0

I was trying to re-implement the model.generate() function of transformers' models from huggingface. I did that so I could implement logit-bias, that normal function does not allow. But before I could reach that, I encountered a lot of problems with my top-p sampling.

Here's the code snippet:

generation_args = {
    "max_new_tokens": 500,
    "temperature": 0.4,  # Adjust temperature if needed for more or less randomness
    "do_sample": True,  # Enable sampling
    "top_p": 0.5,  # Set the cumulative probability for nucleus sampling
    "top_k": None,  # Optionally, you can set top_k if you want to use it alongside or instead of top_p
}


def top_p_filtering(logits, top_p):
    """Filter the logits using top-p (nucleus) sampling."""
    # Sort logits in descending order and get the sorted indices
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)

    # Compute the cumulative probabilities of the sorted logits
    cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

    # Create a mask for the tokens to keep
    sorted_indices_to_keep = cumulative_probs <= top_p

    # Ensure that at least one token is kept (the first token, which has the highest logit)
    sorted_indices_to_keep[..., 0] = True

    # Filter out the tokens to remove by setting their logits to negative infinity
    logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')

    return logits


def custom_generate(input_ids, streamer, max_new_tokens, temperature, top_p):
    past_key_values = None
    attention_mask = torch.ones(input_ids.shape, device=input_ids.device)

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                past_key_values=past_key_values,
                attention_mask=attention_mask,
                use_cache=True
            )

        logits = outputs.logits[:, -1, :]  # Get logits of the last token

        # Apply temperature to logits
        if temperature != 1.0:
            logits = logits / temperature

        # Apply top-p sampling
        if top_p is not None and top_p < 1.0:
            logits = top_p_filtering(logits, top_p)
        print("1")
        next_token_probs = torch.nn.functional.softmax(logits, dim=-1)
        print("2")
        # Check if next_token_probs contains valid probabilities


        next_token_id = torch.multinomial(next_token_probs,
                                          num_samples=1)  
        print("3")
        streamer.put(next_token_id)  # Pass the tensor directly to the streamer

        input_ids = next_token_id  # Set the next input to the last generated token
        attention_mask = torch.cat(
            [attention_mask, torch.ones((attention_mask.shape[0], 1), device=attention_mask.device)], dim=1)

        past_key_values = outputs.past_key_values

        if next_token_id.item() == tokenizer.eos_token_id:  
            break

with torch.no_grad():
    custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])

The error that I face:

../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [10,0,0], thread: [63,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Exception in thread Thread-18 (generate):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 130, in generate
    custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])
  File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 108, in custom_generate
    next_token_id = torch.multinomial(next_token_probs,
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The entire problem arised only after adding top-p sampling.

I expected my sampling to work, as I have looked through my code maybe 30 times already. ChatGPT says this code is perfect, and that my error is really hard to debug. My hypothesis is that values are getting incorrectly filtered or setting them to "bad" values.

1 Answer 1

0

The problem is the indexing you're doing at this line:

logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')

For reasons I'll explain, this is causing an index out of bounds error. Out of bounds indexing is a common cause of CUDA error: device-side assert triggered errors.

Consider the following:

import torch
import torch.nn as nn

torch.manual_seed(42)

top_p = 0.2

logits = torch.randn(8, 128) # random logits

# sort logits 
sorted_logits, sorted_indices = torch.sort(logits, descending=True)

# calculate cumulative probs
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

# apply top p threshold to cumulative probs
sorted_indices_to_keep = cumulative_probs <= top_p

# ensure at least one index is kept
sorted_indices_to_keep[..., 0] = True

# this is the problem: logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
print(logits.shape, sorted_indices[~sorted_indices_to_keep].shape)
> torch.Size([8, 128]) torch.Size([989])

When you index sorted_indices[~sorted_indices_to_keep], both inputs are of shape (8, 128), but the output is of shape (989,) (or similar depending on the random seed for the dummy logits).

This happens because the sorted_indices_to_keep has an irregular number of True values in each row. This means the indexing operation can't resolve the output into a clean 2D tensor where every row is the same size. Pytorch handles this situation by returning an unrolled vector of every True value from the indexing tensor.

This means when you try to compute logits[sorted_indices[~sorted_indices_to_keep]], you are using a long 1D tensor to index into a small 2D tensor. If you run this on CPU, you get an error like IndexError: index 20 is out of bounds for dimension 0 with size 8. When you run on GPU, you get the Cuda assert error.

To fix this, use the scatter operation. Use something like this:

def top_p_filtering(logits, top_p, shift_indices=True, debug=False):
    """Filter the logits using top-p (nucleus) sampling."""
    # Sort logits in descending order and get the sorted indices
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)

    # Compute the cumulative probabilities of the sorted logits
    cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

    # Create a mask for the tokens to keep
    sorted_indices_to_keep = cumulative_probs <= top_p
    
    # Optional: shift indices to the right. This results in keeping the first 
    # token above the top_p threshold. Skip this line to ensure that all 
    # token probs are strictly below the top_p threshold
    if shift_indices:
        sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()

    # Ensure that at least one token is kept (the first token, which has the highest logit)
    sorted_indices_to_keep[..., 0] = True
    
    # Use scatter to create top_p mask
    mask = sorted_indices_to_keep.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_keep)
    
    # Optional debug check to make sure top_p is being honored
    # Note we need to compute probs before masking because applying softmax 
    # after masking will result in a distribution that sums to 1
    if debug:
        probs = torch.nn.functional.softmax(logits, dim=-1)
        probs[~mask] = 0
        print(probs.sum(-1))
    
    # Use mask to set logit vals to -inf
    logits[~mask] = float('-inf')

    return logits
Sign up to request clarification or add additional context in comments.

2 Comments

This solution works, but why? How come my logits are not a one-dimmentional vector? (I am using this code for inference, so there is no batch). And what makes my boolean mask behave weirdly? Anyways, thank you for the solution, it really helped!
For inference you still have a batch dimension, even if it's just one item. If that wasn't the case, your line logits = outputs.logits[:, -1, :] would error. The behavior of sorted_indices[~sorted_indices_to_keep] is typical - it returns a vector because there is no way to resolve it into a 2D tensor

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.