Skip to content

Commit de88e8f

Browse files
authored
feat: support context protocol in PythonInterpreter (stanfordnlp#7984)
1 parent 2159597 commit de88e8f

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

dspy/primitives/python_interpreter.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
22
import subprocess
3+
from types import TracebackType
34
from typing import Any, Dict, List, Optional
45
import os
56

7+
68
class InterpreterError(ValueError):
79
pass
810

11+
912
class PythonInterpreter:
1013
r"""
1114
PythonInterpreter that runs code in a sandboxed environment using Deno and Pyodide.
@@ -16,22 +19,15 @@ class PythonInterpreter:
1619
Example Usage:
1720
```python
1821
code_string = "print('Hello'); 1 + 2"
19-
interp = PythonInterpreter()
20-
output = interp(code_string)
21-
print(output) # If final statement is non-None, prints the numeric result, else prints captured output
22-
interp.shutdown()
22+
with PythonInterpreter() as interp:
23+
output = interp(code_string) # If final statement is non-None, prints the numeric result, else prints captured output
2324
```
2425
"""
2526

26-
def __init__(
27-
self,
28-
deno_command: Optional[List[str]] = None
29-
) -> None:
27+
def __init__(self, deno_command: Optional[List[str]] = None) -> None:
3028
if isinstance(deno_command, dict):
3129
deno_command = None # no-op, just a guard in case someone passes a dict
32-
self.deno_command = deno_command or [
33-
"deno", "run", "--allow-read", self._get_runner_path()
34-
]
30+
self.deno_command = deno_command or ["deno", "run", "--allow-read", self._get_runner_path()]
3531
self.deno_process = None
3632

3733
def _get_runner_path(self) -> str:
@@ -46,7 +42,7 @@ def _ensure_deno_process(self) -> None:
4642
stdin=subprocess.PIPE,
4743
stdout=subprocess.PIPE,
4844
stderr=subprocess.PIPE,
49-
text=True
45+
text=True,
5046
)
5147
except FileNotFoundError as e:
5248
install_instructions = (
@@ -77,7 +73,7 @@ def _serialize_value(self, value: Any) -> str:
7773
elif isinstance(value, (int, float, bool)):
7874
return str(value)
7975
elif value is None:
80-
return 'None'
76+
return "None"
8177
elif isinstance(value, list) or isinstance(value, dict):
8278
return json.dumps(value)
8379
else:
@@ -129,6 +125,18 @@ def execute(
129125
# If there's no error, return the "output" field
130126
return result.get("output", None)
131127

128+
def __enter__(self):
129+
return self
130+
131+
# All exception fields are ignored and the runtime will automatically re-raise the exception
132+
def __exit__(
133+
self,
134+
_exc_type: Optional[type[BaseException]],
135+
_exc_val: Optional[BaseException],
136+
_exc_tb: Optional[TracebackType],
137+
):
138+
self.shutdown()
139+
132140
def __call__(
133141
self,
134142
code: str,

tests/primitives/test_python_interpreter.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,35 @@
1010

1111

1212
def test_execute_simple_code():
13-
interpreter = PythonInterpreter()
14-
code = "print('Hello, World!')"
15-
result = interpreter.execute(code)
16-
assert result == "Hello, World!\n", "Simple print statement should return 'Hello World!\n'"
13+
with PythonInterpreter() as interpreter:
14+
code = "print('Hello, World!')"
15+
result = interpreter.execute(code)
16+
assert result == "Hello, World!\n", "Simple print statement should return 'Hello World!\n'"
1717

1818

1919
def test_import():
20-
interpreter = PythonInterpreter()
21-
code = "import math\nresult = math.sqrt(4)\nresult"
22-
result = interpreter.execute(code)
23-
assert result == 2, "Should be able to import and use math.sqrt"
20+
with PythonInterpreter() as interpreter:
21+
code = "import math\nresult = math.sqrt(4)\nresult"
22+
result = interpreter.execute(code)
23+
assert result == 2, "Should be able to import and use math.sqrt"
2424

2525

2626
def test_user_variable_definitions():
27-
interpreter = PythonInterpreter()
28-
code = "result = number + 1\nresult"
29-
result = interpreter.execute(code, variables={"number": 4})
30-
assert result == 5, "User variable assignment should work"
27+
with PythonInterpreter() as interpreter:
28+
code = "result = number + 1\nresult"
29+
result = interpreter.execute(code, variables={"number": 4})
30+
assert result == 5, "User variable assignment should work"
3131

3232

3333
def test_failure_syntax_error():
34-
interpreter = PythonInterpreter()
35-
code = "+++"
36-
with pytest.raises(SyntaxError, match="Invalid Python syntax"):
37-
interpreter.execute(code)
34+
with PythonInterpreter() as interpreter:
35+
code = "+++"
36+
with pytest.raises(SyntaxError, match="Invalid Python syntax"):
37+
interpreter.execute(code)
3838

3939

4040
def test_failure_zero_division():
41-
interpreter = PythonInterpreter()
42-
code = "1+0/0"
43-
with pytest.raises(InterpreterError, match="ZeroDivisionError"):
44-
interpreter.execute(code)
41+
with PythonInterpreter() as interpreter:
42+
code = "1+0/0"
43+
with pytest.raises(InterpreterError, match="ZeroDivisionError"):
44+
interpreter.execute(code)

0 commit comments

Comments
 (0)