What you're trying to achieve is input masking. You can use separate fully connected layers (different people) for each output after running the LSTM once to capture all the information (one photo), instead of running the LSTM separately for each output (like taking multiple photos for different people).
You can mask or zero out portions of the LSTM's hidden state before sending it to the appropriate output layer if you want to make sure that particular inputs don't affect particular outputs. By running the LSTM just once, you can accomplish the same goal much more effectively.
I hope that helps you
Edit:
ok here's the modification needed base on your code:
import torch
import torch.nn as nn
class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LstmModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
# Fully connected layers for different outputs
self.fc1 = nn.Linear(hidden_size, output_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.fc3 = nn.Linear(hidden_size, output_size)
# Create masks to remove certain inputs' effects
self.mask1 = nn.Parameter(torch.tensor([1, 1, 0, 1, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False) # Ignore input3 for output1, padded with 0's
self.mask2 = nn.Parameter(torch.tensor([1, 0, 1, 1, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False) # Ignore input2 for output2, padded with 0's
self.mask3 = nn.Parameter(torch.tensor([1, 1, 1, 0, 1, 0, 0, 0], dtype=torch.float32), requires_grad=False) # Ignore input4 for output3, padded with 0's
def forward(self, x):
batch_size, seq_length, input_size = x.shape
# LSTM Forward pass
_, (hn, _) = self.lstm(x) # hn shape: (num_layers, batch, hidden_size)
hidden = hn[-1] # Get the final hidden state (batch, hidden_size)
# Apply masks by element-wise multiplication with the hidden state
hidden_masked1 = hidden * self.mask1 # Apply mask to hidden state for output1
hidden_masked2 = hidden * self.mask2 # Apply mask for output2
hidden_masked3 = hidden * self.mask3 # Apply mask for output3
# Generate outputs
output1 = self.fc1(hidden_masked1)
output2 = self.fc2(hidden_masked2)
output3 = self.fc3(hidden_masked3)
return output1, output2, output3
# === TESTING THE MODEL ===
# Example input: (batch_size=2, seq_length=10, input_size=5)
batch_size = 2
seq_length = 10
input_size = 5
hidden_size = 8
output_size = 1 # Single value per output
# Create model
model = LstmModel(input_size, hidden_size, output_size)
# Generate some random input data
x = torch.randn(batch_size, seq_length, input_size)
# Forward pass
output1, output2, output3 = model(x)
print("Output 1:", output1)
print("Output 2:", output2)
print("Output 3:", output3)
the results without the hidden_masked:
Output 1: tensor([[-0.0487],
[-0.0439]], grad_fn=<AddmmBackward0>)
Output 2: tensor([[-0.2588],
[-0.2890]], grad_fn=<AddmmBackward0>)
Output 3: tensor([[0.1792],
[0.1249]], grad_fn=<AddmmBackward0>)
with the hidden_masked:
Output 1: tensor([[0.3568],
[0.3477]], grad_fn=<AddmmBackward0>)
Output 2: tensor([[-0.3200],
[-0.3470]], grad_fn=<AddmmBackward0>)
Output 3: tensor([[0.4120],
[0.2970]], grad_fn=<AddmmBackward0>)
i realy hope that the comments i added in the code clarify each line and what it role.
use google colab to quick test the code
Edit 2:
since i used hard code values, here's a more rebost way:
import torch
import torch.nn as nn
class MaskedLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, exclusion_map):
"""
Args:
exclusion_map: Dictionary mapping output_idx to excluded_input_idx
Example: {1: 2, 2: 1, 3: 3} # output1 excludes input3, output2 excludes input2, etc.
"""
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc_layers = nn.ModuleList([nn.Linear(hidden_size, output_size)
for _ in range(len(exclusion_map))])
# Create trainable masks (modified)
self.masks = nn.ParameterDict()
for output_idx, excluded_input in exclusion_map.items():
mask = torch.ones(hidden_size)
# Zero 20% of dimensions associated with excluded input
input_span = hidden_size // input_size
start = excluded_input * input_span
end = (excluded_input + 1) * input_span
mask[start:end] = 0
self.masks[f"mask_{output_idx}"] = nn.Parameter(mask, requires_grad=False)
def forward(self, x):
lstm_out, (hn, _) = self.lstm(x)
final_hidden = hn[-1] # (batch_size, hidden_size)
outputs = []
for idx, fc in enumerate(self.fc_layers):
masked_hidden = final_hidden * self.masks[f"mask_{idx+1}"]
outputs.append(fc(masked_hidden))
return tuple(outputs)
# Configuration
exclusion_rules = {
1: 2, # Output1 excludes input3
2: 1, # Output2 excludes input2
3: 3 # Output3 excludes input4
}
model = MaskedLSTM(input_size=5, hidden_size=10,
output_size=1, exclusion_map=exclusion_rules)
# Test
x = torch.randn(3, 10, 5) # batch_size=3, seq_len=10
out1, out2, out3 = model(x)
print("Output 1:", out1)
print("Output 2:", out2)
print("Output 3:", out3)
tcontains information from all timesteps beforet. If you don't want that, you should use a different model architecture