Skip to content

Commit 9cfd4ef

Browse files
authored
Make BaseOutput dataclasses picklable (huggingface#5234)
* Make BaseOutput dataclasses picklable * make style * Test * Empty commit * Simpler and safer
1 parent 78a7851 commit 9cfd4ef

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/diffusers/utils/outputs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717

1818
from collections import OrderedDict
19-
from dataclasses import fields
19+
from dataclasses import fields, is_dataclass
2020
from typing import Any, Tuple
2121

2222
import numpy as np
@@ -101,6 +101,13 @@ def __setitem__(self, key, value):
101101
# Don't call self.__setattr__ to avoid recursion errors
102102
super().__setattr__(key, value)
103103

104+
def __reduce__(self):
105+
if not is_dataclass(self):
106+
return super().__reduce__()
107+
callable, _args, *remaining = super().__reduce__()
108+
args = tuple(getattr(self, field.name) for field in fields(self))
109+
return callable, args, *remaining
110+
104111
def to_tuple(self) -> Tuple[Any]:
105112
"""
106113
Convert self to a tuple containing all the attributes/keys that are not `None`.

tests/others/test_outputs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pickle as pkl
12
import unittest
23
from dataclasses import dataclass
34
from typing import List, Union
@@ -58,3 +59,13 @@ def test_outputs_dict_init(self):
5859
assert isinstance(outputs["images"][0], PIL.Image.Image)
5960
assert isinstance(outputs[0], list)
6061
assert isinstance(outputs[0][0], PIL.Image.Image)
62+
63+
def test_outputs_serialization(self):
64+
outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
65+
serialized = pkl.dumps(outputs_orig)
66+
outputs_copy = pkl.loads(serialized)
67+
68+
# Check original and copy are equal
69+
assert dir(outputs_orig) == dir(outputs_copy)
70+
assert dict(outputs_orig) == dict(outputs_copy)
71+
assert vars(outputs_orig) == vars(outputs_copy)

0 commit comments

Comments
 (0)