Skip to content

Commit 9a4d670

Browse files
cxxxxxnultmaster
andauthored
fix: line chart verification bug && unknow status (microsoft#15)
* fix: line chart verification bug && unknow status * Minor improvements * Reformat * Minor improvements --------- Co-authored-by: Yuge Zhang <[email protected]>
1 parent 1329ffc commit 9a4d670

File tree

4 files changed

+113
-49
lines changed

4 files changed

+113
-49
lines changed

coml/core.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def explain(self, code: str) -> str:
195195

196196
def static_check(
197197
self, code: str, context: GenerateContext | FixContext
198-
) -> tuple[bool, str]:
198+
) -> tuple[bool | None, str]:
199199
# Check the quality of code by looking at it (i.e., rubberduck)
200200
messages = [
201201
SystemMessage(content=CHECK_INSTRUCTION),
@@ -209,15 +209,15 @@ def static_check(
209209
return False, reason
210210
if "CORRECT" in last_line.upper():
211211
return True, reason
212-
raise ValueError("Unable to parse the response.")
212+
return None, response.content
213213

214214
def output_sanity_check(
215215
self,
216216
code: str,
217217
context: GenerateContext | FixContext,
218218
error: str | None,
219219
output: str | None,
220-
) -> tuple[bool, str]:
220+
) -> tuple[bool | None, str]:
221221
# Run a sanity check of the output of the code
222222
messages = [
223223
SystemMessage(content=SANITY_CHECK_INSTRUCTION),
@@ -233,7 +233,7 @@ def output_sanity_check(
233233
return False, reason
234234
if "CORRECT" in last_line.upper():
235235
return True, reason
236-
raise ValueError("Unable to parse the response.")
236+
return None, response.content
237237

238238
def visualization_check(
239239
self,
@@ -242,12 +242,20 @@ def visualization_check(
242242
svg_string: str,
243243
variable_descriptions: dict[str, str],
244244
source,
245-
) -> tuple[bool, list[tuple[bool, str]]]:
245+
) -> tuple[bool | None, list[tuple[bool | None, str]]]:
246246
vis_verifier = VisVerifier(self.llm, self)
247247
verifications = vis_verifier.verify(
248248
request, previous_code, svg_string, variable_descriptions, source
249249
)
250-
pass_verify = all([verification["answer"] for verification in verifications])
250+
251+
answers = [verification["answer"] for verification in verifications]
252+
if False in answers:
253+
pass_verify = False
254+
elif None in answers:
255+
pass_verify = None
256+
else:
257+
pass_verify = True
258+
251259
reason = []
252260
for verification in verifications:
253261
answer = verification["answer"]

coml/magics.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@
6666
</style>
6767
"""
6868

69+
VERIFY_STATUS_ICON = {
70+
"error": "❌",
71+
"warning": "⚠️",
72+
"info": "ℹ️",
73+
"ok": "✅",
74+
None: "❔",
75+
True: "✅",
76+
False: "❌",
77+
}
78+
6979

7080
@magics_class
7181
class CoMLMagics(Magics):
@@ -265,14 +275,6 @@ def display_statuses(statuses):
265275
elif error or output:
266276
display_names["sanity"] = "Output sanity check"
267277

268-
status_icon = {
269-
"error": "❌",
270-
"warning": "⚠️",
271-
"info": "ℹ️",
272-
"ok": "✅",
273-
True: "✅",
274-
False: "❌",
275-
}
276278
loading = "<span class='loader'></span>"
277279
message_template = "<details><summary><b>{}:</b> {}</summary>\n{}</details>"
278280
for name in display_names:
@@ -285,7 +287,7 @@ def display_statuses(statuses):
285287
display_names[name],
286288
loading
287289
if name not in statuses
288-
else status_icon[statuses[name]["result"]],
290+
else VERIFY_STATUS_ICON[statuses[name]["result"]],
289291
detail_message,
290292
)
291293

@@ -318,14 +320,14 @@ def display_statuses(statuses):
318320
visualization_check_details,
319321
) = self.agent.visualization_check(
320322
context["request"],
321-
"\n".join(self._get_code_context()),
323+
"\n".join(context["codes"]),
322324
output.replace("<image/svg+xml>", ""),
323325
context["variables"],
324326
vis_framework,
325327
)
326328
details = ""
327329
for detail in visualization_check_details:
328-
details += ("✅" if detail[0] else "❌") + " " + detail[1] + "\n"
330+
details += VERIFY_STATUS_ICON[detail[0]] + " " + detail[1] + "\n"
329331
result["vis"] = {
330332
"result": visualization_check_result,
331333
"details": details,

coml/vis_utils/deconstruct.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,10 @@ def process_legend_matplotlib(spec):
329329
spec["type"] = "legend"
330330
labels = []
331331
examples = []
332-
for i in range(1, len(spec["children"])):
332+
# element(index 0) might be background
333+
# todo: A more accurate way to recognize background
334+
first = 0 if spec["children"][0]["tag"] == "text" else 1
335+
for i in range(first, len(spec["children"])):
333336
child = spec["children"][i]
334337
if child["tag"] == "text":
335338
labels.append(child)
@@ -1054,7 +1057,7 @@ def analysis_mark(nodes, spec):
10541057
lines = []
10551058
if "path" in nodes:
10561059
paths = nodes["path"]
1057-
lines += identify_mark_lines(paths, spec)
1060+
lines += identify_mark_lines(paths)
10581061
if "line" in nodes:
10591062
lines += nodes["line"]
10601063
# line chart that only has two points
@@ -1213,26 +1216,51 @@ def deconstruct(svg, source="matplotlib"):
12131216
# matplotlib parser
12141217
defss = {}
12151218
spec = parser_node(svg, None, defss, [0, 0], {}, source)
1216-
1217-
for child in spec["children"][0]["children"]:
1218-
if "type" in child and child["type"] == "subplot":
1219-
child["encoding"] = {}
1220-
others = {}
1221-
for child2 in child["children"]:
1222-
if "type" in child2:
1223-
if child2["type"] == "xaxis" or child2["type"] == "yaxis":
1224-
analysis_axis(child2, child["encoding"])
1225-
elif child2["type"] == "legend":
1226-
analysis_legend(child2, child["encoding"])
1219+
subplots = [
1220+
child
1221+
for child in spec["children"][0]["children"]
1222+
if ("type" in child and child["type"] == "subplot")
1223+
]
1224+
if len(subplots) != 1:
1225+
return None
1226+
subplot = subplots[0]
1227+
1228+
# find legend
1229+
legends = [
1230+
child
1231+
for child in subplot["children"]
1232+
if ("type" in child and child["type"] == "legend")
1233+
]
1234+
legend = None
1235+
if len(legends) > 1:
1236+
return None
1237+
elif len(legends) == 1:
1238+
legend = legends[0]
1239+
else:
1240+
# legend may out of subplot
1241+
legends = [
1242+
child
1243+
for child in spec["children"][0]["children"]
1244+
if ("type" in child and child["type"] == "legend")
1245+
]
1246+
if len(legends) == 1:
1247+
legend = legends[0]
1248+
1249+
subplot["encoding"] = {}
1250+
if legend is not None:
1251+
analysis_legend(legend, subplot["encoding"])
1252+
others = {}
1253+
for child in subplot["children"]:
1254+
if "type" in child:
1255+
if child["type"] == "xaxis" or child["type"] == "yaxis":
1256+
analysis_axis(child, subplot["encoding"])
1257+
else:
1258+
nodes = get_leaf_nodes(child)
1259+
for node in nodes:
1260+
if node["tag"] not in others:
1261+
others[node["tag"]] = [node]
12271262
else:
1228-
nodes = get_leaf_nodes(child2)
1229-
for node in nodes:
1230-
if node["tag"] not in others:
1231-
others[node["tag"]] = [node]
1232-
else:
1233-
others[node["tag"]].append(node)
1234-
analysis_scale(child)
1235-
analysis_mark(others, child)
1236-
return child
1237-
1238-
return None
1263+
others[node["tag"]].append(node)
1264+
analysis_scale(subplot)
1265+
analysis_mark(others, subplot)
1266+
return subplot

coml/vis_utils/verifier.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,9 @@ def answer_question(
158158
[previous_code],
159159
)
160160
code = generating_context["answer"]
161-
162161
final_code = previous_code + "\n" + code
162+
# do not show figure
163+
final_code = final_code.replace("plt.show()", "")
163164
global_env = {"finding": None}
164165
try:
165166
exec(final_code, global_env)
@@ -607,9 +608,9 @@ def check_order(order: dict, chart_info: dict):
607608
if not is_sorted:
608609
result["answer"] = False
609610

610-
result["rationale"] = f"Sort {order['channel']} in {order['order']} order."
611+
result["rationale"] = f"{order['channel']} is sorted in {order['order']} order."
611612
if result["answer"] is False:
612-
result["rationale"] = result["rationale"].replace("Sort", "Doesn't sort")
613+
result["rationale"] = result["rationale"].replace("is sorted", "is not sorted")
613614

614615
return result
615616

@@ -627,7 +628,13 @@ def __init__(self, llm: BaseChatModel, agent):
627628
def _add_verification(self, verification):
628629
self.verifications.append(verification)
629630
# display
630-
answer = "✅" if verification["answer"] else "❌"
631+
answer = ""
632+
if verification["answer"] is True:
633+
answer = "✅"
634+
elif verification["answer"] is False:
635+
answer = "❌"
636+
elif verification["answer"] is None:
637+
answer = "❔"
631638
aspect = verification["aspect"].capitalize()
632639
rationale = verification["rationale"]
633640
print(answer + " " + aspect + ": " + rationale)
@@ -642,7 +649,7 @@ def verify(
642649
):
643650
self.verifications = []
644651
understand_fail_result = {
645-
"answer": False,
652+
"answer": None,
646653
"aspect": "Visualization understanding",
647654
"rationale": "Cannot understand the visualization.",
648655
}
@@ -658,7 +665,10 @@ def verify(
658665
# STEP2: check chart type, data encoding and title
659666
self.verify_chart_info(request, chart_info, variable_descriptions)
660667
pass_verify = all(
661-
[verification["answer"] for verification in self.verifications]
668+
[
669+
verification["answer"] is True
670+
for verification in self.verifications
671+
]
662672
)
663673
if pass_verify:
664674
# STEP3: check visualization data
@@ -707,6 +717,22 @@ def verify_data(
707717
try:
708718
# STEP 1: Spot-Check
709719
data = chart_info["data"]
720+
encoding = chart_info["encoding"]
721+
# check label
722+
for channel in encoding.keys():
723+
if "title" not in encoding[channel]:
724+
verification = {
725+
"aspect": channel + " label",
726+
"answer": None,
727+
"rationale": "Channel "
728+
+ channel
729+
+ " is not labeled, so accurate understanding of the data on the graph is difficult.",
730+
}
731+
self._add_verification(verification)
732+
verifications.append(verification)
733+
if len(verifications) > 0:
734+
return verifications
735+
710736
# random pick NUM_SAMPLE data points
711737
indexes = range(len(data))
712738
sampled_indexes = random.sample(indexes, NUM_SAMPLE)
@@ -737,11 +763,11 @@ def verify_data(
737763
if verification:
738764
self._add_verification(verification)
739765
verifications.append(verification)
740-
if verification["answer"] is False:
766+
if verification["answer"] is not True:
741767
break
742768

743769
pass_verify = all(
744-
[verification["answer"] for verification in verifications]
770+
[verification["answer"] is True for verification in verifications]
745771
)
746772
if not pass_verify:
747773
return verifications

0 commit comments

Comments
 (0)