Skip to content

Commit 6eca224

Browse files
authored
Add optional cache for on-demand variables (#156)
* cache option for on-demand variables * unrelated tweaks and fixes * fix existing tests * update on_demand docstrings * black * update release notes * add test
1 parent fa3eeb4 commit 6eca224

File tree

6 files changed

+111
-18
lines changed

6 files changed

+111
-18
lines changed

doc/whats_new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Enhancements
2525
- Added ``%create_setup`` IPython (Jupyter) magic command to auto-generate code
2626
cells with a new simulation setup from a given model (:issue:`152`). The
2727
command is available after executing ``%load_ext xsimlab.ipython``.
28+
- Added an optional cache for on-demand variables (:issue:`156`). The ``@compute``
29+
decorator now has a ``cache`` option (deactivated by default).
2830

2931
Bug fixes
3032
~~~~~~~~~

xsimlab/model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,8 @@ def __init__(self, processes):
494494
self._index_vars = builder.get_variables(var_type=VarType.INDEX)
495495
self._index_vars_dict = None
496496

497+
self._od_vars = builder.get_variables(var_type=VarType.ON_DEMAND)
498+
497499
builder.ensure_no_intent_conflict()
498500

499501
self._input_vars = builder.get_input_variables()
@@ -787,6 +789,12 @@ def _merge_and_update_state(self, out_states):
787789
for p_obj in self._processes.values():
788790
p_obj.__xsimlab_state__ = self._state
789791

792+
def _clear_od_cache(self):
793+
"""Clear cached values of on-demand variables."""
794+
795+
for key in self._od_vars:
796+
self._state.pop(key, None)
797+
790798
def execute(
791799
self,
792800
stage,
@@ -856,21 +864,23 @@ def execute(
856864
stage = SimulationStage(stage)
857865
execute_args = (stage, runtime_context, hooks, validate)
858866

867+
self._clear_od_cache()
868+
859869
self._call_hooks(hooks, runtime_context, stage, "model", "pre")
860870

861871
if parallel:
862872
dsk = self._build_dask_graph(execute_args)
863873
out_states = dsk_get(dsk, "_gather", scheduler=scheduler)
864874

865875
# TODO: without this -> flaky tests (don't know why)
866-
# state is not well updated -> error when writing output vars in store
876+
# state is not properly updated -> error when writing output vars in store
867877
if isinstance(scheduler, Client):
868878
time.sleep(0.001)
869879

870880
self._merge_and_update_state(out_states)
871881

872882
else:
873-
for p_name, p_obj in self._processes.items():
883+
for p_obj in self._processes.values():
874884
self._execute_process(p_obj, *execute_args)
875885

876886
self._call_hooks(hooks, runtime_context, stage, "model", "post")

xsimlab/process.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,14 @@ def get_from_state(self):
180180
return self.__xsimlab_state__[key]
181181

182182
def get_on_demand(self):
183-
p_name, v_name = self.__xsimlab_od_keys__[var_name]
184-
p_obj = self.__xsimlab_model__._processes[p_name]
185-
return getattr(p_obj, v_name)
183+
key = self.__xsimlab_od_keys__[var_name]
184+
try:
185+
# get from cache
186+
return self.__xsimlab_state__[key]
187+
except KeyError:
188+
p_name, v_name = key
189+
p_obj = self.__xsimlab_model__._processes[p_name]
190+
return getattr(p_obj, v_name)
186191

187192
def put_in_state(self, value):
188193
key = self.__xsimlab_state_keys__[var_name]
@@ -235,16 +240,27 @@ def _make_property_on_demand(var):
235240
This property is a simple wrapper around the variable's compute method.
236241
237242
"""
238-
if "compute" not in var.metadata:
243+
if "compute_method" not in var.metadata:
239244
raise KeyError(
240245
"No compute method found for on_demand variable "
241246
f"'{var.name}'. A method decorated with '@{var.name}.compute' "
242247
"is required in the class definition."
243248
)
244249

245-
get_method = var.metadata["compute"]
250+
var_name = var.name
251+
compute_method = var.metadata["compute_method"]
252+
compute_cache = var.metadata["compute_cache"]
253+
254+
def compute_and_cache(self):
255+
value = compute_method(self)
256+
257+
if compute_cache:
258+
key = self.__xsimlab_od_keys__[var_name]
259+
self.__xsimlab_state__[key] = value
260+
261+
return value
246262

247-
return property(fget=get_method, doc=var_details(var))
263+
return property(fget=compute_and_cache, doc=var_details(var))
248264

249265

250266
def _make_property_group(var):
@@ -315,6 +331,7 @@ def runtime(meth=None, args=None):
315331
- ``batch`` : current simulation number in the batch
316332
- ``sim_start`` : simulation start (date)time
317333
- ``sim_end`` : simulation end (date)time
334+
- ``nsteps``: total number of simulation steps
318335
- ``step`` : current step number
319336
- ``step_start`` : current step start (date)time
320337
- ``step_end``: current step end (date)time

xsimlab/tests/fixture_process.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(self):
119119
model,
120120
state,
121121
state_keys={"some_var": ("some_process", "some_var")},
122+
od_keys={"some_od_var": ("some_process", "some_od_var")},
122123
)
123124
another_process = _init_process(
124125
AnotherProcess,
@@ -146,6 +147,7 @@ def __init__(self):
146147
"group_var": [("some_process", "some_var")],
147148
},
148149
od_keys={
150+
"od_var": ("example_process", "od_var"),
149151
"in_foreign_od_var": ("some_process", "some_od_var"),
150152
"group_var": [("some_process", "some_od_var")],
151153
},

xsimlab/tests/test_model.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,51 @@ def test_context_manager(self):
319319

320320
def test_repr(self, simple_model, simple_model_repr):
321321
assert repr(simple_model) == simple_model_repr
322+
323+
324+
def test_on_demand_cache():
325+
@xs.process
326+
class P1:
327+
var = xs.on_demand(dims="x")
328+
cached_var = xs.on_demand(dims="x")
329+
330+
@var.compute
331+
def _compute_var(self):
332+
return np.random.rand(10)
333+
334+
@cached_var.compute(cache=True)
335+
def _compute_cached_var(self):
336+
return np.random.rand(10)
337+
338+
@xs.process
339+
class P2:
340+
var = xs.foreign(P1, "var")
341+
cached_var = xs.foreign(P1, "cached_var")
342+
view = xs.variable(dims="x", intent="out")
343+
cached_view = xs.variable(dims="x", intent="out")
344+
345+
def run_step(self):
346+
self.view = self.var
347+
self.cached_view = self.cached_var
348+
349+
@xs.process
350+
class P3:
351+
p1_view = xs.foreign(P1, "var")
352+
p1_cached_view = xs.foreign(P1, "cached_var")
353+
p2_view = xs.foreign(P2, "view")
354+
p2_cached_view = xs.foreign(P2, "cached_view")
355+
356+
def initialize(self):
357+
self._p1_cached_view_init = self.p1_cached_view
358+
359+
def run_step(self):
360+
# P1.var's compute method called twice
361+
assert not np.all(self.p1_view == self.p2_view)
362+
# P1.cached_var's compute method called once
363+
assert self.p1_cached_view is self.p2_cached_view
364+
# check cache cleared between simulation stages
365+
assert not np.all(self.p1_cached_view == self._p1_cached_view_init)
366+
367+
model = xs.Model({"p1": P1, "p2": P2, "p3": P3})
368+
model.execute("initialize", {})
369+
model.execute("run_step", {})

xsimlab/variable.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,22 @@ class VarIntent(Enum):
2323
INOUT = "inout"
2424

2525

26-
def compute(self, method):
26+
def compute(self, method=None, *, cache=False):
2727
"""A decorator that, when applied to an on-demand variable, returns a
2828
value for that variable.
2929
3030
"""
31-
self.metadata["compute"] = method
3231

33-
return method
32+
def attach_to_metadata(method):
33+
self.metadata["compute_method"] = method
34+
self.metadata["compute_cache"] = cache
35+
36+
return method
37+
38+
if method is None:
39+
return attach_to_metadata
40+
else:
41+
return attach_to_metadata(method)
3442

3543

3644
# monkey patch, waiting for cleaner solution:
@@ -284,13 +292,8 @@ def on_demand(
284292
Like other variables, such variable should be declared in a
285293
process class. Additionally, it requires its own method to compute
286294
its value, which must be defined in the same class and decorated
287-
(e.g., using `@myvar.compute` if the name of the variable is
288-
`myvar`).
289-
290-
An on-demand variable is always an output variable (i.e., intent='out').
291-
292-
Its computation usually involves other variables, although this is
293-
not required.
295+
(e.g., using ``@myvar.compute`` if the name of the variable is
296+
``myvar``).
294297
295298
These variables may be useful, e.g., for model diagnostics.
296299
@@ -318,6 +321,17 @@ def on_demand(
318321
and 'object_codec'. See :func:`zarr.creation.create` for details
319322
about these options. Other keys are ignored.
320323
324+
Notes
325+
-----
326+
An on-demand variable is always an output variable (i.e., intent='out').
327+
328+
Its computation usually involves other variables, although this is
329+
not required.
330+
331+
It is possible to cache its value at each simulation stage, by applying
332+
the compute decorator like this: ``@myvar.compute(cache=True)``. This is
333+
useful if the variable is meant to be accessed many times in other processes.
334+
321335
See Also
322336
--------
323337
:func:`variable`

0 commit comments

Comments
 (0)