12

In Keras, using the Flatten() layer retains the batch size. For eg, if the input shape to Flatten is (32, 100, 100), in Keras output of Flatten is (32, 10000), but in PyTorch it is 320000. Why is it so?

2 Answers 2

20

As OP already pointed out in their answer, the tensor operations do not default to considering a batch dimension. You can use torch.flatten() or Tensor.flatten() with start_dim=1 to start the flattening operation after the batch dimension.

Alternatively since PyTorch 1.2.0 you can define an nn.Flatten() layer in your model which defaults to start_dim=1.

Sign up to request clarification or add additional context in comments.

1 Comment

I tried both methods when training a GNN, but I seem to be missing something. In my DataLoader object, batch_size is mixed with the first dimension of input (in my GNN, the input unpacked from DataLoader object is of size [batch_size*node_num, attribute_num]), so if I use torch.flatten(), samples are mixed together and there would be only 1 output from this network, while I expect #batch_size outputs. And if I use nn.Flatten(), nothing seems to happen and the output of this layer is still [batch_size*node_num, attribute_num]. How can I deal with this?
2

Yes, As mentioned in this thread, PyTorch operations such as Flatten, view, reshape.

In general when using modules like Conv2d, you don't need to worry about batch size. PyTorch takes care of it. But when dealing directly with tensors, you need to take care of batch size.

In Keras, Flatten() is a layer. But in PyTorch, flatten() is an operation on the tensor. Hence, batch size needs to be taken care manually.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.