4

I know this seems to be a common problem but I wasn't able to find a solution. I'm running a multi-label classification model and having issues with tensor sizing.

My full code looks like this:

from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch

# Instantiating tokenizer and model
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')

# Instantiating quantized model
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

# Forming data tensors
input_ids = torch.tensor(tokenizer.encode(x_train[0], add_special_tokens=True)).unsqueeze(0)
labels = torch.tensor(Y[0]).unsqueeze(0)

# Train model
outputs = quantized_model(input_ids, labels=labels)
loss, logits = outputs[:2]

Which yields the error:

ValueError: Expected input batch_size (1) to match target batch_size (11)

Input_ids looks like:

tensor([[  101,   789,   160,  1766,  1616,  1110,   170,  1205,  7727,  1113,
           170,  2463,  1128,  1336,  1309,  1138,   112,   119, 11882, 11545,
           119,   108, 15710,   108,  3645,   108,  3994,   102]])

with shape:

torch.Size([1, 28])

and labels looks like:

tensor([[0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1]])

with shape:

torch.Size([1, 11])

The size of input_ids will vary as the strings to be encoded vary in size.

I also noticed that when feeding in 5 values of Y to produce 5 labels, it yields the error:

ValueError: Expected input batch_size (1) to match target batch_size (55).

with labels shape:

torch.Size([1, 5, 11])

(Note that I didn't feed 5 input_ids, which is presumably why input size remains constant)

I've tried a few different approaches to getting these to work, but I'm currently at a loss. I'd really appreciate some guidance. Thanks!

5
  • Why are you unsqueezing the first dim? That's suppsoe to be the batch size Commented May 9, 2020 at 22:11
  • I probably should have said I pulled most of this as an example from a huggingface transformers example which had the lines: input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 and labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 So I left those as is Commented May 9, 2020 at 22:17
  • When I change the labels unsqueeze value to 11 and the input_id unsqueeze value to -1, I get the following error: IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 11) @umbreon29 Commented May 9, 2020 at 22:23
  • @umbreon29 when I change the input_id unsqueeze value to -1 and the labels unsqueeze value to 0, I get the error: ValueError: Expected input batch_size (28) to match target batch_size (11) Commented May 9, 2020 at 22:32
  • I'm really confused on how to make the shapes line up Commented May 10, 2020 at 3:02

1 Answer 1

4

The labels for DistilBertForSequenceClassification need to have the size torch.Size([batch_size]) as mentioned in the documentation:

  • labels (torch.LongTensor of shape (batch_size,), optional, defaults to None) – Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

In your case, your labels should have size torch.Size([1]).

That is not possible for your data, and that's because the sequence classification has one label for each sequence, but you wanted to make it a multi-label classification.

As far as I'm aware there is no multi-label model in HuggingFace's transformer library that you could use out of the box. You would need to create your own model, which is not particularly difficult, because these extra models all use the same base model and add an appropriate classifier at the end, depending on the task to be solved. HuggingFace - Multi-label Text Classification using BERT – The Mighty Transformer explains how this can be done.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.