I write code about convolutional LSTM, but I am not sure if mine is good enough. Is anyone interested in giving a look at my code?
class ConvLSTMCell1D(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
# print(hidden_dim)
# Separate convolutional layers for input and previous hidden states
self.conv_x = nn.Conv1d(input_dim, 4 * hidden_dim, kernel_size=3, stride=1, padding=1)
self.conv_h_prev_layer = nn.Conv1d(hidden_dim, 4 * hidden_dim, kernel_size=3, stride=1, padding=1)
self.conv_h = nn.Conv1d(hidden_dim, 4 * hidden_dim, kernel_size=3, stride=1, padding=1)
# Peephole connections
self.W_ci = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.W_cf = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.W_co = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
def forward(self, x, h_prev_layer, h_prev_time, c_prev):
# print(self.hidden_dim)
conv_x = self.conv_x(x) # W_x * X_t
conv_h_prev = self.conv_h_prev_layer(h_prev_layer) # W_{hn-1 hn} * h_t^{n-1}
conv_h_time = self.conv_h(h_prev_time) # W_{hn hn} * h_{t-1}^{n}
conv_output = conv_x + conv_h_prev + conv_h_time # Combined convolution output
cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, conv_output.shape[1]//4, dim=1)
# Apply peephole connections
i = self.sigmoid(cc_i + self.W_ci * c_prev)
f = self.sigmoid(cc_f + self.W_cf * c_prev)
g = self.tanh(cc_g)
c_next = f * c_prev + i * g
o = self.sigmoid(cc_o + self.W_co * c_next)
h_next = o * self.tanh(c_next)
return h_next, c_next
class Seq2SeqConvLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers):
super().__init__()
self.num_layers = num_layers
self.hidden_dim = hidden_dim
# print(hidden_dim)
# Encoder & Decoder layers
self.encoder = nn.ModuleList([
ConvLSTMCell1D(input_dim,hidden_dim) for _ in range(num_layers)])
self.decoder = nn.ModuleList([
ConvLSTMCell1D(input_dim,hidden_dim) for _ in range(num_layers)])
# Final output layer
self.output_conv = nn.Conv1d(hidden_dim, input_dim, kernel_size=1)
# self.output_conv = nn.Conv1d(hidden_dim, input_dim, kernel_size=3, padding=1)
def forward(self, src):
batch_size, seq_len, _, spatial = src.shape
h_enc = [torch.zeros(batch_size, self.hidden_dim, spatial, requires_grad=True).to(device) for _ in range(self.num_layers)]
c_enc = [torch.zeros(batch_size, self.hidden_dim, spatial, requires_grad=True).to(device) for _ in range(self.num_layers)]
# Encoder
for t in range(seq_len):
for layer in range(self.num_layers):
h_prev_layer = h_enc[layer - 1] if layer > 0 else torch.zeros_like(h_enc[layer])
h_prev_time = h_enc[layer]
h_enc[layer], c_enc[layer] = self.encoder[layer](src[:, t, :, :], h_prev_layer, h_prev_time, c_enc[layer])
h_dec = [h.detach() for h in h_enc]
c_dec = [c.detach() for c in c_enc]
decoder_input = src[:, -1, :, :]
outputs = []
# forecasting
for t in range(seq_len):
# for t in range(4):
for layer in range(self.num_layers):
h_prev_layer = h_dec[layer - 1] if layer > 0 else torch.zeros_like(h_dec[layer])
h_prev_time = h_dec[layer]
h_dec[layer], c_dec[layer] = self.decoder[layer](decoder_input, h_prev_layer, h_prev_time, c_dec[layer])
# print(h_dec[0].shape)
pred = self.output_conv(h_dec[-1])
outputs.append(pred)
decoder_input = pred
return torch.stack(outputs, dim=1)