1+ from peft import LoraConfig , get_peft_model
2+ import pandas as pd
3+ import numpy as np
4+ import torch
5+ import os
6+ import torch .nn as nn
7+ from tqdm import tqdm
8+ from eva_clip import create_model_and_transforms , create_model_from_pretrained
9+ from training .llm2vec_wrapper import LLM2VecWrapper as LLM2Vec
10+ from peft import LoraConfig , get_peft_model
11+ from PIL import Image
12+ import torch .nn .functional as F
13+ import argparse
14+ import torch
15+ from PIL import Image
16+ from eva_clip import create_model_and_transforms , create_model_from_pretrained
17+ from training .llm2vec_wrapper import LLM2VecWrapper as LLM2Vec
18+ import torch .nn as nn
19+
20+ class LLM2VecWithProjection (nn .Module ):
21+ def __init__ (self , llm2vec_model , projection ):
22+ super ().__init__ ()
23+ self .model = llm2vec_model
24+ self .projection = projection
25+ # self.tokenizer = llm2vec_model.tokenizer
26+
27+ def forward (self , text ):
28+ embeddings = self .model (text )
29+ return self .projection (embeddings )
30+
31+ def lock (self , unlocked_layers = 0 , freeze_layer_norm = True ):
32+ # Freeze LLM2Vec weights but keep projection trainable
33+ for param in self .model .parameters ():
34+ param .requires_grad = False
35+
36+ def set_grad_checkpointing (self , enable = True ):
37+ self .model .gradient_checkpointing_enable () if enable else self .model .gradient_checkpointing_disable ()
38+ # Initialize model and preprocessing
39+ model , preprocess_train , preprocess_val = create_model_and_transforms (
40+ "EVA02-CLIP-L-14-336" ,
41+ "eva_clip" ,
42+ precision = "fp16" ,
43+ device = "cuda" ,
44+ force_custom_clip = True ,
45+ image_mean = None ,
46+ image_std = None ,
47+ cache_dir = None ,
48+ skip_list = None ,
49+ )
50+
51+ model .eval ()
52+
53+ # Load pre-trained LLM2Vec model
54+ text_model = LLM2Vec .from_pretrained (
55+ base_model_name_or_path = "meta-llama/Llama-3.2-3B" ,
56+ enable_bidirectional = True ,
57+ peft_model_name_or_path = '/data/research/tmp/checkpoint-12600/' ,
58+ merge_peft = True ,
59+ pooling_mode = "mean" ,
60+ max_length = 512 ,
61+ torch_dtype = torch .bfloat16 ,
62+ )
63+
64+ # Add a trainable projection layer
65+ projection_layer = nn .Sequential (
66+ nn .LayerNorm (text_model .config .hidden_size ),
67+ nn .Linear (text_model .config .hidden_size , model .visual .head .out_features )
68+ ).to ('cuda' )
69+
70+ # Wrap LLM2Vec with projection
71+ model .text = LLM2VecWithProjection (text_model .model , projection_layer )
72+ # Load
73+
74+ ## Load from pretrained checkpoint
75+ model .load_state_dict (torch .load ('/model/llm2clip/logs/T_vitl336_mimic-2025_01_06-11/checkpoints/output/pytorch_model.bin' ), strict = False )
76+ model .eval ()
77+ for param in model .parameters ():
78+ # Check if parameter dtype is Float (float32)
79+ if param .dtype == torch .float32 :
80+ param .data = param .data .to (torch .float16 )
81+ model .eval ()
82+
83+ ## Inference
84+ import pandas as pd
85+ df = pd .read_csv ('/data/research/csv/rsna_test.csv' )
86+ df .head ()
87+
88+
89+ df ['caption2' ][0 ]
90+
91+ def load_image (image_path ):
92+ """Load and preprocess the image."""
93+ image = Image .open (image_path ).convert ("RGB" )
94+ image_tensor = preprocess_val (image ).unsqueeze (0 ).to ('cuda' )
95+ return image_tensor .to (torch .float16 )
96+
97+
98+ def encode_image (image_tensor ):
99+ """Encode the image using the visual model."""
100+ with torch .no_grad ():
101+ image_embedding = model .visual (image_tensor )
102+ return image_embedding
103+
104+ def encode_text (text , model , text_model , image_tensor ):
105+ """Encode the text using the LLM2Vec model with projection."""
106+ original = text_model .tokenizer (
107+ # ["The patient's gender is female", "The patient's gender is male", "The patient's gender is unknown"],
108+ # ["This is a PA view", "This is a AP view"],
109+ text ,
110+ return_tensors = "pt" ,
111+ padding = True ,
112+ truncation = True ,
113+ max_length = 512 ,
114+ )
115+ embed_mask = torch .zeros_like (original ["attention_mask" ])
116+ original ["embed_mask" ] = embed_mask
117+ l2v = LLM2Vec (model .text .model , text_model .tokenizer , pooling_mode = "mean" , max_length = 512 ).to ('cuda:0' )
118+ text_features = l2v .forward (original .to (device = 'cuda:0' ))
119+ with torch .no_grad ():
120+ text_features = model .text .projection (text_features .to (dtype = image_tensor .dtype ))
121+ return text_features
122+
123+ def compute_similarity (image_embedding , text_features ):
124+ # Normalize the vectors
125+ image_embedding = image_embedding / image_embedding .norm (dim = 1 , keepdim = True )
126+ text_features = text_features / text_features .norm (dim = 1 , keepdim = True )
127+
128+ # Compute cosine similarity
129+ # This will give you a [1, 10] tensor where each value is the cosine similarity
130+ # between the image embedding and each text feature
131+ similarity = (image_embedding @ text_features .T )
132+
133+ return similarity
134+
135+
136+ index = 4
137+ text_test = {'gender' : ["The patient's gender is female." , "The patient's gender is male." ], 'view' : ["This is a PA view, CXR image of the whole chest." , "This is a AP view, CXR image of the whole chest." ],
138+ 'report2' : df ['caption2' ].tolist (), 'report' : df ['caption' ].tolist (), 'age' : [f"The patient's age is { x } years old." for x in df ['age' ]],
139+ 'pneumonia' : ["Pneumonia is present." , "No signs of pneumonia" ], 'pneumothorax' : ['There is pneumothorax.' , 'There is no pneumothorax.' ]}
140+ print (df ['caption2' ][index ])
141+
142+ image_tensor = load_image (df ['img_path' ].tolist ()[index ])
143+ image_embedding = encode_image (image_tensor )
144+
145+ text_embedding = encode_text (text_test ['report' ], model , text_model , image_tensor )
146+ similarity = compute_similarity (image_embedding , text_embedding )* model .logit_scale
147+ print (f"Similarities shape: { similarity .shape } " ) # Should be [1, 10]
148+ print (f"Similarities: { similarity } " )
0 commit comments