Skip to content

Commit e7e6d85

Browse files
authored
[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)
* memory usage tests * fixes * gguf
1 parent 8eefed6 commit e7e6d85

File tree

11 files changed

+136
-105
lines changed

11 files changed

+136
-105
lines changed

src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def create_quantized_param(
135135
target_device: "torch.device",
136136
state_dict: Dict[str, Any],
137137
unexpected_keys: Optional[List[str]] = None,
138+
**kwargs,
138139
):
139140
import bitsandbytes as bnb
140141

@@ -445,6 +446,7 @@ def create_quantized_param(
445446
target_device: "torch.device",
446447
state_dict: Dict[str, Any],
447448
unexpected_keys: Optional[List[str]] = None,
449+
**kwargs,
448450
):
449451
import bitsandbytes as bnb
450452

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def create_quantized_param(
108108
target_device: "torch.device",
109109
state_dict: Optional[Dict[str, Any]] = None,
110110
unexpected_keys: Optional[List[str]] = None,
111+
**kwargs,
111112
):
112113
module, tensor_name = get_module_from_name(model, param_name)
113114
if tensor_name not in module._parameters and tensor_name not in module._buffers:

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def create_quantized_param(
215215
target_device: "torch.device",
216216
state_dict: Dict[str, Any],
217217
unexpected_keys: List[str],
218+
**kwargs,
218219
):
219220
r"""
220221
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,

tests/quantization/__init__.py

Whitespace-only changes.

tests/quantization/bnb/test_4bit.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,29 +54,8 @@ def get_some_linear_layer(model):
5454

5555
if is_torch_available():
5656
import torch
57-
import torch.nn as nn
5857

59-
class LoRALayer(nn.Module):
60-
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
61-
62-
Taken from
63-
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
64-
"""
65-
66-
def __init__(self, module: nn.Module, rank: int):
67-
super().__init__()
68-
self.module = module
69-
self.adapter = nn.Sequential(
70-
nn.Linear(module.in_features, rank, bias=False),
71-
nn.Linear(rank, module.out_features, bias=False),
72-
)
73-
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
74-
nn.init.normal_(self.adapter[0].weight, std=small_std)
75-
nn.init.zeros_(self.adapter[1].weight)
76-
self.adapter.to(module.weight.device)
77-
78-
def forward(self, input, *args, **kwargs):
79-
return self.module(input, *args, **kwargs) + self.adapter(input)
58+
from ..utils import LoRALayer, get_memory_consumption_stat
8059

8160

8261
if is_bitsandbytes_available():
@@ -96,6 +75,8 @@ class Base4bitTests(unittest.TestCase):
9675
# This was obtained on audace so the number might slightly change
9776
expected_rel_difference = 3.69
9877

78+
expected_memory_saving_ratio = 0.8
79+
9980
prompt = "a beautiful sunset amidst the mountains."
10081
num_inference_steps = 10
10182
seed = 0
@@ -140,8 +121,10 @@ def setUp(self):
140121
)
141122

142123
def tearDown(self):
143-
del self.model_fp16
144-
del self.model_4bit
124+
if hasattr(self, "model_fp16"):
125+
del self.model_fp16
126+
if hasattr(self, "model_4bit"):
127+
del self.model_4bit
145128

146129
gc.collect()
147130
torch.cuda.empty_cache()
@@ -180,6 +163,32 @@ def test_memory_footprint(self):
180163
linear = get_some_linear_layer(self.model_4bit)
181164
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
182165

166+
def test_model_memory_usage(self):
167+
# Delete to not let anything interfere.
168+
del self.model_4bit, self.model_fp16
169+
170+
# Re-instantiate.
171+
inputs = self.get_dummy_inputs()
172+
inputs = {
173+
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
174+
}
175+
model_fp16 = SD3Transformer2DModel.from_pretrained(
176+
self.model_name, subfolder="transformer", torch_dtype=torch.float16
177+
).to(torch_device)
178+
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
179+
del model_fp16
180+
181+
nf4_config = BitsAndBytesConfig(
182+
load_in_4bit=True,
183+
bnb_4bit_quant_type="nf4",
184+
bnb_4bit_compute_dtype=torch.float16,
185+
)
186+
model_4bit = SD3Transformer2DModel.from_pretrained(
187+
self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16
188+
)
189+
quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
190+
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
191+
183192
def test_original_dtype(self):
184193
r"""
185194
A simple test to check if the model succesfully stores the original dtype

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -60,29 +60,8 @@ def get_some_linear_layer(model):
6060

6161
if is_torch_available():
6262
import torch
63-
import torch.nn as nn
6463

65-
class LoRALayer(nn.Module):
66-
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
67-
68-
Taken from
69-
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
70-
"""
71-
72-
def __init__(self, module: nn.Module, rank: int):
73-
super().__init__()
74-
self.module = module
75-
self.adapter = nn.Sequential(
76-
nn.Linear(module.in_features, rank, bias=False),
77-
nn.Linear(rank, module.out_features, bias=False),
78-
)
79-
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
80-
nn.init.normal_(self.adapter[0].weight, std=small_std)
81-
nn.init.zeros_(self.adapter[1].weight)
82-
self.adapter.to(module.weight.device)
83-
84-
def forward(self, input, *args, **kwargs):
85-
return self.module(input, *args, **kwargs) + self.adapter(input)
64+
from ..utils import LoRALayer, get_memory_consumption_stat
8665

8766

8867
if is_bitsandbytes_available():
@@ -102,6 +81,8 @@ class Base8bitTests(unittest.TestCase):
10281
# This was obtained on audace so the number might slightly change
10382
expected_rel_difference = 1.94
10483

84+
expected_memory_saving_ratio = 0.7
85+
10586
prompt = "a beautiful sunset amidst the mountains."
10687
num_inference_steps = 10
10788
seed = 0
@@ -142,8 +123,10 @@ def setUp(self):
142123
)
143124

144125
def tearDown(self):
145-
del self.model_fp16
146-
del self.model_8bit
126+
if hasattr(self, "model_fp16"):
127+
del self.model_fp16
128+
if hasattr(self, "model_8bit"):
129+
del self.model_8bit
147130

148131
gc.collect()
149132
torch.cuda.empty_cache()
@@ -182,6 +165,28 @@ def test_memory_footprint(self):
182165
linear = get_some_linear_layer(self.model_8bit)
183166
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
184167

168+
def test_model_memory_usage(self):
169+
# Delete to not let anything interfere.
170+
del self.model_8bit, self.model_fp16
171+
172+
# Re-instantiate.
173+
inputs = self.get_dummy_inputs()
174+
inputs = {
175+
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
176+
}
177+
model_fp16 = SD3Transformer2DModel.from_pretrained(
178+
self.model_name, subfolder="transformer", torch_dtype=torch.float16
179+
).to(torch_device)
180+
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
181+
del model_fp16
182+
183+
config = BitsAndBytesConfig(load_in_8bit=True)
184+
model_8bit = SD3Transformer2DModel.from_pretrained(
185+
self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
186+
)
187+
quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
188+
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
189+
185190
def test_original_dtype(self):
186191
r"""
187192
A simple test to check if the model succesfully stores the original dtype
@@ -248,7 +253,7 @@ def test_llm_skip(self):
248253
self.assertTrue(linear.weight.dtype == torch.int8)
249254
self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))
250255

251-
self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
256+
self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear))
252257
self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)
253258

254259
def test_config_from_pretrained(self):

tests/quantization/quanto/__init__.py

Whitespace-only changes.

tests/quantization/quanto/test_quanto.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,8 @@
1919

2020
if is_torch_available():
2121
import torch
22-
import torch.nn as nn
2322

24-
class LoRALayer(nn.Module):
25-
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
26-
27-
Taken from
28-
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
29-
"""
30-
31-
def __init__(self, module: nn.Module, rank: int):
32-
super().__init__()
33-
self.module = module
34-
self.adapter = nn.Sequential(
35-
nn.Linear(module.in_features, rank, bias=False),
36-
nn.Linear(rank, module.out_features, bias=False),
37-
)
38-
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
39-
nn.init.normal_(self.adapter[0].weight, std=small_std)
40-
nn.init.zeros_(self.adapter[1].weight)
41-
self.adapter.to(module.weight.device)
42-
43-
def forward(self, input, *args, **kwargs):
44-
return self.module(input, *args, **kwargs) + self.adapter(input)
23+
from ..utils import LoRALayer, get_memory_consumption_stat
4524

4625

4726
@nightly
@@ -85,20 +64,20 @@ def test_quanto_layers(self):
8564
assert isinstance(module, QLinear)
8665

8766
def test_quanto_memory_usage(self):
88-
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
89-
unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3
90-
91-
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
9267
inputs = self.get_dummy_inputs()
68+
inputs = {
69+
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
70+
}
9371

94-
torch.cuda.reset_peak_memory_stats()
95-
torch.cuda.empty_cache()
72+
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
73+
unquantized_model.to(torch_device)
74+
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
9675

97-
model.to(torch_device)
98-
with torch.no_grad():
99-
model(**inputs)
100-
max_memory = torch.cuda.max_memory_allocated() / 1024**3
101-
assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction
76+
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
77+
quantized_model.to(torch_device)
78+
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
79+
80+
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
10281

10382
def test_keep_modules_in_fp32(self):
10483
r"""
@@ -318,14 +297,14 @@ def test_training(self):
318297

319298

320299
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
321-
expected_memory_reduction = 0.3
300+
expected_memory_reduction = 0.6
322301

323302
def get_dummy_init_kwargs(self):
324303
return {"weights_dtype": "float8"}
325304

326305

327306
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
328-
expected_memory_reduction = 0.3
307+
expected_memory_reduction = 0.6
329308
_test_torch_compile = True
330309

331310
def get_dummy_init_kwargs(self):

tests/quantization/torchao/__init__.py

Whitespace-only changes.

tests/quantization/torchao/test_torchao.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,7 @@
5050
import torch
5151
import torch.nn as nn
5252

53-
class LoRALayer(nn.Module):
54-
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
55-
56-
Taken from
57-
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
58-
"""
59-
60-
def __init__(self, module: nn.Module, rank: int):
61-
super().__init__()
62-
self.module = module
63-
self.adapter = nn.Sequential(
64-
nn.Linear(module.in_features, rank, bias=False),
65-
nn.Linear(rank, module.out_features, bias=False),
66-
)
67-
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
68-
nn.init.normal_(self.adapter[0].weight, std=small_std)
69-
nn.init.zeros_(self.adapter[1].weight)
70-
self.adapter.to(module.weight.device)
71-
72-
def forward(self, input, *args, **kwargs):
73-
return self.module(input, *args, **kwargs) + self.adapter(input)
53+
from ..utils import LoRALayer, get_memory_consumption_stat
7454

7555

7656
if is_torchao_available():
@@ -503,6 +483,22 @@ def test_memory_footprint(self):
503483
# there is additional overhead of scales and zero points
504484
self.assertTrue(total_bf16 < total_int4wo)
505485

486+
def test_model_memory_usage(self):
487+
model_id = "hf-internal-testing/tiny-flux-pipe"
488+
expected_memory_saving_ratio = 2.0
489+
490+
inputs = self.get_dummy_tensor_inputs(device=torch_device)
491+
492+
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
493+
transformer_bf16.to(torch_device)
494+
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
495+
del transformer_bf16
496+
497+
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
498+
transformer_int8wo.to(torch_device)
499+
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
500+
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
501+
506502
def test_wrong_config(self):
507503
with self.assertRaises(ValueError):
508504
self.get_dummy_components(TorchAoConfig("int42"))

tests/quantization/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from diffusers.utils import is_torch_available
2+
3+
4+
if is_torch_available():
5+
import torch
6+
import torch.nn as nn
7+
8+
class LoRALayer(nn.Module):
9+
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
10+
11+
Taken from
12+
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
13+
"""
14+
15+
def __init__(self, module: nn.Module, rank: int):
16+
super().__init__()
17+
self.module = module
18+
self.adapter = nn.Sequential(
19+
nn.Linear(module.in_features, rank, bias=False),
20+
nn.Linear(rank, module.out_features, bias=False),
21+
)
22+
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
23+
nn.init.normal_(self.adapter[0].weight, std=small_std)
24+
nn.init.zeros_(self.adapter[1].weight)
25+
self.adapter.to(module.weight.device)
26+
27+
def forward(self, input, *args, **kwargs):
28+
return self.module(input, *args, **kwargs) + self.adapter(input)
29+
30+
@torch.no_grad()
31+
@torch.inference_mode()
32+
def get_memory_consumption_stat(model, inputs):
33+
torch.cuda.reset_peak_memory_stats()
34+
torch.cuda.empty_cache()
35+
36+
model(**inputs)
37+
max_memory_mem_allocated = torch.cuda.max_memory_allocated()
38+
return max_memory_mem_allocated

0 commit comments

Comments
 (0)