|
| 1 | +import dspy |
| 2 | + |
| 3 | +try: |
| 4 | + import graphviz # type: ignore |
| 5 | + graphviz_available = True |
| 6 | +except ImportError: |
| 7 | + graphviz_available = False |
| 8 | + |
| 9 | + |
| 10 | +class ModuleGraph: |
| 11 | + def __init__(self, module_name, module): |
| 12 | + if graphviz_available is False: |
| 13 | + raise ImportError( |
| 14 | + """Please install graphviz to use this feature. |
| 15 | + Run 'pip install graphviz'""") |
| 16 | + |
| 17 | + self.graph = graphviz.Digraph(format='png') |
| 18 | + self.nodes = set() |
| 19 | + self.module_name = module_name |
| 20 | + self.module = module |
| 21 | + self.inspect_settings(dspy.settings) |
| 22 | + self.add_module(self.module_name, self.module) |
| 23 | + |
| 24 | + def inspect_settings(self, settings): |
| 25 | + """Check for the existence and configuration of LM and RM and add them to the graph.""" |
| 26 | + components = {'lm': settings.lm, 'rm': settings.rm} |
| 27 | + for component_name, component in components.items(): |
| 28 | + if component: |
| 29 | + details = {attr: getattr(component, attr) for attr in dir(component) |
| 30 | + if not attr.startswith('_') and not callable(getattr(component, attr))} |
| 31 | + component_details = f"{component_name.upper()} Details: " + ', '.join(f"{k}: {v}" for k, v in details.items() if k!='history') |
| 32 | + self.graph.node(component_name, label=component_details, shape='box') |
| 33 | + self.nodes.add(component_name) |
| 34 | + |
| 35 | + def add_module(self, module_name, module): |
| 36 | + """Add a module to the graph""" |
| 37 | + |
| 38 | + module_type = type(module) |
| 39 | + |
| 40 | + if 'dspy.predict' in str(module_type): |
| 41 | + module_name = self.generate_module_name(module_name, module_type) |
| 42 | + self.process_submodule(module_name, module) |
| 43 | + else: |
| 44 | + self.process_submodules(module_name, module) |
| 45 | + |
| 46 | + def generate_module_name(self, base_name, module_type): # noqa: ANN201 |
| 47 | + """ Generate a module name based on the module type""" |
| 48 | + type_map = { |
| 49 | + 'Predict': '__Predict', |
| 50 | + 'ReAct': '__ReAct', |
| 51 | + 'ChainOfThought': '__ChainOfThought', |
| 52 | + 'ProgramOfThought': '__ProgramOfThought', |
| 53 | + 'MultiChainComparison': '__MultiChainComparison', |
| 54 | + 'majority': '__majority', |
| 55 | + } |
| 56 | + |
| 57 | + for key, suffix in type_map.items(): |
| 58 | + if key in str(module_type): |
| 59 | + return base_name + suffix |
| 60 | + return base_name |
| 61 | + |
| 62 | + def process_submodules(self, module_name, module): |
| 63 | + """ Process submodules of a module and add them to the graph""" |
| 64 | + |
| 65 | + for sub_module_name, sub_module in module.__dict__.items(): |
| 66 | + if isinstance(sub_module, dspy.Predict): |
| 67 | + sub_module_name = self.generate_module_name(sub_module_name, type(sub_module)) |
| 68 | + self.process_submodule(f"{module_name}__{sub_module_name}", sub_module) |
| 69 | + elif isinstance(sub_module, (dspy.Module, dspy.Retrieve)): |
| 70 | + self.add_module(f"{module_name}__{sub_module_name}", sub_module) |
| 71 | + if isinstance(sub_module, dspy.Retrieve): |
| 72 | + self.graph.edge("rm", 'lm', label='RM used in Module') |
| 73 | + |
| 74 | + def process_submodule(self, sub_module_name, sub_module): |
| 75 | + """Process a submodule and add it to the graph""" |
| 76 | + |
| 77 | + for field_type, fields in [('input', sub_module.signature.input_fields), |
| 78 | + ('output', sub_module.signature.output_fields)]: |
| 79 | + for field_name, field in fields.items(): |
| 80 | + node_id = f"{sub_module_name}_{field_type}_{field_name}" |
| 81 | + if node_id not in self.nodes: |
| 82 | + label = f"{field_name}: ({field.json_schema_extra['desc']})" |
| 83 | + self.graph.node(node_id, label=label, shape='ellipse') |
| 84 | + self.nodes.add(node_id) |
| 85 | + edge_direction = (node_id, sub_module_name) if field_type == 'input' else (sub_module_name, node_id) |
| 86 | + self.graph.edge(*edge_direction) |
| 87 | + |
| 88 | + # Add node for the submodule itself |
| 89 | + self.graph.node(sub_module_name, label=sub_module_name, shape='box') |
| 90 | + self.nodes.add(sub_module_name) |
| 91 | + |
| 92 | + # Connect submodule to LM if configured |
| 93 | + if 'lm' in self.nodes: |
| 94 | + self.graph.edge('lm', sub_module_name, label='LM used in Module') |
| 95 | + |
| 96 | + def render_graph(self, filename=None): |
| 97 | + """Render the graph to a file(png)""" |
| 98 | + if filename is None: |
| 99 | + filename = self.module_name |
| 100 | + self.graph.render(filename) |
| 101 | + |
| 102 | + |
| 103 | +# Example usage of the ModuleGraph class: |
| 104 | +# import dspy |
| 105 | +# import os |
| 106 | +# from dotenv import load_dotenv |
| 107 | +# from dspy.experimental import ModuleGraph |
| 108 | + |
| 109 | +# load_dotenv() |
| 110 | + |
| 111 | +# # Configuration of dspy models |
| 112 | +# llm = dspy.OpenAI( |
| 113 | +# model='gpt-3.5-turbo', |
| 114 | +# api_key=os.environ['OPENAI_API_KEY'], |
| 115 | +# max_tokens=100 |
| 116 | +# ) |
| 117 | + |
| 118 | +# colbertv2_wiki = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') |
| 119 | + |
| 120 | +# dspy.settings.configure(lm=llm, rm=colbertv2_wiki) |
| 121 | + |
| 122 | +# class GenerateAnswer(dspy.Signature): |
| 123 | +# "Answer with long and detailed answers" |
| 124 | +# context = dspy.InputField(desc="may content relevant facts") |
| 125 | +# question = dspy.InputField() |
| 126 | +# answer = dspy.OutputField(desc="often between 10 and 50 words") |
| 127 | + |
| 128 | +# class RAG(dspy.Module): |
| 129 | +# def __init__(self, num_passages=3): |
| 130 | +# super().__init__() |
| 131 | +# self.retrieve = dspy.Retrieve(k=num_passages) |
| 132 | +# self.generate_answer = dspy.ChainOfThought(GenerateAnswer) |
| 133 | + |
| 134 | +# def forward(self, question): |
| 135 | +# context = self.retrieve(question).passages |
| 136 | +# prediction = self.generate_answer(context=context, question=question) |
| 137 | +# return dspy.Prediction(context=context, answer=prediction.answer) |
| 138 | + |
| 139 | +# rag_system = RAG() |
| 140 | +# graph = ModuleGraph("RAG", rag_system) |
| 141 | + |
| 142 | +# graph.render_graph() |
0 commit comments