I have a machine-translation model. In this model, I calculate a vector for a given sentence and I take this vector, aggregate with each generated output of RNN and put it into RNN again for calculating next hidden state and generating next output. Because of this, I use the RNN in a non-batched fashion.
I want to utilize GPU as much as I can, as well as I'd like to use multiple GPUs. The model computes first sentence transformation easily, when it comes to RNN it is very slow.
For example, if I have a batch size of 100, model computes the transformation with one matrix multiplication for each GPU core (batch size of 50 for each core). So it is fast. But when it comes to RNN computation it takes 1 input at a time, not utilising even 1 core.
I think this is also a problem in text generation tasks. What is the best approach for training efficiently when it comes to models that takes their output as input?
Here is my code:
class RLM(nn.Module):
def __init__(self, input_size, representation_size, num_rnn_layers):
super(RLM, self).__init__()
self.rnn = nn.LSTM(input_size, representation_size, num_rnn_layers, batch_first=True)
self.sigmoid = nn.Sigmoid()
def forward(self, input, hidden):
output, hidden = self.rnn(input, hidden)
return output
class RCTM(nn.Module):
def __init__(self, vocab_size, max_output_len, representation_size, output_dim, rlm, csm):
super(RCTM, self).__init__()
self.csm = csm # sentence transformation
self.rlm = rlm # rnn transformation
...
...
self.token_linear = nn.Linear(vocab_size, representation_size)
self.out_linear = nn.Linear(representation_size, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x, debug = False):
csm_output = self.csm(x) # TRANSLATED SENTENCE - (BATCH_SIZE, REPR_SIZE)
h = RNN_HIDDEN_STATE_INITIALIZATION
token_encoding = GIVING START TOKEN FIRST
token_encodings = REPEAT FOR BATCH SIZE
# Run for MAX_OUTPUT_LEN
while j < self.max_output_len:
token_representation = self.token_linear(token_encodings)
# AGGREGATION
rnn_input = token_representation[:, None, :] + csm_output + rnn_output
# COMPUTE Reccurrent NN
rnn_output = self.rlm(rnn_input, h)
# FINAL COMPUTATION TO LOGITS
out = self.out_linear(rnn_output)
return ...
The important part in the code is in the WHILE loop. I am not sure how to utilize this computation.
IMPORTANT: I monitor GPU usage with nvidia-smi and during RNN computation it is very slow and GPU usage is %0-1, almost nothing. But it's up to %100 during loss computation.