@@ -132,9 +132,58 @@ def _pinned_memory_tensors(self):
132132 finally :
133133 pinned_dict = None
134134
135+ def _transfer_tensor_to_device (self , tensor , source_tensor , current_stream = None ):
136+ tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
137+ if self .record_stream and current_stream is not None :
138+ tensor .data .record_stream (current_stream )
139+
140+ def _process_tensors_from_modules (self , pinned_memory = None , current_stream = None ):
141+ for group_module in self .modules :
142+ for param in group_module .parameters ():
143+ source = pinned_memory [param ] if pinned_memory else param .data
144+ self ._transfer_tensor_to_device (param , source , current_stream )
145+ for buffer in group_module .buffers ():
146+ source = pinned_memory [buffer ] if pinned_memory else buffer .data
147+ self ._transfer_tensor_to_device (buffer , source , current_stream )
148+
149+ for param in self .parameters :
150+ source = pinned_memory [param ] if pinned_memory else param .data
151+ self ._transfer_tensor_to_device (param , source , current_stream )
152+
153+ for buffer in self .buffers :
154+ source = pinned_memory [buffer ] if pinned_memory else buffer .data
155+ self ._transfer_tensor_to_device (buffer , source , current_stream )
156+
157+ def _onload_from_disk (self , current_stream ):
158+ if self .stream is not None :
159+ loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
160+
161+ for key , tensor_obj in self .key_to_tensor .items ():
162+ self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
163+
164+ with self ._pinned_memory_tensors () as pinned_memory :
165+ for key , tensor_obj in self .key_to_tensor .items ():
166+ self ._transfer_tensor_to_device (tensor_obj , pinned_memory [tensor_obj ], current_stream )
167+
168+ self .cpu_param_dict .clear ()
169+
170+ else :
171+ onload_device = (
172+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
173+ )
174+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
175+ for key , tensor_obj in self .key_to_tensor .items ():
176+ tensor_obj .data = loaded_tensors [key ]
177+
178+ def _onload_from_memory (self , current_stream ):
179+ if self .stream is not None :
180+ with self ._pinned_memory_tensors () as pinned_memory :
181+ self ._process_tensors_from_modules (pinned_memory , current_stream )
182+ else :
183+ self ._process_tensors_from_modules (None , current_stream )
184+
135185 @torch .compiler .disable ()
136186 def onload_ (self ):
137- r"""Onloads the group of modules to the onload_device."""
138187 torch_accelerator_module = (
139188 getattr (torch , torch .accelerator .current_accelerator ().type )
140189 if hasattr (torch , "accelerator" )
@@ -172,67 +221,30 @@ def onload_(self):
172221 self .stream .synchronize ()
173222
174223 with context :
175- if self .stream is not None :
176- with self ._pinned_memory_tensors () as pinned_memory :
177- for group_module in self .modules :
178- for param in group_module .parameters ():
179- param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
180- if self .record_stream :
181- param .data .record_stream (current_stream )
182- for buffer in group_module .buffers ():
183- buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
184- if self .record_stream :
185- buffer .data .record_stream (current_stream )
186-
187- for param in self .parameters :
188- param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
189- if self .record_stream :
190- param .data .record_stream (current_stream )
191-
192- for buffer in self .buffers :
193- buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
194- if self .record_stream :
195- buffer .data .record_stream (current_stream )
196-
224+ if self .offload_to_disk_path :
225+ self ._onload_from_disk (current_stream )
197226 else :
198- for group_module in self .modules :
199- for param in group_module .parameters ():
200- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
201- for buffer in group_module .buffers ():
202- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
203-
204- for param in self .parameters :
205- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
206-
207- for buffer in self .buffers :
208- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
209- if self .record_stream :
210- buffer .data .record_stream (current_stream )
211-
212- @torch .compiler .disable ()
213- def offload_ (self ):
214- r"""Offloads the group of modules to the offload_device."""
215- if self .offload_to_disk_path :
216- # TODO: we can potentially optimize this code path by checking if the _all_ the desired
217- # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
218- # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
219- # we perform a write.
220- # Check if the file has been saved in this session or if it already exists on disk.
221- if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
222- os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
223- tensors_to_save = {
224- key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()
225- }
226- safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
227-
228- # The group is now considered offloaded to disk for the rest of the session.
229- self ._is_offloaded_to_disk = True
230-
231- # We do this to free up the RAM which is still holding the up tensor data.
232- for tensor_obj in self .tensor_to_key .keys ():
233- tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
234- return
235-
227+ self ._onload_from_memory (current_stream )
228+
229+ def _offload_to_disk (self ):
230+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
231+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
232+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
233+ # we perform a write.
234+ # Check if the file has been saved in this session or if it already exists on disk.
235+ if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
236+ os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
237+ tensors_to_save = {key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()}
238+ safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
239+
240+ # The group is now considered offloaded to disk for the rest of the session.
241+ self ._is_offloaded_to_disk = True
242+
243+ # We do this to free up the RAM which is still holding the up tensor data.
244+ for tensor_obj in self .tensor_to_key .keys ():
245+ tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
246+
247+ def _offload_to_memory (self ):
236248 torch_accelerator_module = (
237249 getattr (torch , torch .accelerator .current_accelerator ().type )
238250 if hasattr (torch , "accelerator" )
@@ -257,6 +269,14 @@ def offload_(self):
257269 for buffer in self .buffers :
258270 buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
259271
272+ @torch .compiler .disable ()
273+ def offload_ (self ):
274+ r"""Offloads the group of modules to the offload_device."""
275+ if self .offload_to_disk_path :
276+ self ._offload_to_disk ()
277+ else :
278+ self ._offload_to_memory ()
279+
260280
261281class GroupOffloadingHook (ModelHook ):
262282 r"""
0 commit comments