One instance of the following module uses up to almost 75% of my vram. So, I was wondering how I could improve that without slowing down runtime too much. The code is below:
NUM_OF_IMUS = 13
NUM_OF_NOISE_PARAMS = 9
class mod(nn.Module):
    def __init__(self, d_model, device):
        super(Noise_Regressor, self).__init__()
        
        self.norm1 = nn.LayerNorm(d_model)
        self.hidden_state_to_noise_params = nn.Linear(d_model, NUM_OF_IMUS * NUM_OF_NOISE_PARAMS)
        self.eps = 1e-5
        self.device = device
        self.t_step_init_mat = torch.triu(torch.arange(10000, device=self.device) - torch.arange(10000, device=self.device)[:, None])
        self.MASK = torch.triu(torch.ones((10000, 10000), device=self.device), diagonal=0)
    
    """
        hidden_states should be of dimension (Batch, Sequence Len, Dim)
        B should always be 1
        Sequence Length can be up to 10000
        The dimension can be 512
    """
    def forward(self, hidden_states, min_orig_accel_norm):
        seq_len = hidden_states.shape[1]
        
        t_step_init_mat = self.t_step_init_mat[:seq_len, :seq_len]
        MASK = self.MASK[:seq_len, :seq_len]
        
        hidden_normed = self.norm1(hidden_states)
        noise_params = self.hidden_state_to_noise_params(hidden_normed).view(seq_len, NUM_OF_NOISE_PARAMS, NUM_OF_IMUS)
        
        c = noise_params[:, 4, :].view(seq_len, 1, NUM_OF_IMUS)
        c_theta = noise_params[:, 5, :].view(seq_len, 1, NUM_OF_IMUS)
        phi = noise_params[:, 6, :].view(seq_len, 1, NUM_OF_IMUS)
        phi_theta = noise_params[:, 7, :].view(seq_len, 1, NUM_OF_IMUS)
        d = torch.sqrt((noise_params[:, 1, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
        k = (d**2) / 4 + F.softplus(noise_params[:, 0, :]).view(seq_len, 1, NUM_OF_IMUS)
        d_theta = torch.sqrt((noise_params[:, 3, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
        k_theta = (d_theta**2) / 4 + F.softplus(noise_params[:, 2, :]).view(seq_len, 1, NUM_OF_IMUS)
        
        noise_bias = noise_params[:, 8, :].T
        
        dynamics_list = [] 
        for imu_num in range(NUM_OF_IMUS):
            omega1 = torch.sqrt(4 * k[:, :, imu_num] - (d[:, :, imu_num] ** 2)) / 2
            linear_dynamics = c[:, :, imu_num] * torch.exp((-d[:, :, imu_num] / 2)  * t_step_init_mat) *  torch.sin(phi[:, :, imu_num] + t_step_init_mat * omega1)
            omega1_theta = torch.sqrt(4 * k_theta[:, :, imu_num] - (d_theta[:, :, imu_num] ** 2)) / 2
            angular_dynamics = c_theta[:, :, imu_num] * torch.exp((-d_theta[:, :, imu_num] / 2)  * t_step_init_mat) *  torch.sin(phi_theta[:, :, imu_num] + t_step_init_mat * omega1_theta)
            spring_damper_dynamics_per_step = (linear_dynamics + angular_dynamics) * MASK
            dynamics_list.append(torch.sum(spring_damper_dynamics_per_step, dim=0, keepdim=True)) 
        
        return torch.cat(dynamics_list, dim=0) + min_orig_accel_norm + noise_bias