2
\$\begingroup\$

I write code about convolutional LSTM, but I am not sure if mine is good enough. Is anyone interested in giving a look at my code?

class ConvLSTMCell1D(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
   #    print(hidden_dim)
        # Separate convolutional layers for input and previous hidden states
        self.conv_x = nn.Conv1d(input_dim, 4 * hidden_dim, kernel_size=3, stride=1, padding=1)
        self.conv_h_prev_layer = nn.Conv1d(hidden_dim, 4 * hidden_dim, kernel_size=3, stride=1, padding=1)
        self.conv_h = nn.Conv1d(hidden_dim, 4 * hidden_dim, kernel_size=3, stride=1, padding=1)
        
        # Peephole connections
        self.W_ci = nn.Parameter(torch.zeros(1, hidden_dim, 1))
        self.W_cf = nn.Parameter(torch.zeros(1, hidden_dim, 1))
        self.W_co = nn.Parameter(torch.zeros(1, hidden_dim, 1))

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
    
    def forward(self, x, h_prev_layer, h_prev_time, c_prev):
     #  print(self.hidden_dim)
        conv_x = self.conv_x(x)  # W_x * X_t
        conv_h_prev = self.conv_h_prev_layer(h_prev_layer)  # W_{hn-1 hn} * h_t^{n-1}
        conv_h_time = self.conv_h(h_prev_time)  # W_{hn hn} * h_{t-1}^{n}
        
        conv_output = conv_x + conv_h_prev + conv_h_time  # Combined convolution output
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, conv_output.shape[1]//4, dim=1)
        
        # Apply peephole connections
        i = self.sigmoid(cc_i + self.W_ci * c_prev)
        f = self.sigmoid(cc_f + self.W_cf * c_prev)
        g = self.tanh(cc_g)
        c_next = f * c_prev + i * g
                
        o = self.sigmoid(cc_o + self.W_co * c_next)
            
        h_next = o * self.tanh(c_next)
        
        return h_next, c_next

class Seq2SeqConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
   #    print(hidden_dim)
        # Encoder & Decoder layers
        self.encoder = nn.ModuleList([
            ConvLSTMCell1D(input_dim,hidden_dim) for _ in range(num_layers)])
        
        self.decoder = nn.ModuleList([
            ConvLSTMCell1D(input_dim,hidden_dim) for _ in range(num_layers)])
                    
        # Final output layer
        self.output_conv = nn.Conv1d(hidden_dim, input_dim, kernel_size=1)
      # self.output_conv = nn.Conv1d(hidden_dim, input_dim, kernel_size=3, padding=1)
    def forward(self, src):
        batch_size, seq_len, _, spatial = src.shape
        h_enc = [torch.zeros(batch_size, self.hidden_dim, spatial, requires_grad=True).to(device) for _ in range(self.num_layers)]
        c_enc = [torch.zeros(batch_size, self.hidden_dim, spatial, requires_grad=True).to(device) for _ in range(self.num_layers)]

        # Encoder
        for t in range(seq_len):
            for layer in range(self.num_layers):
                h_prev_layer = h_enc[layer - 1] if layer > 0 else torch.zeros_like(h_enc[layer])
                h_prev_time = h_enc[layer]
                h_enc[layer], c_enc[layer] = self.encoder[layer](src[:, t, :, :], h_prev_layer, h_prev_time, c_enc[layer])

        h_dec = [h.detach() for h in h_enc]
        c_dec = [c.detach() for c in c_enc]
        decoder_input = src[:, -1, :, :]      
        outputs = []

        # forecasting
        for t in range(seq_len):
     #  for t in range(4):
            for layer in range(self.num_layers):
                h_prev_layer = h_dec[layer - 1] if layer > 0 else torch.zeros_like(h_dec[layer])
                h_prev_time = h_dec[layer]
                h_dec[layer], c_dec[layer] = self.decoder[layer](decoder_input, h_prev_layer, h_prev_time, c_dec[layer])

      #     print(h_dec[0].shape)
            pred = self.output_conv(h_dec[-1])
            outputs.append(pred)
            decoder_input = pred

        return torch.stack(outputs, dim=1)
\$\endgroup\$
1
  • \$\begingroup\$ Wow, it would not have occurred to me that "random indent on comment" would still allow us to produce valid *.pyc bytecode. I have seen a great deal of python code, and have never encountered random whitespace like that checked into a repo. \$\endgroup\$ Commented Mar 1 at 19:35

1 Answer 1

1
\$\begingroup\$

Here are some minor coding style suggestions.

Firstly, I assume these lines are at the top of your code:

import torch
from torch import nn

Documentation

The PEP 8 style guide recommends adding docstrings for classes and functions.

The docstring for the class should summarize its purpose. It is also helpful to show examples of how it is to be used.

The docstring for the function should also summarize its purpose, along with descriptions of the input types and return types.

Comments

Some of the comments are helpful, but the ones which merely comment out lines of code should be deleted. For example:

#    print(hidden_dim)
  #  print(self.hidden_dim)

My guess is that these print statements were used while debugging the code. It is customary to use the __str__ function for each class. This allows you to construct a string containing whatever data you see fit, then simply print the class handle.

For comments like these:

# W_{hn-1 hn} * h_t^{n-1}

It would be helpful to describe what items like W_ and h_t refer to.

Naming

For brief variable names like f and g, it would be good to either use longer names or describe their common meaning using comments.

Simpler

Consider this code:

    for t in range(seq_len):
 #  for t in range(4):

As mentioned already, the commented line should be deleted to reduce clutter.

Also, since the t variable is not used in the loop, the _ placeholder can be used:

    for _ in range(seq_len):

Layout

In the second class, it would be good to add a blank line before the second function:

def forward(self, src):
\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.