0

I have a machine-translation model. In this model, I calculate a vector for a given sentence and I take this vector, aggregate with each generated output of RNN and put it into RNN again for calculating next hidden state and generating next output. Because of this, I use the RNN in a non-batched fashion.

I want to utilize GPU as much as I can, as well as I'd like to use multiple GPUs. The model computes first sentence transformation easily, when it comes to RNN it is very slow.

For example, if I have a batch size of 100, model computes the transformation with one matrix multiplication for each GPU core (batch size of 50 for each core). So it is fast. But when it comes to RNN computation it takes 1 input at a time, not utilising even 1 core.

I think this is also a problem in text generation tasks. What is the best approach for training efficiently when it comes to models that takes their output as input?

Here is my code:

class RLM(nn.Module):
    def __init__(self, input_size, representation_size, num_rnn_layers):
        super(RLM, self).__init__()
        self.rnn = nn.LSTM(input_size, representation_size, num_rnn_layers, batch_first=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, hidden):
        output, hidden = self.rnn(input, hidden)

        return output

class RCTM(nn.Module):
    def __init__(self, vocab_size, max_output_len, representation_size, output_dim, rlm, csm):
        super(RCTM, self).__init__()
        self.csm = csm # sentence transformation
        self.rlm = rlm # rnn transformation
        ...
        ...
        self.token_linear = nn.Linear(vocab_size, representation_size)
        self.out_linear = nn.Linear(representation_size, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, debug = False):
        csm_output = self.csm(x) # TRANSLATED SENTENCE - (BATCH_SIZE, REPR_SIZE)
        
        h = RNN_HIDDEN_STATE_INITIALIZATION
        token_encoding = GIVING START TOKEN FIRST
        token_encodings = REPEAT FOR BATCH SIZE

        # Run for MAX_OUTPUT_LEN
        while j < self.max_output_len:
            token_representation = self.token_linear(token_encodings)
            
            # AGGREGATION
            rnn_input = token_representation[:, None, :] + csm_output + rnn_output
            
            # COMPUTE Reccurrent NN
            rnn_output = self.rlm(rnn_input, h)

            # FINAL COMPUTATION TO LOGITS
            out = self.out_linear(rnn_output)         

            
        return ...

The important part in the code is in the WHILE loop. I am not sure how to utilize this computation.

IMPORTANT: I monitor GPU usage with nvidia-smi and during RNN computation it is very slow and GPU usage is %0-1, almost nothing. But it's up to %100 during loss computation.

4
  • This is a core property of RNN, output/input for next step is processed sequentially. I don't have time to test and provide a workable answer but one suggestion would be to push as much work into the LSTM since recurrence would be captured in the hidden/cell states, don't think the wile loop is necessary here. Also as a side note I would favor nn.Embedding over nn.Linear if possible. Commented Oct 15 at 14:39
  • Thank you for the suggestion. What do you mean by "while loop is not necessary", how can I achieve this without a while loop? Commented Oct 15 at 15:02
  • Essentially try and pack your entire sequence such that the LSTM is ran once per batch. Suggest looking into 'teacher forcing' for RNN if applicable Commented Oct 15 at 15:18
  • Do you need the feedback to back through the RNN again? If not you can try running all the tasks ahead of time. Presumably you know all the target sequences in advance. If you do, you can just process all the time steps in parallel. Commented Oct 29 at 14:18

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.