I’ve been working with Qwen2.5-VL and Gemma3 locally, and I need to measure the similarity between text and image embeddings—similar to CLIP/SigLIP—but I’m resource-limited and can’t spin up additional models. I tried extracting embeddings and computing cosine similarity myself, but I’m not getting meaningful results. Am I doing something wrong? How should I correctly compute text–image similarity with Qwen2.5-VL and Gemma3?
What I’ve tried
import torch
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
BitsAndBytesConfig,
)
from transformers.image_utils import load_image
from torch.nn.functional import normalize, cosine_similarity
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map="auto",
quantization_config=quant_config,
)
min_pixels = 144 * 28 * 28
max_pixels = 256 * 28 * 28
processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
min_pixels=min_pixels,
max_pixels=max_pixels,
use_fast=True,
)
img = load_image("https://huggingface.co/datasets/merve/coco/resolve/main/val2017/000000039769.jpg")
inputs = processor(
text=["This is a photo of 2 cats."],
images=[img],
).to("cuda")
with torch.no_grad():
input_ids = torch.tensor(input_processed['input_ids'], device='cuda')
txt_embeds = model.model.embed_tokens(input_ids).to('cuda')
img_embeds = model.visual(input_processed['pixel_values'], grid_thw=input_processed['image_grid_thw']).to('cuda')
sim = cosine_similarity(
normalize(img_embeds.mean(dim=0), dim=-1),
normalize(txt_embeds.mean(dim=1), dim=-1),
)
print("Cosine similarity:", sim.item())
This yields a similarity of 0.0069, which seems far too low.
Also tried Gemma3
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config,
token=access_token,
output_hidden_states=True,
return_dict=True,
).eval()
processor = AutoProcessor.from_pretrained(
model_id,
token=access_token,
min_pixels=256*28*28,
max_pixels=512*28*28,
use_fast=True,
)
inputs_img = processor(
text="<start_of_image>", images=img, return_tensors="pt", padding=True
).to(model.device, dtype=torch.bfloat16)
inputs_txt = processor(
text="This is a photo of 2 cats.", return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
with torch.no_grad():
img_tokens = model.get_image_features(pixel_values=inputs_img["pixel_values"])
img_feats = normalize(img_tokens.mean(dim=1), dim=-1)
tok_embeds = model.get_input_embeddings()(inputs_txt["input_ids"])
tok_embeds = model(
**inputs_txt,
output_hidden_states=True,
return_dict=True
).hidden_states[-1]
txt_feats = normalize(tok_embeds.mean(dim=1), dim=-1)
sim = cosine_similarity(img_feats, txt_feats)
print("Cosine similarity:", sim.item())
This yields a similarity of 0.0403
By comparison, SigLIP’s example on Hugging Face does:
inputs = processor(
text=texts,
images=image,
padding="max_length",
max_length=64,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image)
print("SigLIP similarity:", probs)
google/siglip2-large-patch16-256 → ~0.164
google/siglip2-so400m-patch14-384 → ~0.245
Questions:
Is my approach to extracting embeddings and computing cosine similarity incorrect?
Is there a recommended pooling or projection strategy for generative VLMs to produce a contrastive-like score?
Should I fine-tune or attach a small head on top of Qwen2.5-VL, Gemma3 to align embeddings?
Are there any existing libraries or implementations that standardize text–image similarity for these generative vision-language models under limited resources (in my case VRAM 6GB but not using only the chat model)?
I’m wondering if there’s a vision-chat model that can compute text–image similarity natively—I don’t have to stick with Qwen2.5-VL or Gemma3, but I haven’t been able to find any due to my skill limitations.
Any pointers, code examples, or best practices would be greatly appreciated!