Skip to content

Commit fa3eeb4

Browse files
authored
Ensure finalize simulation stage is always run (#154)
* ensure finalize stage is always run * add test * update doc * black * fix PR link in release notes
1 parent 6385702 commit fa3eeb4

File tree

4 files changed

+50
-20
lines changed

4 files changed

+50
-20
lines changed

doc/framework.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ A model run is divided into four successive stages:
202202
4. finalization
203203

204204
During a simulation, stages 1 and 4 are run only once while stages 2
205-
and 3 are repeated for a given number of (time) steps.
205+
and 3 are repeated for a given number of (time) steps. Stage 4 is run even if
206+
an exception is raised during stage 1, 2 or 3.
206207

207208
Each process-ified class may provide its own computation instructions
208209
by implementing specific methods named ``.initialize()``,

doc/whats_new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ Bug fixes
3636
which inherit from process-decorated classes (:issue:`149`).
3737
- Fix :func:`~xarray.Dataset.xsimlab.update_clocks` when only the master clock is
3838
updated implicitly (:issue:`151`).
39+
- Ensure that the ``finalize`` simulation stage is always run, even when an
40+
exception is raised during the previous stages (:issue:`154`).
3941

4042
v0.4.1 (17 April 2020)
4143
----------------------

xsimlab/drivers.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -330,30 +330,33 @@ def _run(
330330

331331
in_vars = _get_input_vars(ds_init, model)
332332
model.update_state(in_vars, validate=validate_inputs, ignore_static=True)
333-
model.execute("initialize", rt_context, **execute_kwargs)
334333

335-
for step, (_, ds_step) in enumerate(ds_gby_steps):
334+
try:
335+
model.execute("initialize", rt_context, **execute_kwargs)
336336

337-
rt_context.update(
338-
step=step,
339-
step_start=ds_step["_clock_start"].values,
340-
step_end=ds_step["_clock_end"].values,
341-
step_delta=ds_step["_clock_diff"].values,
342-
)
343-
344-
in_vars = _get_input_vars(ds_step, model)
345-
model.update_state(in_vars, validate=validate_inputs, ignore_static=False)
346-
model.execute("run_step", rt_context, **execute_kwargs)
337+
for step, (_, ds_step) in enumerate(ds_gby_steps):
347338

348-
store.write_output_vars(batch, step, model=model)
339+
rt_context.update(
340+
step=step,
341+
step_start=ds_step["_clock_start"].values,
342+
step_end=ds_step["_clock_end"].values,
343+
step_delta=ds_step["_clock_diff"].values,
344+
)
349345

350-
model.execute("finalize_step", rt_context, **execute_kwargs)
346+
in_vars = _get_input_vars(ds_step, model)
347+
model.update_state(in_vars, validate=validate_inputs, ignore_static=False)
348+
model.execute("run_step", rt_context, **execute_kwargs)
351349

352-
store.write_output_vars(batch, -1, model=model)
350+
store.write_output_vars(batch, step, model=model)
353351

354-
model.execute("finalize", rt_context, **execute_kwargs)
352+
model.execute("finalize_step", rt_context, **execute_kwargs)
355353

356-
store.write_index_vars(model=model)
354+
store.write_output_vars(batch, -1, model=model)
355+
store.write_index_vars(model=model)
356+
except Exception as error:
357+
raise error
358+
finally:
359+
model.execute("finalize", rt_context, **execute_kwargs)
357360

358361

359362
class XarraySimulationDriver(BaseSimulationDriver):

xsimlab/tests/test_drivers.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_runtime_context_in_model(in_dataset, model):
3939
class P:
4040
@xs.runtime(args="not_a_runtime_arg")
4141
def run_step(self, arg):
42-
pass
42+
print(arg)
4343

4444
m = model.update_processes({"p": P})
4545

@@ -93,7 +93,7 @@ def test_constructor(self, in_dataset, model):
9393
with pytest.raises(KeyError, match=r"Missing variables.*"):
9494
XarraySimulationDriver(invalid_ds, model)
9595

96-
def test_run_model_get_results(self, in_dataset, out_dataset, xarray_driver):
96+
def test_run_model_get_results(self, out_dataset, xarray_driver):
9797
xarray_driver.run_model()
9898
out_ds_actual = xarray_driver.get_results()
9999

@@ -124,3 +124,27 @@ def test_multi_index(self, in_dataset, model):
124124
out_dataset = driver.get_results()
125125

126126
pd.testing.assert_index_equal(out_dataset.indexes["dummy"], midx)
127+
128+
129+
def test_finalize_always_called():
130+
@xs.process
131+
class P:
132+
var = xs.variable(intent="out")
133+
134+
def initialize(self):
135+
self.var = "initialized"
136+
raise RuntimeError()
137+
138+
def finalize(self):
139+
self.var = "finalized"
140+
141+
model = xs.Model({"p": P})
142+
in_dataset = xs.create_setup(model=model, clocks={"clock": [0, 1]})
143+
driver = XarraySimulationDriver(in_dataset, model)
144+
145+
try:
146+
driver.run_model()
147+
except RuntimeError:
148+
pass
149+
150+
assert model.state[("p", "var")] == "finalized"

0 commit comments

Comments
 (0)