5

Can someone tell me please about how the network parameter (10) is calculated? Thanks in advance.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)
print(len(list(net.parameters())))

Output:

Net(
  (conv1): Conv2d (1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120)
  (fc2): Linear(in_features=120, out_features=84)
  (fc3): Linear(in_features=84, out_features=10)
)
10

Best, Zack

2 Answers 2

5

Most layer modules in PyTorch (e.g. Linear, Conv2d, etc.) group parameters into specific categories, such as weights and biases. Each of the five layer instances in your network has a "weight" and a "bias" parameter. This is why "10" is printed.

Of course, all of these "weight" and "bias" fields contain many parameters. For example, your first fully connected layer self.fc1 contains 16 * 5 * 5 * 120 = 48000 parameters. So len(params) doesn't tell you the number of parameters in the network--it gives you just the total number of "groupings" of parameters in the network.

Sign up to request clarification or add additional context in comments.

Comments

3

Since Bill already answered why "10" is printed, I am just sharing a code snippet which you can use to find out the number of parameters associated with each layer in your network.

def count_parameters(model):
    total_param = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_param = numpy.prod(param.size())
            if param.dim() > 1:
                print(name, ':', 'x'.join(str(x) for x in list(param.size())), '=', num_param)
            else:
                print(name, ':', num_param)
            total_param += num_param
    return total_param

Use the above function as follows.

print('number of trainable parameters =', count_parameters(net))

Output:

conv1.weight : 6x1x5x5 = 150
conv1.bias : 6
conv2.weight : 16x6x5x5 = 2400
conv2.bias : 16
fc1.weight : 120x400 = 48000
fc1.bias : 120
fc2.weight : 84x120 = 10080
fc2.bias : 84
fc3.weight : 10x84 = 840
fc3.bias : 10
number of trainable parameters = 61706

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.