2
\$\begingroup\$

One instance of the following module uses up to almost 75% of my vram. So, I was wondering how I could improve that without slowing down runtime too much. The code is below:

NUM_OF_IMUS = 13
NUM_OF_NOISE_PARAMS = 9

class mod(nn.Module):
    def __init__(self, d_model, device):
        super(Noise_Regressor, self).__init__()
        
        self.norm1 = nn.LayerNorm(d_model)

        self.hidden_state_to_noise_params = nn.Linear(d_model, NUM_OF_IMUS * NUM_OF_NOISE_PARAMS)
        self.eps = 1e-5
        self.device = device
        self.t_step_init_mat = torch.triu(torch.arange(10000, device=self.device) - torch.arange(10000, device=self.device)[:, None])
        self.MASK = torch.triu(torch.ones((10000, 10000), device=self.device), diagonal=0)
    
    """
        hidden_states should be of dimension (Batch, Sequence Len, Dim)
        B should always be 1
        Sequence Length can be up to 10000
        The dimension can be 512
    """

    def forward(self, hidden_states, min_orig_accel_norm):
        seq_len = hidden_states.shape[1]
        
        t_step_init_mat = self.t_step_init_mat[:seq_len, :seq_len]
        MASK = self.MASK[:seq_len, :seq_len]
        
        hidden_normed = self.norm1(hidden_states)
        noise_params = self.hidden_state_to_noise_params(hidden_normed).view(seq_len, NUM_OF_NOISE_PARAMS, NUM_OF_IMUS)
        
        c = noise_params[:, 4, :].view(seq_len, 1, NUM_OF_IMUS)
        c_theta = noise_params[:, 5, :].view(seq_len, 1, NUM_OF_IMUS)
        phi = noise_params[:, 6, :].view(seq_len, 1, NUM_OF_IMUS)
        phi_theta = noise_params[:, 7, :].view(seq_len, 1, NUM_OF_IMUS)

        d = torch.sqrt((noise_params[:, 1, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
        k = (d**2) / 4 + F.softplus(noise_params[:, 0, :]).view(seq_len, 1, NUM_OF_IMUS)

        d_theta = torch.sqrt((noise_params[:, 3, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
        k_theta = (d_theta**2) / 4 + F.softplus(noise_params[:, 2, :]).view(seq_len, 1, NUM_OF_IMUS)
        
        noise_bias = noise_params[:, 8, :].T
        

        dynamics_list = [] 
        for imu_num in range(NUM_OF_IMUS):
            omega1 = torch.sqrt(4 * k[:, :, imu_num] - (d[:, :, imu_num] ** 2)) / 2
            linear_dynamics = c[:, :, imu_num] * torch.exp((-d[:, :, imu_num] / 2)  * t_step_init_mat) *  torch.sin(phi[:, :, imu_num] + t_step_init_mat * omega1)

            omega1_theta = torch.sqrt(4 * k_theta[:, :, imu_num] - (d_theta[:, :, imu_num] ** 2)) / 2
            angular_dynamics = c_theta[:, :, imu_num] * torch.exp((-d_theta[:, :, imu_num] / 2)  * t_step_init_mat) *  torch.sin(phi_theta[:, :, imu_num] + t_step_init_mat * omega1_theta)

            spring_damper_dynamics_per_step = (linear_dynamics + angular_dynamics) * MASK
            dynamics_list.append(torch.sum(spring_damper_dynamics_per_step, dim=0, keepdim=True)) 
        
        return torch.cat(dynamics_list, dim=0) + min_orig_accel_norm + noise_bias
\$\endgroup\$
1
  • \$\begingroup\$ Please do not edit the question, especially the code, after an answer has been posted. Changing the question may cause answer invalidation. Everyone needs to be able to see what the reviewer was referring to. What to do after the question has been answered. \$\endgroup\$ Commented Dec 8, 2024 at 13:20

1 Answer 1

2
\$\begingroup\$

computing a discarded result

Please don't write code like this:

def greet(name):
    42
    name + " is cool."
    print(f"Hello {name}!")

Yes, you can compute a literal or an expression and then discard the result, the python interpreter will let you do that. But it doesn't help the readability of your code.

Rather than

    """
        hidden_states should be of dimension (Batch, Sequence Len, Dim)
        B should always be 1
        Sequence Length can be up to 10000
        The dimension can be 512
    """

you meant to write

      # hidden_states should be of dimension (Batch, Sequence Len, Dim)
      # B should always be 1
      # Sequence Length can be up to 10000
      # The dimension can be 512

Please note that the OP code does not contain any docstrings, despite the presence of a discarded triple-quoted string.

inheritance

class mod(nn.Module):
    ...
        super(Noise_Regressor, self).__init__()

Maybe you'd prefer for the class MRO to inherit from both those classes?

helpful names

The various d, k, c, and phi local variables are admirably clear. Thank you for spelling out the meaning of what's at those indices.

Do try to cite your references. As written it's unclear which wikipedia page or other textbook resource mod might be trying to implement.

performance

75% of my vram

You didn't tell us the business problem you're trying to solve, the problem size, nor your VRAM size or any elapsed timings. I'm willing to believe we're computing some figures which do not impinge directly on the business problem and could be discarded, but the OP doesn't help us understand what aspects of the computation are most important to the use case.

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.