1

I am looking to make this calculation without using any for loops (vectorized) but cant really seem to find a good solution. Maybe someone can help?

    edge_in = torch.ones(len(edge_embeds), len(edge_embeds[0]), len(edge_embeds[0][0]) + 2*len(nodes_a_embeds[0]))

    for i in range(0, len(nodes_a_embeds)): # A
      for u in range(0, len(nodes_b_embeds)): # B
        edge_in[i][u] = torch.cat([nodes_a_embeds[i], nodes_b_embeds[u], edge_embeds[i][u]], dim=0)

    # OUT: edge_in: torch.Tensor with shape (|A|, |B|, 2*node_dim + 2*edge_dim)

    # IN: edge_embeds: torch.Tensor with shape (|A|, |B|, 2 x edge_dim) 
    # IN: nodes_a_embeds: torch.Tensor with shape (|A|, node_dim)
    # IN: nodes_b_embeds: torch.Tensor with shape (|B|, node_dim)

1 Answer 1

1

You can expand nodes_a_embed and nodes_b_embeds to the same shape as edge_embeds and concatenate them directly:

  • nodes_a_embed = nodes_a_embeds[:, None].expand(-1, n_B, -1): [n_A, node_dim] => [n_A, n_B, node_dim]
  • nodes_b_embed = nodes_b_embeds[None].expand(n_A, -1, -1): [n_B, node_dim] => [n_A, n_B, node_dim]

Verification:

import torch

n_A = 100
n_B = 200
node_dim = 32
edge_dim = 32

edge_in = torch.randn(n_A, n_B, 2*node_dim + 2*edge_dim)

edge_embeds = torch.randn(n_A, n_B, 2*edge_dim) 
nodes_a_embeds = torch.randn(n_A, node_dim)
nodes_b_embeds = torch.randn(n_B, node_dim)

edge_in = torch.ones(len(edge_embeds), len(edge_embeds[0]), len(edge_embeds[0][0]) + 2*len(nodes_a_embeds[0]))

for i in range(0, len(nodes_a_embeds)): # A
    for u in range(0, len(nodes_b_embeds)): # B
        edge_in[i][u] = torch.cat([nodes_a_embeds[i], nodes_b_embeds[u], edge_embeds[i][u]], dim=0)

# vectorized version
edge_in_vectorized = torch.cat([
                nodes_a_embeds[:, None].expand(-1, n_B, -1),
                nodes_b_embeds[None].expand(n_A, -1, -1),
                edge_embeds], dim=-1)

print((edge_in_vectorized == edge_in).all())    # True
Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.