In Keras, using the Flatten() layer retains the batch size. For eg, if the input shape to Flatten is (32, 100, 100), in Keras output of Flatten is (32, 10000), but in PyTorch it is 320000. Why is it so?
2 Answers
As OP already pointed out in their answer, the tensor operations do not default to considering a batch dimension. You can use torch.flatten() or Tensor.flatten() with start_dim=1 to start the flattening operation after the batch dimension.
Alternatively since PyTorch 1.2.0 you can define an nn.Flatten() layer in your model which defaults to start_dim=1.
1 Comment
DataLoader object, batch_size is mixed with the first dimension of input (in my GNN, the input unpacked from DataLoader object is of size [batch_size*node_num, attribute_num]), so if I use torch.flatten(), samples are mixed together and there would be only 1 output from this network, while I expect #batch_size outputs. And if I use nn.Flatten(), nothing seems to happen and the output of this layer is still [batch_size*node_num, attribute_num]. How can I deal with this?Yes, As mentioned in this thread, PyTorch operations such as Flatten, view, reshape.
In general when using modules like Conv2d, you don't need to worry about batch size. PyTorch takes care of it. But when dealing directly with tensors, you need to take care of batch size.
In Keras, Flatten() is a layer. But in PyTorch, flatten() is an operation on the tensor. Hence, batch size needs to be taken care manually.