Skip to content

Commit 9197adf

Browse files
v0.9
1 parent c888b70 commit 9197adf

File tree

18 files changed

+2306
-464
lines changed

18 files changed

+2306
-464
lines changed

llm2clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"embed_dim": 1280,
33
"vision_cfg": {
4-
"image_size": 336,
4+
"image_size": 448,
55
"layers": 24,
66
"width": 1024,
77
"drop_path_rate": 0,

llm2clip/run.sh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
MODEL=EVA02-CLIP-L-14-336
22
PRETRAINED=eva_clip
3-
python -m torch.distributed.launch --nproc_per_node=2 \
3+
python -m torch.distributed.launch --nproc_per_node=4 \
44
--use_env training/main.py \
55
--enable-deepspeed \
66
--grad-checkpointing \
7-
--name="mimic_B16448_8b_local" \
8-
--save-frequency 10 \
7+
--name="final_llm2clip_caption" \
8+
--save-frequency 2 \
99
--local-loss \
1010
--zeroshot-frequency 2 \
1111
--report-to="tensorboard, wandb" \
@@ -16,8 +16,8 @@ python -m torch.distributed.launch --nproc_per_node=2 \
1616
--pretrained=${PRETRAINED} \
1717
--precision "fp16" \
1818
--warmup 0 \
19-
--batch-size=160 \
20-
--eval-batch-size=160 \
19+
--batch-size=128 \
20+
--eval-batch-size=128 \
2121
--log-every-n-steps 200 \
2222
--epochs=20 \
2323
--lr=1e-5 \
@@ -35,14 +35,14 @@ python -m torch.distributed.launch --nproc_per_node=2 \
3535
--model=${MODEL} \
3636
--seed 4096 \
3737
--gather-with-grad \
38-
--text-base="meta-llama/Meta-Llama-3.1-8B-Instruct" \
39-
--llm2vec-path="/data/research/tmp/checkpoint-llama8b2_old/" \
38+
--text-base="/model/llm2clip/llm2vec/8b_special/mntp/checkpoint-5779/" \
39+
--llm2vec-path="/model/llm2clip/llm2vec/8b_special/supervised/checkpoint-12535/" \
4040
--force-custom-clip \
4141
--optimizer="ap_adamw" \
4242
--zero-stage=1 \
4343
--dataset-type "cxr" \
4444
--csv-img-key "img_path" \
45-
--csv-caption-key "caption" \
45+
--csv-caption-key "caption2_lite" \
4646
--rsna "/data/research/csv/rsna_test.csv" \
4747
--siim "/data/research/csv/siim_test.csv" \
4848
--openi "/data/csv/llm2clip/openi_clip_val.csv" \

llm2clip/training/data.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,21 @@
5151
from PIL import Image
5252
import torch
5353

54-
def apply_dropout(text):
54+
def apply_dropout(text, age_dropout=0.3, view_dropout=0.3, gender_dropout=0.3, bmi_dropout=0.3):
5555
# 30% chance to drop each attribute
56-
if random.random() < 0.3:
56+
if random.random() < view_dropout:
5757
# Replace view position
5858
text = re.sub(r"This is a (\w+) view", "This is a unknown view", text)
5959

60-
if random.random() < 0.3:
60+
if random.random() < age_dropout:
6161
# Replace age
62-
text = re.sub(r"The patient is (\d+) years old", "The patient is unknown years old", text)
62+
text = re.sub(r"The patient's age is (\d+)", "The patient's age is unknown", text)
63+
64+
if random.random() < bmi_dropout:
65+
# Replace age
66+
text = re.sub(r"The patient's bmi is (\d+)", "The patient's bmi is unknown", text)
6367

64-
if random.random() < 0.3:
68+
if random.random() < gender_dropout:
6569
# Replace gender
6670
text = re.sub(r"The patient's gender is (\w+)", "The patient's gender is unknown", text)
6771

llm2clip/training/main.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
from datetime import datetime
66
sys.path.append(os.getcwd())
7-
7+
from peft import PeftModel
88
import numpy as np
99
import torch
1010
from torch.cuda.amp import GradScaler
@@ -143,19 +143,25 @@ def main(args):
143143
cache_dir=args.cache_dir,
144144
skip_list=args.skip_list,
145145
)
146-
146+
logging.info("text_model is loading...")
147147
random_seed(args.seed, args.rank)
148148
if args.llm2vec_path:
149149
print("Using LLM2Vec")
150150
text_model = LLM2Vec.from_pretrained(
151151
base_model_name_or_path=args.text_base,
152152
enable_bidirectional=True,
153-
peft_model_name_or_path=args.llm2vec_path,
153+
peft_model_name_or_path=args.text_base,
154154
merge_peft=True,
155155
pooling_mode="mean",
156156
max_length=512,
157157
torch_dtype=torch.bfloat16,
158158
)
159+
text_model.model = PeftModel.from_pretrained(
160+
text_model.model,
161+
args.llm2vec_path,
162+
)
163+
164+
text_model.model = text_model.model.merge_and_unload()
159165
# Add a trainable projection layer
160166
projection_layer = nn.Sequential(
161167
nn.LayerNorm(text_model.config.hidden_size),

llm2clip/training/train.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,12 @@ def evaluate_iter(model, tokenizer, data, iter_nums, epoch, args, tb_writer=None
259259
model.eval()
260260
l2v = LLM2Vec(model.text.model, tokenizer, pooling_mode="mean", max_length=512) #TODO: modify this
261261
print('evaluating retrieval')
262-
with torch.no_grad():
263-
retrieval_zero_shot_metrics = retrieval_eval(model, l2v, data, epoch, args)
264-
metrics.update(retrieval_zero_shot_metrics)
265-
zero_shot_metrics = zero_shot_eval(model, l2v, data, epoch, args)
266-
metrics.update(zero_shot_metrics)
267-
print(zero_shot_metrics)
262+
# with torch.no_grad():
263+
# retrieval_zero_shot_metrics = retrieval_eval(model, l2v, data, epoch, args)
264+
# metrics.update(retrieval_zero_shot_metrics)
265+
# zero_shot_metrics = zero_shot_eval(model, l2v, data, epoch, args)
266+
# metrics.update(zero_shot_metrics)
267+
# print(zero_shot_metrics)
268268
autocast = get_autocast(args.precision)
269269
cast_dtype = get_cast_dtype(args.precision)
270270
if 'val' in data:
@@ -356,9 +356,9 @@ def evaluate(model, tokenizer, data, epoch, args, tb_writer=None):
356356
l2v = LLM2Vec(model.text.model, tokenizer, pooling_mode="mean", max_length=512) #TODO: modify this
357357
retrieval_zero_shot_metrics = retrieval_eval(model, l2v, data, epoch, args)
358358
metrics.update(retrieval_zero_shot_metrics)
359-
zero_shot_metrics = zero_shot_eval(model, l2v, data, epoch, args)
360-
metrics.update(zero_shot_metrics)
361-
print(zero_shot_metrics)
359+
# zero_shot_metrics = zero_shot_eval(model, l2v, data, epoch, args)
360+
# metrics.update(zero_shot_metrics)
361+
# print(zero_shot_metrics)
362362
autocast = get_autocast(args.precision)
363363
cast_dtype = get_cast_dtype(args.precision)
364364

llm2clip/training/zero_shot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def zero_shot_eval(model, l2v, data, epoch, args):
160160
# Add medical condition evaluation
161161
if 'rsna' in data:
162162
text_categories = {
163-
'pneumonia': ["Detected abnormalities : There is pneumonia.", "Detected abnormalities : There is no pneumonia"],
163+
'pneumonia': ["pneumonia is present", "there is no pneumonia"],
164164
}
165165

166166
logging.info('Building medical zero-shot classifier')
@@ -175,7 +175,7 @@ def zero_shot_eval(model, l2v, data, epoch, args):
175175

176176
if 'siim' in data:
177177
text_categories = {
178-
'pneumothorax': ['Detected abnormalities : There is pneumothorax.', 'Detected abnormalities : There is no pneumothorax.']
178+
'pneumothorax': ['pneumothorax is present', 'there is no pneumothorax']
179179
}
180180
logging.info('Building medical zero-shot classifier')
181181
medical_classifier = zero_shot_classifier_medical(model, text_categories, l2v, args)

llm_caption_contrastive/ac_zero2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ machine_rank: 0
1515
main_training_function: main
1616
mixed_precision: bf16
1717
num_machines: 1
18-
num_processes: 2
18+
num_processes: 4
1919
rdzv_backend: static
2020
same_network: true
2121
tpu_env: []

0 commit comments

Comments
 (0)