Skip to content

Commit 3ac7b16

Browse files
Merge pull request stanfordnlp#967 from fsndzomga/graphviz-of-dspy-modules
feature(dspy): generate a graphical representation of modules
2 parents fc58b9c + 0c1326d commit 3ac7b16

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

dspy/experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
from module_graph import *
2+
13
from .synthesizer import *
24
from .synthetic_data import *

dspy/experimental/module_graph.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import dspy
2+
3+
try:
4+
import graphviz # type: ignore
5+
graphviz_available = True
6+
except ImportError:
7+
graphviz_available = False
8+
9+
10+
class ModuleGraph:
11+
def __init__(self, module_name, module):
12+
if graphviz_available is False:
13+
raise ImportError(
14+
"""Please install graphviz to use this feature.
15+
Run 'pip install graphviz'""")
16+
17+
self.graph = graphviz.Digraph(format='png')
18+
self.nodes = set()
19+
self.module_name = module_name
20+
self.module = module
21+
self.inspect_settings(dspy.settings)
22+
self.add_module(self.module_name, self.module)
23+
24+
def inspect_settings(self, settings):
25+
"""Check for the existence and configuration of LM and RM and add them to the graph."""
26+
components = {'lm': settings.lm, 'rm': settings.rm}
27+
for component_name, component in components.items():
28+
if component:
29+
details = {attr: getattr(component, attr) for attr in dir(component)
30+
if not attr.startswith('_') and not callable(getattr(component, attr))}
31+
component_details = f"{component_name.upper()} Details: " + ', '.join(f"{k}: {v}" for k, v in details.items() if k!='history')
32+
self.graph.node(component_name, label=component_details, shape='box')
33+
self.nodes.add(component_name)
34+
35+
def add_module(self, module_name, module):
36+
"""Add a module to the graph"""
37+
38+
module_type = type(module)
39+
40+
if 'dspy.predict' in str(module_type):
41+
module_name = self.generate_module_name(module_name, module_type)
42+
self.process_submodule(module_name, module)
43+
else:
44+
self.process_submodules(module_name, module)
45+
46+
def generate_module_name(self, base_name, module_type): # noqa: ANN201
47+
""" Generate a module name based on the module type"""
48+
type_map = {
49+
'Predict': '__Predict',
50+
'ReAct': '__ReAct',
51+
'ChainOfThought': '__ChainOfThought',
52+
'ProgramOfThought': '__ProgramOfThought',
53+
'MultiChainComparison': '__MultiChainComparison',
54+
'majority': '__majority',
55+
}
56+
57+
for key, suffix in type_map.items():
58+
if key in str(module_type):
59+
return base_name + suffix
60+
return base_name
61+
62+
def process_submodules(self, module_name, module):
63+
""" Process submodules of a module and add them to the graph"""
64+
65+
for sub_module_name, sub_module in module.__dict__.items():
66+
if isinstance(sub_module, dspy.Predict):
67+
sub_module_name = self.generate_module_name(sub_module_name, type(sub_module))
68+
self.process_submodule(f"{module_name}__{sub_module_name}", sub_module)
69+
elif isinstance(sub_module, (dspy.Module, dspy.Retrieve)):
70+
self.add_module(f"{module_name}__{sub_module_name}", sub_module)
71+
if isinstance(sub_module, dspy.Retrieve):
72+
self.graph.edge("rm", 'lm', label='RM used in Module')
73+
74+
def process_submodule(self, sub_module_name, sub_module):
75+
"""Process a submodule and add it to the graph"""
76+
77+
for field_type, fields in [('input', sub_module.signature.input_fields),
78+
('output', sub_module.signature.output_fields)]:
79+
for field_name, field in fields.items():
80+
node_id = f"{sub_module_name}_{field_type}_{field_name}"
81+
if node_id not in self.nodes:
82+
label = f"{field_name}: ({field.json_schema_extra['desc']})"
83+
self.graph.node(node_id, label=label, shape='ellipse')
84+
self.nodes.add(node_id)
85+
edge_direction = (node_id, sub_module_name) if field_type == 'input' else (sub_module_name, node_id)
86+
self.graph.edge(*edge_direction)
87+
88+
# Add node for the submodule itself
89+
self.graph.node(sub_module_name, label=sub_module_name, shape='box')
90+
self.nodes.add(sub_module_name)
91+
92+
# Connect submodule to LM if configured
93+
if 'lm' in self.nodes:
94+
self.graph.edge('lm', sub_module_name, label='LM used in Module')
95+
96+
def render_graph(self, filename=None):
97+
"""Render the graph to a file(png)"""
98+
if filename is None:
99+
filename = self.module_name
100+
self.graph.render(filename)
101+
102+
103+
# Example usage of the ModuleGraph class:
104+
# import dspy
105+
# import os
106+
# from dotenv import load_dotenv
107+
# from dspy.experimental import ModuleGraph
108+
109+
# load_dotenv()
110+
111+
# # Configuration of dspy models
112+
# llm = dspy.OpenAI(
113+
# model='gpt-3.5-turbo',
114+
# api_key=os.environ['OPENAI_API_KEY'],
115+
# max_tokens=100
116+
# )
117+
118+
# colbertv2_wiki = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
119+
120+
# dspy.settings.configure(lm=llm, rm=colbertv2_wiki)
121+
122+
# class GenerateAnswer(dspy.Signature):
123+
# "Answer with long and detailed answers"
124+
# context = dspy.InputField(desc="may content relevant facts")
125+
# question = dspy.InputField()
126+
# answer = dspy.OutputField(desc="often between 10 and 50 words")
127+
128+
# class RAG(dspy.Module):
129+
# def __init__(self, num_passages=3):
130+
# super().__init__()
131+
# self.retrieve = dspy.Retrieve(k=num_passages)
132+
# self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
133+
134+
# def forward(self, question):
135+
# context = self.retrieve(question).passages
136+
# prediction = self.generate_answer(context=context, question=question)
137+
# return dspy.Prediction(context=context, answer=prediction.answer)
138+
139+
# rag_system = RAG()
140+
# graph = ModuleGraph("RAG", rag_system)
141+
142+
# graph.render_graph()

0 commit comments

Comments
 (0)