Skip to content

Commit a395758

Browse files
add collate
1 parent 897c376 commit a395758

File tree

3 files changed

+51
-10
lines changed

3 files changed

+51
-10
lines changed

llm2clip/run.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ python -m torch.distributed.launch --nproc_per_node=2 \
55
--enable-deepspeed \
66
--grad-checkpointing \
77
--name="T_vitl336_mimic" \
8-
--save-frequency 1 \
9-
--zeroshot-frequency 1 \
8+
--save-frequency 2 \
9+
--zeroshot-frequency 2 \
1010
--report-to="tensorboard, wandb" \
1111
--wandb-project-name="LLM2CLIP" \
1212
--wandb-notes="EVA02-CLIP-L-14-336" \
1313
--train-data "/data/csv/llm2clip/mimic_clip.csv" \
1414
--pretrained=${PRETRAINED} \
1515
--precision "fp16" \
1616
--warmup 0 \
17-
--batch-size=256 \
18-
--eval-batch-size=256 \
19-
--log-every-n-steps 100 \
17+
--batch-size=150 \
18+
--eval-batch-size=150 \
19+
--log-every-n-steps 200 \
2020
--epochs=20 \
2121
--lr=1e-5 \
2222
--visual-lr=1e-5 \
@@ -41,4 +41,4 @@ python -m torch.distributed.launch --nproc_per_node=2 \
4141
--zero-stage=1 \
4242
--dataset-type "cxr" \
4343
--csv-img-key "img_path" \
44-
--csv-caption-key "caption"
44+
--csv-caption-key "caption2"

llm2clip/training/data.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,51 @@ def __getitem__(self, idx):
8484
if self.transform:
8585
image = self.transform(image)
8686

87-
# Process caption if tokenizer is provided
88-
if self.tokenizer:
89-
caption = self.tokenizer([caption])[0]
87+
# # Process caption if tokenizer is provided
88+
# if self.tokenizer:
89+
# caption = self.tokenizer([caption])[0]
9090

9191
return image, caption
92+
def collate_fn(self, batch):
93+
images, texts = zip(*batch)
94+
images = torch.stack(images)
95+
96+
# Split texts
97+
texts_2 = []
98+
original_texts = []
99+
for text in texts:
100+
t = text.split("!@#$%^&*()")
101+
texts_2.append(t[1] if len(t) > 1 else "")
102+
original_texts.append("".join(t))
103+
104+
# Tokenize original texts with padding
105+
106+
original = self.tokenizer(
107+
original_texts,
108+
return_tensors="pt",
109+
padding=True,
110+
truncation=True,
111+
max_length=512,
112+
)
113+
114+
# Process secondary texts and create embed masks
115+
embed_mask = torch.zeros_like(original["attention_mask"])
116+
for i, t in enumerate(texts_2):
117+
if t: # Only process non-empty secondary texts
118+
ids = self.tokenizer(
119+
[t],
120+
return_tensors="pt",
121+
padding=True,
122+
truncation=True,
123+
max_length=512,
124+
add_special_tokens=False,
125+
)
126+
if len(ids["input_ids"][0]) > 0:
127+
embed_mask[i, -len(ids["input_ids"][0]):] = 1
128+
129+
original["embed_mask"] = embed_mask
130+
return images, original
131+
92132

93133
# Example usage:
94134

@@ -799,6 +839,7 @@ def get_cxr_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
799839
pin_memory=True,
800840
sampler=sampler,
801841
drop_last=is_train,
842+
collate_fn=dataset.collate_fn
802843
)
803844
dataloader.num_samples = num_samples
804845
dataloader.num_batches = len(dataloader)

llm2clip/training/params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def parse_args(args):
168168
parser.add_argument(
169169
"--logs",
170170
type=str,
171-
default="logs",
171+
default="/model/llm2clip/logs",
172172
help="Where to store tensorboard logs. Use None to avoid storing logs.",
173173
)
174174
parser.add_argument(

0 commit comments

Comments
 (0)