I am looking to make this calculation without using any for loops (vectorized) but cant really seem to find a good solution. Maybe someone can help?
edge_in = torch.ones(len(edge_embeds), len(edge_embeds[0]), len(edge_embeds[0][0]) + 2*len(nodes_a_embeds[0]))
for i in range(0, len(nodes_a_embeds)): # A
for u in range(0, len(nodes_b_embeds)): # B
edge_in[i][u] = torch.cat([nodes_a_embeds[i], nodes_b_embeds[u], edge_embeds[i][u]], dim=0)
# OUT: edge_in: torch.Tensor with shape (|A|, |B|, 2*node_dim + 2*edge_dim)
# IN: edge_embeds: torch.Tensor with shape (|A|, |B|, 2 x edge_dim)
# IN: nodes_a_embeds: torch.Tensor with shape (|A|, node_dim)
# IN: nodes_b_embeds: torch.Tensor with shape (|B|, node_dim)