1

I am new to Pytorch and am trying to transfer my previous code from Tensorflow to Pytorch due to memory issues. However, when trying to reproduce Flatten layer, some issues kept coming out.

In my DataLoader object, batch_size is mixed with the first dimension of input (in my GNN, the input unpacked from DataLoader object is of size [batch_size*node_num, attribute_num], e.g. [4*896, 32] after the GCNConv layers). Basically, if I implement torch.flatten() after GCNConv, samples are mixed together (to [4*896*32]) and there would be only 1 output from this network, while I expect #batch_size outputs. And if I use nn.Flatten() instead, nothing seems to happen (still [4*896, 32]). Should I set batch_size as the first dim of the input at the very beginning, or should I directly use view() function? I tried directly using view() and it (seemed to have) worked, although I am not sure if this is the same as Flatten. Please refer to my code below. I am currently using global_max_pool because it works (it can separate batch_size directly).

By the way, I am not sure why training is so slow in Pytorch... When node_num is raised to 13000, I need an hour to go through an epoch, and I have 100 epoch per test fold and 10 test folds. In tensorflow the whole training process only takes several hours. Same network architecture and raw input data, as shown here in another post of mine, which also described the memory issues I met when using TF.

Have been quite frustrated for a while. I checked this and this post, but it seems their problems somewhat differ from mine. Would greatly appreciate any help!

Code:

# Generate dataset
class STDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(STDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['pygdata.pt']

    def download(self):
        pass

    def process(self):
        data_list= []
        for i in range(sample_size):
            data = Data(x=torch.tensor(X_all[i],dtype=torch.float),edge_index=edge_index,y=torch.FloatTensor(y_all[i]))
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        
dataset = STDataset(root=save_dir)
train_dataset = dataset[:len(X_train)]
val_dataset = dataset[len(X_train):(len(X_train)+len(X_val))]
test_dataset = dataset[(len(X_train)+len(X_val)):]


# Build network

from torch_geometric.nn import GCNConv, GATConv, TopKPooling, global_max_pool, global_mean_pool
from torch.nn import Flatten, Linear, ELU
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels = feature_num, out_channels = 32)
        self.conv2 = GCNConv(in_channels = 32, out_channels = 32)
        self.fc1 = Flatten()
#         self.ln1 = Linear(in_features = batch_size*N*32, out_features = 512) 
        self.ln1 = Linear(in_features = 32, out_features = 32)
        self.ln2 = Linear(in_features = 32, out_features = 1) 

    
    def forward(self,x,edge_index,batch):   
#         x, edge_index, batch = data.x, data.edge_index, data.batch
#         print(np.shape(x),np.shape(edge_index),np.shape(batch))
        x = F.elu(self.conv1(x,edge_index))
#         x = x.squeeze(1) 
        x = F.elu(self.conv2(x,edge_index))
        print(np.shape(x))
        x = self.fc1(x)
#         x = torch.flatten(x,0)
#         x = torch.cat([global_max_pool(x,batch),global_mean_pool(x,batch)],dim=1)
        print(np.shape(x))
        x = self.ln1(x)
        x = F.relu(x)
        ## Dropout?
        print("o")
        x = torch.sigmoid(self.ln2(x))
        return x
        
# training
def train():
    model.train()
    loss_all=0
    correct = 0
    for i, data in enumerate(train_loader, 0):
        data = data.to(device)
        optimizer.zero_grad() 
        output = model(data.x, data.edge_index,data.batch)
        label = data.y.to(device)
        loss = loss_func(output, label)
        loss.backward()
        loss_all += loss.item()
        
        output = output.detach().cpu().numpy().squeeze()
        label = label.detach().cpu().numpy().squeeze()        
        correct += (abs(output-label)<0.5).sum()
        
        optimizer.step()
  
    return loss_all / len(train_dataset), correct / len(train_dataset)

device = torch.device('cuda')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_func = torch.nn.BCELoss()  # binary cross-entropy
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle = True)
for epoch in range(num_epochs):
    gc.collect()
    train_loss, train_acc = train()    


Error message for using torch.nn.Flatten(start_dim = 1) (code above):

ValueError                                Traceback (most recent call last)
<ipython-input-42-c96e8b058742> in <module>
     65 for epoch in range(num_epochs):
     66     gc.collect()
---> 67     train_loss, train_acc = train()

<ipython-input-42-c96e8b058742> in train()
     10         output = model(data.x, data.edge_index,data.batch)
     11         label = data.y.to(device)
---> 12         loss = loss_func(output, label)
     13         loss.backward()
     14         loss_all += loss.item()

~/miniconda3/envs/ST-Torch/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

~/miniconda3/envs/ST-Torch/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    496 
    497     def forward(self, input, target):
--> 498         return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
    499 
    500 

~/miniconda3/envs/ST-Torch/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
   2068     if input.numel() != target.numel():
   2069         raise ValueError("Target and input must have the same number of elements. target nelement ({}) "
-> 2070                          "!= input nelement ({})".format(target.numel(), input.numel()))
   2071 
   2072     if weight is not None:

ValueError: Target and input must have the same number of elements. target nelement (4) != input nelement (3584)

1 Answer 1

4

The way you want the shape to be batch_size*node_num, attribute_num is kinda weird.

Usually it should be batch_size, node_num*attribute_num as you need to match the input to the output. And Flatten in Pytorch does exactly that.

If what you want is really batch_size*node_num, attribute_num then you left with only reshaping the tensor using view or reshape. And actually Flatten itself just calls .reshape.

tensor.view: This will reshape the existing tensor to a new shape, if you edit this new tensor the old one will change too.

tensor.reshape: This will create a new tensor using the data from old tensor but with new shape.

    def forward(self,x,edge_index,batch):   
        x = F.elu(self.conv1(x,edge_index))
        x = F.elu(self.conv2(x,edge_index))

        # print(np.shape(x)) # don't use this
        print(x.size())  # use this

        # x = self.fc1(x)  # this is the old one
        ## choose one of these
        x = x.view(4*896, 32)
        x = x.reshape(4*896, 32)  

        # print(np.shape(x)) # don't use this
        print(x.size())  # use this

        x = self.ln1(x)
        x = F.relu(x)
        ## Dropout?
        print("o")
        x = torch.sigmoid(self.ln2(x))
        return x

Edit 2 reshape

Let's say we have an array of [[[1, 1, 1], [2, 2, 2]]] which shape (1, 2, 3), which represent (batch, length, channel) in Tensorflow.

If you want to use this data properly in Pytorch you need to make it (batch, channel, length), which is (1, 3, 2).

Here's the difference between permute and reshape

>>> x = torch.tensor([[[1, 1, 1], [2, 2, 2]]])
>>> x.size()
torch.Size([1, 2, 3])
>>> x[0, 0, :]
tensor([1, 1, 1])
>>> y = x.reshape((1, 3, 2))
>>> y
tensor([[[1, 1],
         [1, 2],
         [2, 2]]])
>>> y[0, :, 0]
tensor([1, 1, 2])
>>> z = x.permute(0, 2, 1)
>>> z
tensor([[[1, 2],
         [1, 2],
         [1, 2]]])
>>> z[0, :, 0]
tensor([1, 1, 1])

As you can see, the first channel of both x and z are [1, 1, 1] which is what we want while y is [1, 1, 2].

Sign up to request clarification or add additional context in comments.

7 Comments

Thanks for your explanations! Helps a lot. I actually want [batch_size, node_num, attribute_num] as it would be encoded in Tensorflow, but the DataLoader() function (used as tutorials taught) just gives me [batch_size*node_num, attribute_num]. I thought it was normal Pytorch logic, but looks like it's not? I do hope it could be something else honestly ;) Maybe I should manually reshape the input at the very beginning of the network?
@jasperhyp Pytorch is actually work similar to Tensorflow. You just need to swap the channel axis from the last channel to right after the batch.
If the code in tutorials work then it's fine. But otherwise you can also use Tensorflow's data loader and change the channel then convert to torch tensor.
Thanks for your comment! I am not sure if I understand "swap the channel axis from the last channel to right after the batch". Do you mean reshape the input data directly? I now think not having batch_size as an independent dimension is truly problematic, because any layers related to batch (e.g. batch normalization layers) are not working as expected.
Not reshape but permute. For numpy array it called rollaxis or swapaxis.
|

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.