Skip to content

Commit bc14560

Browse files
rh314fchollet
authored andcommitted
Fix function serialization + deserialization where the function has values captured in its closure. (keras-team#8592)
* Fix function serialization + deserialization where the function has values captured in its closure. Also updated test. * Fix PEP8 issue for Travis * PEP8 linter fixes again... * As per review request: https://github.com/fchollet/keras/pull/8592/files/c49ce42e5b5d9071df73854b95361b6915ce281f#r153083188 - Test now includes original function. - Test is now modular with @pytest.mark.parametrize making it easy to add more tests in this category. * Make ensure_value_to_cell function private (used inside func_load function)
1 parent da1ff3d commit bc14560

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

keras/utils/generic_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,27 @@ def func_load(code, defaults=None, closure=None, globs=None):
199199
code, defaults, closure = code
200200
if isinstance(defaults, list):
201201
defaults = tuple(defaults)
202+
203+
def ensure_value_to_cell(value):
204+
"""Ensures that a value is converted to a python cell object.
205+
206+
# Arguments
207+
value: Any value that needs to be casted to the cell type
208+
209+
# Returns
210+
A value wrapped as a cell object (see function "func_load")
211+
212+
"""
213+
def dummy_fn():
214+
value # just access it so it gets captured in .__closure__
215+
cell_value = dummy_fn.__closure__[0]
216+
if not isinstance(value, type(cell_value)):
217+
return cell_value
218+
else:
219+
return value
220+
221+
if closure is not None:
222+
closure = tuple(ensure_value_to_cell(_) for _ in closure)
202223
raw_code = codecs.decode(code.encode('ascii'), 'base64')
203224
code = marshal.loads(raw_code)
204225
if globs is None:

tests/keras/utils/generic_utils_test.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,25 @@ def test_has_arg_positional_only():
7878
assert has_arg(pow, 'x') is False
7979

8080

81-
def test_func_dump_and_load():
82-
def test_func():
83-
return r'\u'
81+
@pytest.mark.parametrize(
82+
'test_funcion_type',
83+
('simple function', 'closured function'))
84+
def test_func_dump_and_load(test_funcion_type):
85+
86+
if test_funcion_type == 'simple function':
87+
def test_func():
88+
return r'\u'
89+
elif test_funcion_type == 'closured function':
90+
def get_test_func():
91+
x = r'\u'
92+
93+
def test_func():
94+
return x
95+
return test_func
96+
test_func = get_test_func()
97+
else:
98+
raise Exception('Unknown test case for test_func_dump_and_load')
99+
84100
serialized = func_dump(test_func)
85101
deserialized = func_load(serialized)
86102
assert deserialized.__code__ == test_func.__code__

0 commit comments

Comments
 (0)