@@ -84,11 +84,51 @@ def __getitem__(self, idx):
8484 if self .transform :
8585 image = self .transform (image )
8686
87- # Process caption if tokenizer is provided
88- if self .tokenizer :
89- caption = self .tokenizer ([caption ])[0 ]
87+ # # Process caption if tokenizer is provided
88+ # if self.tokenizer:
89+ # caption = self.tokenizer([caption])[0]
9090
9191 return image , caption
92+ def collate_fn (self , batch ):
93+ images , texts = zip (* batch )
94+ images = torch .stack (images )
95+
96+ # Split texts
97+ texts_2 = []
98+ original_texts = []
99+ for text in texts :
100+ t = text .split ("!@#$%^&*()" )
101+ texts_2 .append (t [1 ] if len (t ) > 1 else "" )
102+ original_texts .append ("" .join (t ))
103+
104+ # Tokenize original texts with padding
105+
106+ original = self .tokenizer (
107+ original_texts ,
108+ return_tensors = "pt" ,
109+ padding = True ,
110+ truncation = True ,
111+ max_length = 512 ,
112+ )
113+
114+ # Process secondary texts and create embed masks
115+ embed_mask = torch .zeros_like (original ["attention_mask" ])
116+ for i , t in enumerate (texts_2 ):
117+ if t : # Only process non-empty secondary texts
118+ ids = self .tokenizer (
119+ [t ],
120+ return_tensors = "pt" ,
121+ padding = True ,
122+ truncation = True ,
123+ max_length = 512 ,
124+ add_special_tokens = False ,
125+ )
126+ if len (ids ["input_ids" ][0 ]) > 0 :
127+ embed_mask [i , - len (ids ["input_ids" ][0 ]):] = 1
128+
129+ original ["embed_mask" ] = embed_mask
130+ return images , original
131+
92132
93133# Example usage:
94134
@@ -799,6 +839,7 @@ def get_cxr_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
799839 pin_memory = True ,
800840 sampler = sampler ,
801841 drop_last = is_train ,
842+ collate_fn = dataset .collate_fn
802843 )
803844 dataloader .num_samples = num_samples
804845 dataloader .num_batches = len (dataloader )
0 commit comments