NUM_OF_IMUS = 13
NUM_OF_NOISE_PARAMS = 9
class Noise_regressormod(nn.Module):
def __init__(self, d_model, device):
super(Noise_Regressor, self).__init__()
self.norm1 = nn.LayerNorm(d_model)
self.hidden_state_to_noise_params = nn.Linear(d_model, NUM_OF_IMUS * NUM_OF_NOISE_PARAMS)
self.eps = 1e-5
self.device = device
self.t_step_init_mat = torch.triu(torch.arange(10000, device=self.device) - torch.arange(10000, device=self.device)[:, None])
self.MASK = torch.triu(torch.ones((10000, 10000), device=self.device), diagonal=0)
"""
hidden_states should be of dimension (Batch, Sequence Len, Dim)
B should always be 1
Sequence Length can be up to 10000
The dimension can be 512
"""
def forward(self, hidden_states, min_orig_accel_norm):
seq_len = hidden_states.shape[1]
t_step_init_mat = self.t_step_init_mat[:seq_len, :seq_len]
MASK = self.MASK[:seq_len, :seq_len]
hidden_normed = self.norm1(hidden_states)
noise_params = self.hidden_state_to_noise_params(hidden_normed).view(seq_len, NUM_OF_NOISE_PARAMS, NUM_OF_IMUS)
c = noise_params[:, 4, :].view(seq_len, 1, NUM_OF_IMUS)
c_theta = noise_params[:, 5, :].view(seq_len, 1, NUM_OF_IMUS)
phi = noise_params[:, 6, :].view(seq_len, 1, NUM_OF_IMUS)
phi_theta = noise_params[:, 7, :].view(seq_len, 1, NUM_OF_IMUS)
d = torch.sqrt((noise_params[:, 1, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
k = (d**2) / 4 + F.softplus(noise_params[:, 0, :]).view(seq_len, 1, NUM_OF_IMUS)
d_theta = torch.sqrt((noise_params[:, 3, :] ** 2) + self.eps).view(seq_len, 1, NUM_OF_IMUS)
k_theta = (d_theta**2) / 4 + F.softplus(noise_params[:, 2, :]).view(seq_len, 1, NUM_OF_IMUS)
noise_bias = noise_params[:, 8, :].T
dynamics_list = []
for imu_num in range(NUM_OF_IMUS):
omega1 = torch.sqrt(4 * k[:, :, imu_num] - (d[:, :, imu_num] ** 2)) / 2
linear_dynamics = c[:, :, imu_num] * torch.exp((-d[:, :, imu_num] / 2) * t_step_init_mat) * torch.sin(phi[:, :, imu_num] + t_step_init_mat * omega1)
omega1_theta = torch.sqrt(4 * k_theta[:, :, imu_num] - (d_theta[:, :, imu_num] ** 2)) / 2
angular_dynamics = c_theta[:, :, imu_num] * torch.exp((-d_theta[:, :, imu_num] / 2) * t_step_init_mat) * torch.sin(phi_theta[:, :, imu_num] + t_step_init_mat * omega1_theta)
spring_damper_dynamics_per_step = (linear_dynamics + angular_dynamics) * MASK
dynamics_list.append(torch.sum(spring_damper_dynamics_per_step, dim=0, keepdim=True))
return torch.cat(dynamics_list, dim=0) + min_orig_accel_norm + noise_bias