Skip to content

Commit 07df823

Browse files
added testcode
1 parent 8ef86b5 commit 07df823

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed

llm2clip/test.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from peft import LoraConfig, get_peft_model
2+
import pandas as pd
3+
import numpy as np
4+
import torch
5+
import os
6+
import torch.nn as nn
7+
from tqdm import tqdm
8+
from eva_clip import create_model_and_transforms, create_model_from_pretrained
9+
from training.llm2vec_wrapper import LLM2VecWrapper as LLM2Vec
10+
from peft import LoraConfig, get_peft_model
11+
from PIL import Image
12+
import torch.nn.functional as F
13+
import argparse
14+
import torch
15+
from PIL import Image
16+
from eva_clip import create_model_and_transforms, create_model_from_pretrained
17+
from training.llm2vec_wrapper import LLM2VecWrapper as LLM2Vec
18+
import torch.nn as nn
19+
20+
class LLM2VecWithProjection(nn.Module):
21+
def __init__(self, llm2vec_model, projection):
22+
super().__init__()
23+
self.model = llm2vec_model
24+
self.projection = projection
25+
# self.tokenizer = llm2vec_model.tokenizer
26+
27+
def forward(self, text):
28+
embeddings = self.model(text)
29+
return self.projection(embeddings)
30+
31+
def lock(self, unlocked_layers=0, freeze_layer_norm=True):
32+
# Freeze LLM2Vec weights but keep projection trainable
33+
for param in self.model.parameters():
34+
param.requires_grad = False
35+
36+
def set_grad_checkpointing(self, enable=True):
37+
self.model.gradient_checkpointing_enable() if enable else self.model.gradient_checkpointing_disable()
38+
# Initialize model and preprocessing
39+
model, preprocess_train, preprocess_val = create_model_and_transforms(
40+
"EVA02-CLIP-L-14-336",
41+
"eva_clip",
42+
precision="fp16",
43+
device="cuda",
44+
force_custom_clip=True,
45+
image_mean=None,
46+
image_std=None,
47+
cache_dir=None,
48+
skip_list=None,
49+
)
50+
51+
model.eval()
52+
53+
# Load pre-trained LLM2Vec model
54+
text_model = LLM2Vec.from_pretrained(
55+
base_model_name_or_path="meta-llama/Llama-3.2-3B",
56+
enable_bidirectional=True,
57+
peft_model_name_or_path='/data/research/tmp/checkpoint-12600/',
58+
merge_peft=True,
59+
pooling_mode="mean",
60+
max_length=512,
61+
torch_dtype=torch.bfloat16,
62+
)
63+
64+
# Add a trainable projection layer
65+
projection_layer = nn.Sequential(
66+
nn.LayerNorm(text_model.config.hidden_size),
67+
nn.Linear(text_model.config.hidden_size, model.visual.head.out_features)
68+
).to('cuda')
69+
70+
# Wrap LLM2Vec with projection
71+
model.text = LLM2VecWithProjection(text_model.model, projection_layer)
72+
# Load
73+
74+
## Load from pretrained checkpoint
75+
model.load_state_dict(torch.load('/model/llm2clip/logs/T_vitl336_mimic-2025_01_06-11/checkpoints/output/pytorch_model.bin'), strict=False)
76+
model.eval()
77+
for param in model.parameters():
78+
# Check if parameter dtype is Float (float32)
79+
if param.dtype == torch.float32:
80+
param.data = param.data.to(torch.float16)
81+
model.eval()
82+
83+
## Inference
84+
import pandas as pd
85+
df = pd.read_csv('/data/research/csv/rsna_test.csv')
86+
df.head()
87+
88+
89+
df['caption2'][0]
90+
91+
def load_image(image_path):
92+
"""Load and preprocess the image."""
93+
image = Image.open(image_path).convert("RGB")
94+
image_tensor = preprocess_val(image).unsqueeze(0).to('cuda')
95+
return image_tensor.to(torch.float16)
96+
97+
98+
def encode_image(image_tensor):
99+
"""Encode the image using the visual model."""
100+
with torch.no_grad():
101+
image_embedding = model.visual(image_tensor)
102+
return image_embedding
103+
104+
def encode_text(text, model, text_model, image_tensor):
105+
"""Encode the text using the LLM2Vec model with projection."""
106+
original = text_model.tokenizer(
107+
# ["The patient's gender is female", "The patient's gender is male", "The patient's gender is unknown"],
108+
# ["This is a PA view", "This is a AP view"],
109+
text,
110+
return_tensors="pt",
111+
padding=True,
112+
truncation=True,
113+
max_length=512,
114+
)
115+
embed_mask = torch.zeros_like(original["attention_mask"])
116+
original["embed_mask"] = embed_mask
117+
l2v = LLM2Vec(model.text.model, text_model.tokenizer, pooling_mode="mean", max_length=512).to('cuda:0')
118+
text_features = l2v.forward(original.to(device='cuda:0'))
119+
with torch.no_grad():
120+
text_features = model.text.projection(text_features.to(dtype=image_tensor.dtype))
121+
return text_features
122+
123+
def compute_similarity(image_embedding, text_features):
124+
# Normalize the vectors
125+
image_embedding = image_embedding / image_embedding.norm(dim=1, keepdim=True)
126+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
127+
128+
# Compute cosine similarity
129+
# This will give you a [1, 10] tensor where each value is the cosine similarity
130+
# between the image embedding and each text feature
131+
similarity = (image_embedding @ text_features.T)
132+
133+
return similarity
134+
135+
136+
index = 4
137+
text_test = {'gender': ["The patient's gender is female.", "The patient's gender is male."], 'view': ["This is a PA view, CXR image of the whole chest.", "This is a AP view, CXR image of the whole chest."],
138+
'report2': df['caption2'].tolist(), 'report': df['caption'].tolist(), 'age': [f"The patient's age is {x} years old." for x in df['age']],
139+
'pneumonia': ["Pneumonia is present.", "No signs of pneumonia"], 'pneumothorax': ['There is pneumothorax.', 'There is no pneumothorax.']}
140+
print(df['caption2'][index])
141+
142+
image_tensor = load_image(df['img_path'].tolist()[index])
143+
image_embedding = encode_image(image_tensor)
144+
145+
text_embedding = encode_text(text_test['report'], model, text_model, image_tensor)
146+
similarity = compute_similarity(image_embedding, text_embedding)*model.logit_scale
147+
print(f"Similarities shape: {similarity.shape}") # Should be [1, 10]
148+
print(f"Similarities: {similarity}")

0 commit comments

Comments
 (0)