1

I have trained a Variational Autoencoder (VAE) with an additional fully connected layer after the encoder for binary image classification. It is setup using PyTorch Lightning. The encoder / decoder is resnet18 from PyTorch Lightning Bolts repo.

from pl_bolts.models.autoencoders.components import (
    resnet18_encoder,
    resnet18_decoder
)

class VariationalAutoencoder(LightningModule):

...

    self.first_conv: bool = False
    self.maxpool1: bool = False
    self.enc_out_dim: int = 512
    self.encoder = resnet18_encoder(first_conv, maxpool1)
    self.fc_object_identity = nn.Linear(self.enc_out_dim, 1)


    def forward(self, x):
        x_encoded = self.encoder(x)
        mu = self.fc_mu(x_encoded)
        log_var = self.fc_var(x_encoded)
        p, q, z = self.sample(mu, log_var)

        x_classification_score = torch.sigmoid(self.fc_object_identity(x_encoded))

        return self.decoder(z), x_classification_score

variational_autoencoder = VariationalAutoencoder.load_from_checkpoint(
        checkpoint_path=str(checkpoint_file_path)
    )

with torch.no_grad():
    predicted_images, classification_score = variational_autoencoder(test_images)

The reconstructions work well for single images and multiple images when passed through forward(). However, when I pass multiple images to forward() I get different results for the classification score than if I pass a single image tensor:

# Image 1 (class=1) [1, 3, 64, 64]
x_classification_score = 0.9857

# Image 2 (class=0) [1, 3, 64, 64]
x_classification_score = 0.0175

# Image 1 and 2 [2, 3, 64, 64]
x_classification_score =[[0.8943],
                         [0.1736]]

Why is this happening?

2
  • Please provide the architecture for the encoder. You are probably not running the evaluation mode of PyTorch, hence results are different. See here for more info. Commented May 27, 2022 at 21:33
  • 1
    Ah, thanks @szymonmaszke that seems to be it. I have added variational_autoencoder.eval() before the with torch.no_grad(): line and the results are now consistent. So without eval() the network is changing its architecture between inferencing the first image and second one when passing multiple? Commented May 28, 2022 at 5:28

1 Answer 1

1

You are using resnet18 which has a torch.nn.BatchNorm2d layer.

Its behavior changes whether it is in train or eval mode. It calculates mean and variance across batch during training and hence its output is dependent on examples in this batch.

In evaluation mode mean and variance gathered during training via moving average are used which is batch independent, hence results are the same.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.