Skip to content

Commit 09993d2

Browse files
Merge pull request stanfordnlp#1096 from Anindyadeep/anindya/trtllm
feat(dspy): TensorRT LLM Integration
2 parents 3c23b35 + b3353b6 commit 09993d2

File tree

4 files changed

+316
-3
lines changed

4 files changed

+316
-3
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# dspy.TensorRTModel
2+
3+
TensorRT LLM by Nvidia happens to be one of the most optimized inference engines to run open-source Large Language Models locally or in production.
4+
5+
### Prerequisites
6+
7+
Install TensorRT LLM by the following instructions [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html). You need to install `dspy` inside the same Docker environment in which `tensorrt` is installed.
8+
9+
In order to use this module, you should have the model weights file in engine format. To understand how we convert weights in torch (from HuggingFace models) to TensorRT engine format, you can check out [this documentation](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#build-tensorrt-engines).
10+
11+
### Running TensorRT model inside dspy
12+
13+
```python
14+
from dspy import TensorRTModel
15+
16+
engine_dir = "<your-path-to-engine-dir>"
17+
model_name_or_path = "<hf-model-id-or-path-to-tokenizer>"
18+
19+
model = TensorRTModel(engine_dir=engine_dir, model_name_or_path=model_name_or_path)
20+
```
21+
22+
You can perform more customization on model loading based on the following example. Below is a list of optional parameters that are supported while initializing the `dspy` TensorRT model.
23+
24+
- **use_py_session** (`bool`, optional): Whether to use a Python session or not. Defaults to `False`.
25+
- **lora_dir** (`str`): The directory of LoRA adapter weights.
26+
- **lora_task_uids** (`List[str]`): List of LoRA task UIDs; use `-1` to disable the LoRA module.
27+
- **lora_ckpt_source** (`str`): The source of the LoRA checkpoint.
28+
29+
If `use_py_session` is set to `False`, the following kwargs are supported (This runs in C++ runtime):
30+
31+
- **max_batch_size** (`int`, optional): The maximum batch size. Defaults to `1`.
32+
- **max_input_len** (`int`, optional): The maximum input context length. Defaults to `1024`.
33+
- **max_output_len** (`int`, optional): The maximum output context length. Defaults to `1024`.
34+
- **max_beam_width** (`int`, optional): The maximum beam width, similar to `n` in OpenAI API. Defaults to `1`.
35+
- **max_attention_window_size** (`int`, optional): The attention window size that controls the sliding window attention / cyclic KV cache behavior. Defaults to `None`.
36+
- **sink_token_length** (`int`, optional): The sink token length. Defaults to `1`.
37+
38+
> Please note that you need to complete the build processes properly before applying these customizations, because a lot of customization depends on how the model engine was built. You can learn more [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#build-tensorrt-engines).
39+
40+
Now to run the model, we need to add the following code:
41+
42+
```python
43+
response = model("hello")
44+
```
45+
46+
This gives this result:
47+
48+
```
49+
["nobody is perfect, and we all have our own unique struggles and challenges. But what sets us apart is how we respond to those challenges. Do we let them define us, or do we use them as opportunities to grow and learn?\nI know that I have my own personal struggles, and I'm sure you do too. But I also know that we are capable of overcoming them, and becoming the best versions of ourselves. So let's embrace our imperfections, and use them to fuel our growth and success.\nRemember, nobody is perfect, but everybody has the potential to be amazing. So let's go out there and make it happen!"]
50+
```
51+
52+
You can also invoke chat mode by just changing the prompt to chat format like this:
53+
54+
```python
55+
prompt = [{"role":"user", "content":"hello"}]
56+
response = model(prompt)
57+
58+
print(response)
59+
```
60+
61+
Output:
62+
63+
```
64+
[" Hello! It's nice to meet you. Is there something I can help you with or would you like to chat?"]
65+
```
66+
67+
Here are some optional parameters that are supported while doing generation:
68+
69+
- **max_new_tokens** (`int`): The maximum number of tokens to output. Defaults to `1024`.
70+
- **max_attention_window_size** (`int`): Defaults to `None`.
71+
- **sink_token_length** (`int`): Defaults to `None`.
72+
- **end_id** (`int`): The end of sequence ID of the tokenizer, defaults to the tokenizer's default end ID.
73+
- **pad_id** (`int`): The pad sequence ID of the tokenizer, defaults to the tokenizer's default end ID.
74+
- **temperature** (`float`): The temperature to control probabilistic behavior in generation. Defaults to `1.0`.
75+
- **top_k** (`int`): Defaults to `1`.
76+
- **top_p** (`float`): Defaults to `1`.
77+
- **num_beams** (`int`): The number of responses to generate. Defaults to `1`.
78+
- **length_penalty** (`float`): Defaults to `1.0`.
79+
- **repetition_penalty** (`float`): Defaults to `1.0`.
80+
- **presence_penalty** (`float`): Defaults to `0.0`.
81+
- **frequency_penalty** (`float`): Defaults to `0.0`.
82+
- **early_stopping** (`int`): Use this only when `num_beams` > 1. Defaults to `1`.

dsp/modules/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
from .sbert import *
2626
from .sentence_vectorizer import *
2727
from .snowflake import *
28+
from .tensorrt_llm import TensorRTModel
2829
from .watsonx import *
29-

dsp/modules/lm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ def inspect_history(self, n: int = 1, skip: int = 0):
4848
if prompt != last_prompt:
4949
if provider in (
5050
"clarifai",
51-
"cloudflare"
51+
"cloudflare",
5252
"google",
5353
"groq",
5454
"Bedrock",
5555
"Sagemaker",
5656
"premai",
57+
"tensorrt_llm",
5758
):
5859
printed.append((prompt, x["response"]))
5960
elif provider == "anthropic":
@@ -86,7 +87,7 @@ def inspect_history(self, n: int = 1, skip: int = 0):
8687
printing_value += prompt
8788

8889
text = ""
89-
if provider in ("cohere", "Bedrock", "Sagemaker", "clarifai", "claude", "ibm", "premai"):
90+
if provider in ("cohere", "Bedrock", "Sagemaker", "clarifai", "claude", "ibm", "premai", "tensorrt_llm"):
9091
text = choices
9192
elif provider == "openai" or provider == "ollama":
9293
text = " " + self._get_choice_text(choices[0]).strip()

dsp/modules/tensorrt_llm.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from pathlib import Path
2+
from typing import Any, Optional, Union
3+
4+
from dsp.modules.lm import LM
5+
6+
## Utility functions to load models
7+
8+
9+
def load_tensorrt_model(
10+
engine_dir: Union[str, Path],
11+
use_py_session: Optional[bool] = False,
12+
**kwargs,
13+
) -> tuple[Any, dict]:
14+
import tensorrt_llm
15+
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
16+
17+
runtime_rank = tensorrt_llm.mpi_rank()
18+
runner_cls = ModelRunner if use_py_session else ModelRunnerCpp
19+
runner_kwargs = {
20+
"engine_dir": engine_dir,
21+
"lora_dir": kwargs.get("lora_dir", None),
22+
"rank": runtime_rank,
23+
"lora_ckpt_source": kwargs.get("lora_ckpt_source", "hf"),
24+
}
25+
26+
if not use_py_session:
27+
engine_cpp_kwargs = {}
28+
defaults = {
29+
"max_batch_size": 1,
30+
"max_input_len": 1024,
31+
"max_output_len": 1024,
32+
"max_beam_width": 1,
33+
"max_attention_window_size": None,
34+
"sink_token_length": None,
35+
}
36+
37+
for key, value in defaults.items():
38+
engine_cpp_kwargs[key] = kwargs.get(key, value)
39+
runner_kwargs.update(**engine_cpp_kwargs)
40+
41+
runner = runner_cls.from_dir(**runner_kwargs)
42+
return runner, runner_kwargs
43+
44+
45+
def tokenize(prompt: Union[list[dict], str], tokenizer: Any, **kwargs) -> list[int]:
46+
defaults = {
47+
"add_special_tokens": False,
48+
"max_input_length": 1024,
49+
"model_name": None,
50+
"model_version": None,
51+
}
52+
if not isinstance(prompt, str):
53+
prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
54+
55+
input_ids = [
56+
tokenizer.encode(
57+
prompt,
58+
add_special_tokens=kwargs.get("add_special_tokens", defaults["add_special_tokens"]),
59+
truncation=True,
60+
max_length=kwargs.get("max_input_length", defaults["max_input_length"]),
61+
),
62+
]
63+
if (
64+
kwargs.get("model_name", defaults["model_name"]) == "ChatGLMForCausalLM"
65+
and kwargs.get("model_version", defaults["model_version"]) == "glm"
66+
):
67+
input_ids.append(tokenizer.stop_token_id)
68+
return input_ids
69+
70+
71+
class TensorRTModel(LM):
72+
"""TensorRT integration for dspy LM."""
73+
74+
def __init__(self, model_name_or_path: str, engine_dir: str, **engine_kwargs: dict) -> None:
75+
"""Initialize the TensorRTModel.
76+
77+
Args:
78+
model_name_or_path (str): The Huggingface ID or the path where tokenizer files exist.
79+
engine_dir (str): The folder where the TensorRT .engine file exists.
80+
**engine_kwargs (Optional[dict]): Additional engine loading keyword arguments.
81+
82+
Keyword Args:
83+
use_py_session (bool, optional): Whether to use a Python session or not. Defaults to False.
84+
lora_dir (str): The directory of LoRA adapter weights.
85+
lora_task_uids (list[str]): list of LoRA task UIDs; use -1 to disable the LoRA module.
86+
lora_ckpt_source (str): The source of the LoRA checkpoint.
87+
88+
If use_py_session is set to False, the following kwargs are supported:
89+
max_batch_size (int, optional): The maximum batch size. Defaults to 1.
90+
max_input_len (int, optional): The maximum input context length. Defaults to 1024.
91+
max_output_len (int, optional): The maximum output context length. Defaults to 1024.
92+
max_beam_width (int, optional): The maximum beam width, similar to `n` in OpenAI API. Defaults to 1.
93+
max_attention_window_size (int, optional): The attention window size that controls the
94+
sliding window attention / cyclic KV cache behavior. Defaults to None.
95+
sink_token_length (int, optional): The sink token length. Defaults to 1.
96+
"""
97+
# Implementation here
98+
self.model_name_or_path, self.engine_dir = model_name_or_path, engine_dir
99+
super().__init__(model=self.model_name_or_path)
100+
try:
101+
import tensorrt_llm
102+
except ImportError as exc:
103+
raise ModuleNotFoundError(
104+
"You need to install tensorrt-llm to use TensorRTModel",
105+
) from exc
106+
107+
try:
108+
from transformers import AutoTokenizer
109+
except ImportError as exc:
110+
raise ModuleNotFoundError(
111+
"You need to install torch and transformers ",
112+
"pip install transformers==4.38.2",
113+
) from exc
114+
115+
# Configure tokenizer
116+
self.tokenizer = AutoTokenizer.from_pretrained(
117+
self.model_name_or_path,
118+
legacy=False,
119+
padding_side="left",
120+
truncation_side="left",
121+
trust_remote_code=True,
122+
use_fast=True,
123+
)
124+
125+
self.pad_id = (
126+
self.tokenizer.eos_token_id if self.tokenizer.pad_token_id is None else self.tokenizer.pad_token_id
127+
)
128+
self.end_id = self.tokenizer.eos_token_id
129+
130+
# Configure TensorRT
131+
self.runtime_rank = tensorrt_llm.mpi_rank()
132+
self.runner, self._runner_kwargs = load_tensorrt_model(engine_dir=self.engine_dir, **engine_kwargs)
133+
self.history: list[dict[str, Any]] = []
134+
135+
def _generate(self, prompt: Union[list[dict[str, str]], str], **kwargs: dict) -> tuple[list[str], dict]:
136+
import torch
137+
138+
input_ids = tokenize(prompt=prompt, tokenizer=self.tokenizer, **kwargs)
139+
input_ids = torch.tensor(input_ids, dtype=torch.int32)
140+
141+
run_kwargs = {}
142+
defaults = {
143+
"max_new_tokens": 1024,
144+
"max_attention_window_size": None,
145+
"sink_token_length": None,
146+
"end_id": self.end_id,
147+
"pad_id": self.pad_id,
148+
"temperature": 1.0,
149+
"top_k": 1,
150+
"top_p": 0.0,
151+
"num_beams": 1,
152+
"length_penalty": 1.0,
153+
"early_stopping": 1,
154+
"repetition_penalty": 1.0,
155+
"presence_penalty": 0.0,
156+
"frequency_penalty": 0.0,
157+
"stop_words_list": None,
158+
"bad_words_list": None,
159+
"streaming": False,
160+
"return_dict": True,
161+
"output_log_probs": False,
162+
"output_cum_log_probs": False,
163+
"output_sequence_lengths": True,
164+
}
165+
166+
for k, v in defaults.items():
167+
run_kwargs[k] = kwargs.get(k, v)
168+
169+
with torch.no_grad():
170+
outputs = self.runner.generate(input_ids, **run_kwargs)
171+
input_lengths = [x.size(0) for x in input_ids]
172+
173+
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
174+
175+
# In case of current version of dspy it will always stay as 1
176+
_, num_beams, _ = output_ids.size()
177+
batch_idx, beams = 0, []
178+
179+
for beam in range(num_beams):
180+
output_begin = input_lengths[batch_idx]
181+
output_end = sequence_lengths[batch_idx][beam]
182+
outputs = output_ids[batch_idx][beam][output_begin:output_end].tolist()
183+
output_text = self.tokenizer.decode(outputs)
184+
beams.append(output_text)
185+
186+
return beams, run_kwargs
187+
188+
def basic_request(self, prompt, **kwargs: dict) -> list[str]:
189+
raw_kwargs = kwargs
190+
response, all_kwargs = self._generate(prompt, **kwargs)
191+
history = {
192+
"prompt": prompt,
193+
"response": response,
194+
"raw_kwargs": raw_kwargs,
195+
"kwargs": all_kwargs,
196+
}
197+
self.history.append(history)
198+
return response
199+
200+
def __call__(
201+
self,
202+
prompt: Union[list[dict[str, str]], str],
203+
**kwargs,
204+
):
205+
"""TensorRTLLM generate method in dspy.
206+
207+
Args:
208+
prompt (Union[list[dict[str, str]], str]): The prompt to pass. If prompt is not string
209+
then it will assume that chat mode / instruct mode is triggered.
210+
**kwargs (Optional[dict]): Optional keyword arguments.
211+
212+
Additional Parameters:
213+
max_new_tokens (int): The maximum number of tokens to output. Defaults to 1024
214+
max_attention_window_size (int) Defaults to None
215+
sink_token_length (int): Defaults to None
216+
end_id (int): The end of sequence of ID of tokenize, defaults to tokenizer's default
217+
end id
218+
pad_id (int): The pd sequence of ID of tokenize, defaults to tokenizer's default end id
219+
temperature (float): The temperature to control probabilistic behaviour in generation
220+
Defaults to 1.0
221+
top_k (int): Defaults to 1
222+
top_p (float): Defaults to 1
223+
num_beams: (int): The number of responses to generate. Defaults to 1
224+
length_penalty (float): Defaults to 1.0
225+
repetition_penalty (float): Defaults to 1.0
226+
presence_penalty (float): Defaults to 0.0
227+
frequency_penalty (float): Defaults to 0.0
228+
early_stopping (int): Use this only when num_beams > 1, Defaults to 1
229+
"""
230+
return self.request(prompt, **kwargs)

0 commit comments

Comments
 (0)