Open
Description
Describe the bug
When loading a model using from_single_file()
, the RAM usage is really high possibly because the weights are loaded in FP32 before conversion.
Reproduction
import threading
import time
import psutil
import torch
from huggingface_hub import hf_hub_download
from diffusers import UNet2DConditionModel
filename = hf_hub_download("stable-diffusion-v1-5/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.safetensors")
stop_monitoring = False
def log_memory_usage():
process = psutil.Process()
mem_info = process.memory_info()
return mem_info.rss / (1024**2) # Convert to MB
def monitor_memory(interval, peak_memory):
while not stop_monitoring:
current_memory = log_memory_usage()
peak_memory[0] = max(peak_memory[0], current_memory)
time.sleep(interval)
def load_model(filename, dtype):
global stop_monitoring
peak_memory = [0] # Use a list to store peak memory so it can be updated in the thread
initial_memory = log_memory_usage()
print(f"Initial memory usage: {initial_memory:.2f} MB")
monitor_thread = threading.Thread(target=monitor_memory, args=(0.01, peak_memory))
monitor_thread.start()
start_time = time.time()
UNet2DConditionModel.from_single_file(filename, torch_dtype=dtype)
end_time = time.time()
stop_monitoring = True
monitor_thread.join() # Wait for the monitoring thread to finish
print(f"Peak memory usage: {peak_memory[0]:.2f} MB")
print(f"Time taken: {end_time - start_time:.2f} seconds")
final_memory = log_memory_usage()
print(f"Final memory usage: {final_memory:.2f} MB")
load_model(filename, torch.float8_e4m3fn)
Logs
Initial memory usage: 737.19 MB
Peak memory usage: 4867.43 MB
Time taken: 0.92 seconds
Final memory usage: 1578.99 MB
System Info
not relevant here