Skip to content

Commit 07d8e1d

Browse files
Update llamaindex.py
ruff fix
1 parent 0f85bc3 commit 07d8e1d

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

dspy/predict/llamaindex.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
1-
from llama_index.core.prompts import BasePromptTemplate
2-
from dspy import Predict
3-
import dspy
4-
from abc import abstractmethod
5-
from typing import Any, Optional, List, Dict, Callable
1+
import re
2+
from copy import deepcopy
3+
from typing import Any, Callable, Dict, List, Optional
4+
65
from llama_index.core.base.llms.base import BaseLLM
7-
from llama_index.core.base.llms.types import ChatMessage
86
from llama_index.core.base.llms.generic_utils import (
97
prompt_to_messages,
108
)
11-
from llama_index.core.base.query_pipeline.query import QueryComponent, InputKeys, OutputKeys
12-
from llama_index.core.query_pipeline import QueryPipeline
13-
from dspy.signatures.signature import ensure_signature, signature_to_template, infer_prefix, make_signature
14-
from dspy.signatures.field import InputField, OutputField
15-
from dspy.primitives import ProgramMeta
16-
import dsp
17-
from copy import deepcopy
18-
import re
9+
from llama_index.core.base.llms.types import ChatMessage
10+
from llama_index.core.base.query_pipeline.query import InputKeys, OutputKeys, QueryComponent
1911
from llama_index.core.callbacks.base import CallbackManager
20-
from llama_index.core.bridge.pydantic import BaseModel, create_model
21-
from llama_index.core.prompts import PromptTemplate
22-
12+
from llama_index.core.prompts import BasePromptTemplate, PromptTemplate
13+
from llama_index.core.query_pipeline import QueryPipeline
2314

15+
import dsp
16+
import dspy
17+
from dspy import Predict
18+
from dspy.signatures.field import InputField, OutputField
19+
from dspy.signatures.signature import ensure_signature, make_signature, signature_to_template
2420

2521

2622
def get_formatted_template(predict_module: Predict, kwargs: Dict[str, Any]) -> str:
@@ -78,7 +74,7 @@ def __init__(
7874
metadata: Optional[Dict[str, Any]] = None,
7975
template_var_mappings: Optional[Dict[str, Any]] = None,
8076
function_mappings: Optional[Dict[str, Callable]] = None,
81-
**kwargs: Any
77+
**kwargs: Any,
8278
) -> None:
8379
template = signature_to_template(predict_module.signature)
8480
template_vars = _input_keys_from_template(template)
@@ -116,7 +112,7 @@ def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
116112
return get_formatted_template(self.predict_module, mapped_kwargs)
117113

118114
def format_messages(
119-
self, llm: Optional[BaseLLM] = None, **kwargs: Any
115+
self, llm: Optional[BaseLLM] = None, **kwargs: Any,
120116
) -> List[ChatMessage]:
121117
"""Formats the prompt template into chat messages."""
122118
del llm # unused
@@ -126,7 +122,7 @@ def format_messages(
126122
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
127123
"""Get template."""
128124
# get kwarg templates
129-
kwarg_tmpl_map = {k: f"{{k}}" for k in self.template_vars}
125+
kwarg_tmpl_map = {k: "{k}" for k in self.template_vars}
130126

131127
# get "raw" template with all the values filled in with {var_name}
132128
template0 = get_formatted_template(self.predict_module, kwarg_tmpl_map)
@@ -264,4 +260,4 @@ def forward(self, **kwargs: Any) -> Dict[str, Any]:
264260
"""Forward."""
265261
output_dict = self.query_pipeline.run(**kwargs, return_values_direct=False)
266262
return dspy.Prediction(**output_dict)
267-
263+

0 commit comments

Comments
 (0)