1

I have a class model with field of pre-trained resnet something like:

class A(nn.Module):
    def __init__(self, **kwargs):
        super(A, self).__init__()
        self.resnet = get_resnet()

    def forward(self, x):
        return self.resnet(x)

...

now Im doing

model = A()
...
model.eval()

Is it ok or shuld I overwrite the eval, train functions?

1 Answer 1

2

Short answer

It's OK.

Long answer

As the nn.Module.train() runs recursively like this.

self.training = mode
for module in self.children():
    module.train(mode)
return self

And the nn.Module.eval() is just calling self.train(False)

So as long as self.resnet is an nn.Module subclass. You don't need to bother about it and practically every method in nn.Module except forward will affect all the sub modules.

You can test this by

model = A()
...
model.eval()
print(model.resnet.training)  # should be False

If you get False then everything is fine. If you get something else then there's something wrong with the get_resnet().

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.