vectors – torch.Tensor, shape (…, n_nodes, 3)
torch.Tensor, shape (…, n_nodes * n_nodes)
cosines
API Reference: