Skip to content

Commit c888b70

Browse files
added 8b features
1 parent 292992e commit c888b70

File tree

4 files changed

+288
-16
lines changed

4 files changed

+288
-16
lines changed

llm2clip/run.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
MODEL=EVA02-CLIP-B-16
1+
MODEL=EVA02-CLIP-L-14-336
22
PRETRAINED=eva_clip
33
python -m torch.distributed.launch --nproc_per_node=2 \
44
--use_env training/main.py \
55
--enable-deepspeed \
66
--grad-checkpointing \
7-
--name="mimic_B16448_caption_local" \
8-
--save-frequency 2 \
7+
--name="mimic_B16448_8b_local" \
8+
--save-frequency 10 \
99
--local-loss \
1010
--zeroshot-frequency 2 \
1111
--report-to="tensorboard, wandb" \
@@ -16,8 +16,8 @@ python -m torch.distributed.launch --nproc_per_node=2 \
1616
--pretrained=${PRETRAINED} \
1717
--precision "fp16" \
1818
--warmup 0 \
19-
--batch-size=256 \
20-
--eval-batch-size=256 \
19+
--batch-size=160 \
20+
--eval-batch-size=160 \
2121
--log-every-n-steps 200 \
2222
--epochs=20 \
2323
--lr=1e-5 \
@@ -35,8 +35,8 @@ python -m torch.distributed.launch --nproc_per_node=2 \
3535
--model=${MODEL} \
3636
--seed 4096 \
3737
--gather-with-grad \
38-
--text-base="meta-llama/Llama-3.2-3B" \
39-
--llm2vec-path="/data/research/tmp/checkpoint-24678/" \
38+
--text-base="meta-llama/Meta-Llama-3.1-8B-Instruct" \
39+
--llm2vec-path="/data/research/tmp/checkpoint-llama8b2_old/" \
4040
--force-custom-clip \
4141
--optimizer="ap_adamw" \
4242
--zero-stage=1 \

llm2clip/training/zero_shot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def zero_shot_eval(model, l2v, data, epoch, args):
160160
# Add medical condition evaluation
161161
if 'rsna' in data:
162162
text_categories = {
163-
'pneumonia': ["Pneumonia is present.", "No signs of pneumonia"],
163+
'pneumonia': ["Detected abnormalities : There is pneumonia.", "Detected abnormalities : There is no pneumonia"],
164164
}
165165

166166
logging.info('Building medical zero-shot classifier')
@@ -175,7 +175,7 @@ def zero_shot_eval(model, l2v, data, epoch, args):
175175

176176
if 'siim' in data:
177177
text_categories = {
178-
'pneumothorax': ['There is pneumothorax.', 'There is no pneumothorax.']
178+
'pneumothorax': ['Detected abnormalities : There is pneumothorax.', 'Detected abnormalities : There is no pneumothorax.']
179179
}
180180
logging.info('Building medical zero-shot classifier')
181181
medical_classifier = zero_shot_classifier_medical(model, text_categories, l2v, args)

llm_caption_contrastive/dataset/CXR.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,29 @@
1515
}
1616

1717
def shuffle_sentences(text, is_shuffle=True):
18-
# Split the text into sentences using a regex to account for periods that end sentences
19-
if is_shuffle:
20-
sentences = re.split(r'(?<=[.!?])\s+', text)
18+
# Split the text into parts based on the first colon
19+
if ':' in text:
20+
before_colon, after_colon = text.split(':', 1)
21+
else:
22+
before_colon, after_colon = text, ""
2123

22-
# Shuffle the sentences
24+
if is_shuffle and after_colon:
25+
# Split the text after the colon into sentences using regex
26+
sentences = re.split(r'(?<=[.!?])\s+', after_colon.strip())
27+
28+
# Shuffle the sentences
2329
random.shuffle(sentences)
24-
30+
2531
# Join the shuffled sentences back into a single string
2632
shuffled_text = ' '.join(sentences)
2733
else:
28-
shuffled_text = text
29-
return shuffled_text
34+
shuffled_text = after_colon.strip()
35+
36+
# Combine the part before the colon and the shuffled/unchanged part after the colon
37+
if after_colon:
38+
return f"{before_colon}: {shuffled_text}"
39+
else:
40+
return before_colon
3041

3142
class CXRDataset(Dataset):
3243
def __init__(
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import logging
2+
import os
3+
import sys
4+
from dataclasses import dataclass, field
5+
from typing import Any, Dict, List, Optional, Union
6+
7+
import torch
8+
from torch.utils.data import DataLoader, SequentialSampler
9+
10+
from accelerate import Accelerator
11+
from accelerate.logging import get_logger
12+
13+
import transformers
14+
from transformers import (
15+
HfArgumentParser,
16+
AutoTokenizer,
17+
TrainingArguments,
18+
set_seed,
19+
)
20+
21+
from llm2vec_wrapper import LLM2VecWrapper as LLM2Vec
22+
from dataset.utils import load_dataset
23+
from llm2vec.loss.utils import load_loss
24+
from llm2vec.experiment_utils import generate_experiment_id
25+
26+
from tqdm import tqdm
27+
28+
# Suppress unnecessary warnings
29+
transformers.logging.set_verbosity_error()
30+
31+
# Configure logging
32+
logging.basicConfig(
33+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
34+
datefmt="%Y-%m-%d %H:%M:%S",
35+
level=logging.INFO,
36+
)
37+
logger = get_logger(__name__, log_level="INFO")
38+
39+
@dataclass
40+
class ModelArguments:
41+
"""
42+
Arguments pertaining to which model/config/tokenizer we are going to use for testing.
43+
"""
44+
model_name_or_path: Optional[str] = field(
45+
default=None,
46+
metadata={
47+
"help": "The model checkpoint for weights initialization."
48+
},
49+
)
50+
peft_model_name_or_path: Optional[str] = field(
51+
default=None,
52+
metadata={"help": "The PEFT model checkpoint to add on top of the base model."},
53+
)
54+
pooling_mode: Optional[str] = field(
55+
default="mean",
56+
metadata={
57+
"help": "The pooling mode to use in the model.",
58+
"choices": ["mean", "weighted_mean", "eos_token"],
59+
},
60+
)
61+
max_seq_length: Optional[int] = field(
62+
default=512,
63+
metadata={
64+
"help": "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated."
65+
},
66+
)
67+
torch_dtype: Optional[str] = field(
68+
default="float16",
69+
metadata={
70+
"help": "Override the default `torch.dtype` and load the model under this dtype.",
71+
"choices": ["auto", "bfloat16", "float16", "float32"],
72+
},
73+
)
74+
75+
@dataclass
76+
class DataTrainingArguments:
77+
"""
78+
Arguments pertaining to the test data input.
79+
"""
80+
test_dataset_name: Optional[str] = field(
81+
default=None,
82+
metadata={"help": "The name of the test dataset to use."},
83+
)
84+
test_file_path: Optional[str] = field(
85+
default=None,
86+
metadata={"help": "The input test data file or folder."},
87+
)
88+
dataframe_path: Optional[str] = field(
89+
default=None,
90+
metadata={"help": "Path to the dataframe file for the test dataset."},
91+
)
92+
max_test_samples: Optional[int] = field(
93+
default=None,
94+
metadata={
95+
"help": "For debugging purposes or quicker testing, truncate the number of test examples to this value if set."
96+
},
97+
)
98+
99+
@dataclass
100+
class CustomArguments:
101+
"""
102+
Custom arguments for the testing script.
103+
"""
104+
batch_size: int = field(
105+
default=16,
106+
metadata={"help": "Batch size for testing."}
107+
)
108+
output_dir: Optional[str] = field(
109+
default="./test_results",
110+
metadata={"help": "Directory to save test results."}
111+
)
112+
113+
def prepare_for_tokenization(model, text, pooling_mode="mean"):
114+
if model.config._name_or_path in ["meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-8B"]:
115+
text = (
116+
"<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
117+
)
118+
return text
119+
if model.config._name_or_path in [
120+
"mistralai/Mistral-7B-Instruct-v0.2",
121+
"meta-llama/Llama-2-7b-chat-hf",
122+
]:
123+
text = "[INST] " + text.strip() + " [/INST]"
124+
if model.config._name_or_path in [
125+
"google/gemma-2-9b-it",
126+
]:
127+
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
128+
if model.config._name_or_path in [
129+
"Qwen/Qwen2-1.5B-Instruct",
130+
"Qwen/Qwen2-7B-Instruct",
131+
]:
132+
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
133+
if pooling_mode == "eos_token":
134+
if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
135+
text = text.strip() + "<|end_of_text|>"
136+
elif isinstance(model.config, LlamaConfig) or isinstance(
137+
model.config, MistralConfig
138+
):
139+
text = text.strip() + " </s>"
140+
elif isinstance(model.config, GemmaConfig):
141+
text = text.strip() + "<eos>"
142+
elif isinstance(model.config, Qwen2Config):
143+
text = text.strip() + "<|endoftext|>"
144+
return text
145+
146+
def main():
147+
# Initialize Accelerator
148+
accelerator = Accelerator()
149+
logger.info("Initialized Accelerator.")
150+
151+
# Argument parsing
152+
parser = HfArgumentParser(
153+
(ModelArguments, DataTrainingArguments, CustomArguments)
154+
)
155+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
156+
model_args, data_args, custom_args = parser.parse_json_file(
157+
json_file=os.path.abspath(sys.argv[1])
158+
)
159+
else:
160+
model_args, data_args, custom_args = parser.parse_args_into_dataclasses()
161+
162+
# Set seed for reproducibility
163+
set_seed(42)
164+
165+
# Load tokenizer
166+
logger.info("Loading tokenizer...")
167+
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
168+
169+
# Load model with PEFT
170+
logger.info("Loading the LLM2Vec model with PEFT adapter for testing...")
171+
model = LLM2Vec.from_pretrained(
172+
base_model_name_or_path=model_args.model_name_or_path,
173+
enable_bidirectional=False, # Typically not needed for testing
174+
peft_model_name_or_path=model_args.peft_model_name_or_path,
175+
merge_peft=True,
176+
pooling_mode=model_args.pooling_mode,
177+
max_length=model_args.max_seq_length,
178+
torch_dtype=torch.float16 if model_args.torch_dtype == "float16" else getattr(torch, model_args.torch_dtype),
179+
)
180+
181+
model.to(accelerator.device)
182+
model.eval()
183+
logger.info("Model loaded and set to evaluation mode.")
184+
185+
# Load test dataset
186+
logger.info("Loading test dataset...")
187+
test_dataset = load_dataset(
188+
dataset_name=data_args.test_dataset_name,
189+
split="test",
190+
file_path=data_args.test_file_path,
191+
dataframe_path=data_args.dataframe_path,
192+
)
193+
194+
if data_args.max_test_samples is not None:
195+
test_dataset = test_dataset.select(range(data_args.max_test_samples))
196+
logger.info(f"Truncated test dataset to {data_args.max_test_samples} samples.")
197+
198+
# Prepare data collator
199+
def collate_fn(batch):
200+
texts = [prepare_for_tokenization(model.model, example['text'], pooling_mode=model_args.pooling_mode) for example in batch]
201+
inputs = tokenizer(
202+
texts,
203+
padding=True,
204+
truncation=True,
205+
max_length=model_args.max_seq_length,
206+
return_tensors="pt"
207+
)
208+
labels = torch.tensor([example['label'] for example in batch])
209+
return {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "labels": labels}
210+
211+
test_dataloader = DataLoader(
212+
test_dataset,
213+
sampler=SequentialSampler(test_dataset),
214+
batch_size=custom_args.batch_size,
215+
collate_fn=collate_fn,
216+
)
217+
218+
# Prepare dataloader with accelerator
219+
test_dataloader = accelerator.prepare(test_dataloader)
220+
logger.info("Test DataLoader prepared.")
221+
222+
# Evaluation loop
223+
logger.info("Starting evaluation...")
224+
all_predictions = []
225+
all_labels = []
226+
227+
with torch.no_grad():
228+
for batch in tqdm(test_dataloader, desc="Testing"):
229+
inputs = {
230+
"input_ids": batch["input_ids"].to(accelerator.device),
231+
"attention_mask": batch["attention_mask"].to(accelerator.device),
232+
}
233+
labels = batch["labels"].to(accelerator.device)
234+
235+
outputs = model.model(**inputs)
236+
logits = outputs.logits
237+
predictions = torch.argmax(logits, dim=-1)
238+
239+
all_predictions.extend(accelerator.gather(predictions).cpu().numpy())
240+
all_labels.extend(accelerator.gather(labels).cpu().numpy())
241+
242+
# Compute metrics
243+
from sklearn.metrics import accuracy_score, f1_score
244+
245+
accuracy = accuracy_score(all_labels, all_predictions)
246+
f1 = f1_score(all_labels, all_predictions, average='weighted')
247+
248+
logger.info(f"Test Accuracy: {accuracy:.4f}")
249+
logger.info(f"Test F1 Score: {f1:.4f}")
250+
251+
# Save results
252+
os.makedirs(custom_args.output_dir, exist_ok=True)
253+
results_path = os.path.join(custom_args.output_dir, "test_results.txt")
254+
with open(results_path, "w") as f:
255+
f.write(f"Test Accuracy: {accuracy:.4f}\n")
256+
f.write(f"Test F1 Score: {f1:.4f}\n")
257+
258+
logger.info(f"Test results saved to {results_path}")
259+
260+
if __name__ == "__main__":
261+
main()

0 commit comments

Comments
 (0)