Skip to content

support vllm backend #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import textsynth
from . import dummy
from . import llama
from . import light

MODEL_REGISTRY = {
"hf": gpt2.HFLM,
Expand All @@ -11,6 +12,7 @@
"gpt3": gpt3.GPT3LM,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
"lightllm": light.lightllm,
}


Expand Down
136 changes: 136 additions & 0 deletions lm_eval/models/light.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import asyncio
import re
import requests
from typing import List

import httpx
import torch
import tqdm
from transformers import AutoTokenizer, LlamaTokenizer

from lm_eval.base import BaseLM


class lightllm(BaseLM):
def __init__(
self,
device="cuda",
pretrained="huggyllama/llama-7b",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
load_8bit=True,
):
self.tokenizer = LlamaTokenizer.from_pretrained(pretrained)
self.api_url = "http://localhost:8000/generate"

self.batch_size_per_gpu = 128
print("eos_token: ", self.tokenizer.eos_token)
print("Using framework lightllm")

@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id

@property
def max_length(self):
return 128000
# try:
# return self.model.config.n_ctx
# except AttributeError:
# return self.model.config.max_position_embeddings

@property
def max_gen_toks(self):
return 256

@property
def batch_size(self):
return self.batch_size_per_gpu # * gpus

@property
def device(self):
return "cuda"

def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)

def tok_decode(self, tokens: List[int]):
return self.tokenizer.decode(tokens)

def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call

returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
raise NotImplementedError
with torch.no_grad():
return self.model(inps)[0]

async def async_run(self, prompt: str, **kwargs) -> str:
headers = {'Content-Type': 'application/json'}
pload = {
'inputs': prompt,
"parameters": {
'do_sample': False,
'ignore_eos': False,
'max_new_tokens': self.max_gen_toks,
}
}
async with httpx.AsyncClient(timeout=None) as client:
response = await client.post(self.api_url, headers=headers, json=pload)
text = response.json()['generated_text'][0]
return text

async def _async_run_batch(self, prompts: List[str], **kwargs) -> List[str]:
tasks: List[asyncio.Task] = [self.async_run(prompt, **kwargs) for prompt in prompts]
return await asyncio.gather(*tasks)

def run_batch(self, prompts: List[str], **kwargs) -> List[str]:
return asyncio.run(self._async_run_batch(prompts, **kwargs))

def _model_generate(self, context: torch.Tensor, max_length, eos_token_id):
prompt = self.tok_decode(context.tolist()[0])
headers = {'Content-Type': 'application/json'}
pload = {
'inputs': prompt,
"parameters": {
'do_sample': False,
'ignore_eos': False,
'max_new_tokens': max_length - len(context[0]) ,
}
}
response = requests.post(self.api_url, headers=headers, json=pload, stream=False)
text = response.json()['generated_text'][0]
# print(prompt)
# print(text)
return torch.tensor([context.tolist()[0] + self.tok_encode(text)])

def greedy_until(self, reqs):
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly

res = []
real_match = re.compile(r'^(.*?#### \d+(\.\d+)?)', re.DOTALL)

# print(reqs)

prompts = [req[0] for req in reqs]
texts = self.run_batch(prompts)

for text in texts:

target = real_match.search(text)
if target is not None:
text = target.group(1)
else:
print("Warning: no match found for", text)

res.append(text)

return res
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ def main(task_name: str, **kwargs):
else:
print("Using lm-eval")
model_name = kwargs.pop("model_name")
model = "hf" if model_name in ["causal", "seq_to_seq"] else "llama"
model = "lightllm" if model_name == "lightllm" else model
results = evaluator.simple_evaluate(
model="hf" if model_name in ["causal", "seq_to_seq"] else "llama",
model=model,
model_args=f"pretrained={kwargs.pop('model_path')}",
tasks=[task_name],
num_fewshot=kwargs.get("ntrain", 0),
batch_size=1,
no_cache=True,
limit=kwargs.get("ntest", None),
device="0",
)
print(evaluator.make_table(results))
Expand Down
47 changes: 31 additions & 16 deletions mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pandas as pd
from fire import Fire
import torch
from tqdm import tqdm

from modeling import select_model, EvalModel
Expand Down Expand Up @@ -133,28 +134,41 @@ def gen_prompt(train_df, subject, k=-1):
return prompt


def evaluate(args, subject, model: EvalModel, dev_df, test_df):
def evaluate(args, subject, model: EvalModel, dev_df, test_df, batch_size=1):
cors = []
all_probs = []

for i in range(test_df.shape[0]):
# get prompt and make sure it fits
k = args.ntrain
prompt_end = format_example(test_df, i, include_answer=False)
train_prompt = gen_prompt(dev_df, subject, k)
prompt = train_prompt + prompt_end
num_batches = int(np.ceil(test_df.shape[0] / float(batch_size)))

while not model.check_valid_length(prompt) and k > 0:
k -= 1
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, test_df.shape[0])

batch_prompts = []
batch_labels = []
for i in range(start_idx, end_idx):
# get prompt and make sure it fits
k = args.ntrain
prompt_end = format_example(test_df, i, include_answer=False)
train_prompt = gen_prompt(dev_df, subject, k)
prompt = train_prompt + prompt_end

label = test_df.iloc[i, test_df.shape[1] - 1]
pred = model.run(prompt)
probs = [0 for _ in get_choices()]
cor = pred.strip().startswith(label)
cors.append(cor)
all_probs.append(probs)
while not model.check_valid_length(prompt) and k > 0:
k -= 1
train_prompt = gen_prompt(dev_df, subject, k)
prompt = train_prompt + prompt_end

label = test_df.iloc[i, test_df.shape[1] - 1]
batch_prompts.append(prompt)
batch_labels.append(label)

preds = model.run_batch(batch_prompts)
for pred, label in zip(preds, batch_labels):
probs = [0 for _ in get_choices()]
cor = pred.strip().startswith(label)
cors.append(cor)
all_probs.append(probs)
torch.cuda.empty_cache()

acc = np.mean(cors)
cors = np.array(cors)
Expand All @@ -165,6 +179,7 @@ def evaluate(args, subject, model: EvalModel, dev_df, test_df):
return cors, acc, all_probs



def main(data_dir: str = "data/mmlu", ntrain: int = 5, **kwargs):
args = Namespace(**locals())
model = select_model(max_input_length=2048, max_output_length=2, **kwargs)
Expand Down Expand Up @@ -350,4 +365,4 @@ def main(data_dir: str = "data/mmlu", ntrain: int = 5, **kwargs):


if __name__ == "__main__":
Fire()
Fire()
102 changes: 100 additions & 2 deletions modeling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import json
import signal
import time
from pathlib import Path
from typing import Optional, Tuple
import requests
from typing import Optional, Tuple, List

import httpx
import openai
import rwkv
import tiktoken
Expand All @@ -27,6 +30,8 @@
AutoModel,
LlamaConfig,
)
from vllm import LLM, SamplingParams


import quant

Expand All @@ -38,6 +43,9 @@ class EvalModel(BaseModel, arbitrary_types_allowed=True):

def run(self, prompt: str, **kwargs) -> str:
raise NotImplementedError

def run_batch(self, prompts: List[str], **kwargs) -> List[str]:
return [self.run(prompt, **kwargs) for prompt in prompts]

def count_text_length(self, text: str) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -232,6 +240,94 @@ def get_choice(self, text: str, **kwargs) -> Tuple[float, float]:
return A, B


class VllmModel(EvalModel):
llm: Optional[LLM] = None
sampling_params: Optional[SamplingParams] = None

def load(self):
if self.llm is None:
self.llm = LLM(model=self.model_path)
self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

def run(self, prompt: str, **kwargs) -> str:
self.load()
output = self.llm.generate(prompt, self.sampling_params, use_tqdm=False)
return output[0].outputs[0].text

def run_batch(self, prompts: List[str], **kwargs) -> List[str]:
self.load()
outputs = self.llm.generate(prompts, self.sampling_params, use_tqdm=False)
return [output.outputs[0].text for output in outputs]

def count_text_length(self, text: str) -> int:
self.load()
tokenizer = self.llm.get_tokenizer()
return len(tokenizer.encode(text))

def get_choice(self, text: str, **kwargs) -> Tuple[float, float]:
self.load()
self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95, logprobs=200)
output = self.llm.generate(text, self.sampling_params, use_tqdm=False)
tokenizer = self.llm.get_tokenizer()
A_index = tokenizer("A", add_special_tokens=False).input_ids[0]
B_index = tokenizer("B", add_special_tokens=False).input_ids[0]
A = float(output[0].outputs[0].logprobs[0][A_index]) if A_index in output[0].outputs[0].logprobs[0] else 0.0
B = float(output[0].outputs[0].logprobs[0][B_index]) if B_index in output[0].outputs[0].logprobs[0] else 0.0
return A, B

class LightllmModel(EvalModel):
api_url: str = None
tokenizer: Optional[PreTrainedTokenizer]

def load(self):
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=True)
self.api_url = "http://localhost:8000/generate"
self.max_input_length = 100000
print("load success")

def run(self, prompt: str, **kwargs) -> str:
self.load()
headers = {'Content-Type': 'application/json'}
pload = {
'inputs': prompt,
"parameters": {
'do_sample': False,
'ignore_eos': False,
'max_new_tokens': 2,
}
}
response = requests.post(self.api_url, headers=headers, json=pload, stream=False)
text = response.json()['generated_text'][0]
return text

async def async_run(self, prompt: str, **kwargs) -> str:
headers = {'Content-Type': 'application/json'}
pload = {
'inputs': prompt,
"parameters": {
'do_sample': False,
'ignore_eos': False,
'max_new_tokens': 2,
}
}
async with httpx.AsyncClient(timeout=None) as client:
response = await client.post(self.api_url, headers=headers, json=pload)
text = response.json()['generated_text'][0]
return text

async def _async_run_batch(self, prompts: List[str], **kwargs) -> List[str]:
tasks: List[asyncio.Task] = [self.async_run(prompt, **kwargs) for prompt in prompts]
return await asyncio.gather(*tasks)

def run_batch(self, prompts: List[str], **kwargs) -> List[str]:
return asyncio.run(self._async_run_batch(prompts, **kwargs))

def count_text_length(self, text: str) -> int:
self.load()
return len(self.tokenizer.encode(text))


class LlamaModel(SeqToSeqModel):
use_template: bool = False
"""
Expand Down Expand Up @@ -498,6 +594,8 @@ def select_model(model_name: str, **kwargs) -> EvalModel:
openai=OpenAIModel,
rwkv=RWKVModel,
gptq=GPTQModel,
vllm=VllmModel,
lightllm=LightllmModel,
)
model_class = model_map.get(model_name)
if model_class is None:
Expand Down Expand Up @@ -534,4 +632,4 @@ def test_model(


if __name__ == "__main__":
Fire()
Fire()