Skip to content

Commit e98366f

Browse files
authored
Initial implementation of correctness check (microsoft#8)
1 parent 5e2512f commit e98366f

File tree

11 files changed

+483
-88
lines changed

11 files changed

+483
-88
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ To uninstall, you can run the following commands:
127127
jupyter server extension disable coml
128128

129129
# Uninstall the Python package
130-
pip uninstall coml
130+
pip uninstall mlcopilot
131131
```
132132

133133
In development mode, you will also need to remove the symlink created by `jupyter labextension develop` command.

coml/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def load_ipython_extension(ipython):
1717

1818
ipython.register_magics(CoMLMagics)
1919

20+
print(f"CoML {__version__} loaded.")
21+
2022

2123
def _jupyter_labextension_paths():
2224
return [{"src": "labextension", "dest": "coml"}]

coml/core.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
1111

1212
from .prompt_utils import (
13+
CHECK_INSTRUCTION,
1314
EXPLAIN_INSTRUCTION,
1415
FIX_INSTRUCTION,
1516
GENERATE_INSTRUCTION,
17+
SANITY_CHECK_INSTRUCTION,
1618
SUGGEST_INSTRUCTION,
1719
FixContext,
1820
GenerateContext,
@@ -21,9 +23,11 @@
2123
InteractionIncomplete,
2224
cached_fix_fewshots,
2325
cached_generate_fewshots,
26+
render_check_context,
2427
render_fix_context,
2528
render_generate_context,
2629
render_ipython_cells,
30+
render_sanity_check_context,
2731
)
2832

2933

@@ -183,3 +187,45 @@ def explain(self, code: str) -> str:
183187
response = self.llm(messages)
184188
debug_messages(response)
185189
return response.content
190+
191+
def static_check(
192+
self, code: str, context: GenerateContext | FixContext
193+
) -> tuple[bool, str]:
194+
# Check the quality of code by looking at it (i.e., rubberduck)
195+
messages = [
196+
SystemMessage(content=CHECK_INSTRUCTION),
197+
HumanMessage(content=render_check_context(code, context)),
198+
]
199+
debug_messages(*messages)
200+
response = self.llm(messages)
201+
debug_messages(response)
202+
reason, last_line = response.content.rstrip().rsplit("\n", 1)
203+
if "INCORRECT" in last_line.upper():
204+
return False, reason
205+
if "CORRECT" in last_line.upper():
206+
return True, reason
207+
raise ValueError("Unable to parse the response.")
208+
209+
def output_sanity_check(
210+
self,
211+
code: str,
212+
context: GenerateContext | FixContext,
213+
error: str | None,
214+
output: str | None,
215+
) -> tuple[bool, str]:
216+
# Run a sanity check of the output of the code
217+
messages = [
218+
SystemMessage(content=SANITY_CHECK_INSTRUCTION),
219+
HumanMessage(
220+
content=render_sanity_check_context(code, context, error, output)
221+
),
222+
]
223+
debug_messages(*messages)
224+
response = self.llm(messages)
225+
debug_messages(response)
226+
reason, last_line = response.content.rstrip().rsplit("\n", 1)
227+
if "INCORRECT" in last_line.upper():
228+
return False, reason
229+
if "CORRECT" in last_line.upper():
230+
return True, reason
231+
raise ValueError("Unable to parse the response.")

coml/ipython_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import base64
24
import json
35
import re
@@ -80,6 +82,27 @@ def run_code_in_next_cell(python_code: str, metadata: Any = None) -> None:
8082
)
8183

8284

85+
def update_running_cell_metadata(metadata: Any) -> None:
86+
if is_jupyter_lab_environ():
87+
input(
88+
json.dumps(
89+
{"command": "update_running_cell_metadata", "metadata": metadata}
90+
)
91+
)
92+
else:
93+
encoded_metadata = base64.b64encode(json.dumps(metadata).encode()).decode()
94+
display(
95+
Javascript(
96+
"""
97+
const cell = comlGetCurrentCell();
98+
cell.metadata.coml = Object.assign(cell.metadata.coml || {}, JSON.parse(atob(\""""
99+
+ encoded_metadata
100+
+ """\")));
101+
"""
102+
)
103+
)
104+
105+
83106
def get_ipython_history(ipython: InteractiveShell) -> list[str]:
84107
codes = []
85108
for code in ipython.user_ns["In"]:
@@ -98,6 +121,11 @@ def get_ipython_history(ipython: InteractiveShell) -> list[str]:
98121
return codes
99122

100123

124+
def get_running_cell() -> dict[str, Any] | None:
125+
"""See `get_last_cell` for the output format."""
126+
return json.loads(input(json.dumps({"command": "running_cell"})))
127+
128+
101129
def get_last_cell() -> dict[str, Any] | None:
102130
"""The implementation is in nbclassic_init.js. This is a *hacked* RPC channel.
103131

coml/js/nbclassic_init.js

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,52 @@
1-
function getDebugInformation() {
1+
window.comlGetRunningCellIndex = function () {
22
const runningCells = $(".running");
33
if (runningCells.length === 0) {
44
console.warn("No running cell");
55
return null;
66
}
77
const cellIndex = Jupyter.notebook.get_cell_elements().index(runningCells[0]);
8-
if (cellIndex <= 0) {
9-
console.warn("No previous cell");
8+
if (cellIndex < 0) {
9+
console.error("Running cell not found in cell list.");
1010
return null;
1111
}
12-
const cell = IPython.notebook.get_cell(cellIndex - 1);
13-
const cellDump = cell.toJSON();
14-
return cellDump;
12+
return cellIndex;
1513
}
1614

17-
IPython.CodeCell.prototype.native_handle_input_request = IPython.CodeCell.prototype.native_handle_input_request || IPython.CodeCell.prototype._handle_input_request;
18-
IPython.CodeCell.prototype._handle_input_request = function (msg) {
19-
try {
20-
// only apply the hack if the command is valid JSON
21-
const command = JSON.parse(msg.content.prompt);
22-
const kernel = IPython.notebook.kernel;
23-
if (command["command"] === "last_cell") {
24-
kernel.send_input_reply(JSON.stringify(getDebugInformation()));
25-
} else {
26-
console.log("Not a command", msg);
15+
window.comlGetCurrentCell = function () {
16+
const cell = comlGetRunningCellIndex();
17+
if (cell === null) {
18+
return null;
19+
}
20+
return IPython.notebook.get_cell(cell);
21+
}
22+
23+
window.comlGetLastCell = function () {
24+
const cellIndex = comlGetRunningCellIndex();
25+
if (cellIndex === null) {
26+
return null;
27+
}
28+
return IPython.notebook.get_cell(comlGetRunningCellIndex() - 1);
29+
}
30+
31+
if (window.IPython && IPython.CodeCell) {
32+
window.IPythonAvailable = true;
33+
IPython.CodeCell.prototype.native_handle_input_request = IPython.CodeCell.prototype.native_handle_input_request || IPython.CodeCell.prototype._handle_input_request;
34+
IPython.CodeCell.prototype._handle_input_request = function (msg) {
35+
try {
36+
// only apply the hack if the command is valid JSON
37+
const command = JSON.parse(msg.content.prompt);
38+
const kernel = IPython.notebook.kernel;
39+
if (command["command"] === "last_cell") {
40+
kernel.send_input_reply(JSON.stringify(comlGetLastCell().toJSON()));
41+
} else if (command["command"] === "running_cell") {
42+
kernel.send_input_reply(JSON.stringify(comlGetCurrentCell().toJSON()));
43+
} else {
44+
console.log("Not a command", msg);
45+
this.native_handle_input_request(msg);
46+
}
47+
} catch(err) {
48+
console.log("Not a command", msg, err);
2749
this.native_handle_input_request(msg);
2850
}
29-
} catch(err) {
30-
console.log("Not a command", msg, err);
31-
this.native_handle_input_request(msg);
3251
}
33-
}
52+
}

coml/linter.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import json
2+
import tempfile
3+
from io import StringIO
4+
from typing import Literal, Tuple
5+
6+
from pylint.lint import Run as PylintRun
7+
from pylint.reporters import JSONReporter
8+
9+
LinterResult = Literal["error", "warning", "info", "ok"]
10+
11+
12+
def lint(previous_code: str, new_code: str) -> Tuple[LinterResult, str]:
13+
# https://stackoverflow.com/q/75507725/6837658
14+
pylint_options = [
15+
"--disable=C0103", # Invalid name
16+
"--disable=C0114", # Missing module docstring
17+
"--disable=C0304", # Final new line missing
18+
]
19+
previous_lines = previous_code.count("\n") + 1
20+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
21+
f.write(previous_code + "\n" + new_code)
22+
f.flush()
23+
f.seek(0)
24+
25+
reporter_buffer = StringIO()
26+
results = PylintRun(
27+
[f.name] + pylint_options,
28+
reporter=JSONReporter(reporter_buffer),
29+
do_exit=False,
30+
)
31+
# Score is here.
32+
# score = results.linter.stats.global_note
33+
file_results = json.loads(reporter_buffer.getvalue())
34+
file_results = [e for e in file_results if e["line"] > previous_lines]
35+
36+
details = []
37+
for error in file_results:
38+
line = f"{error['line'] - previous_lines}:{error['column']}: {error['message-id']}: {error['message']}"
39+
details.append(line)
40+
details_joined = "\n".join(details)
41+
42+
if any(e["type"] in ("fatal", "error") for e in file_results):
43+
return "error", details_joined
44+
elif any(e["type"] == "warning" for e in file_results):
45+
return "warning", details_joined
46+
elif file_results:
47+
return "info", details_joined
48+
else:
49+
return "ok", "No issues found."

0 commit comments

Comments
 (0)