Skip to content

Commit 7b5e0d3

Browse files
authored
hf models can use auth tokens now (stanfordnlp#611)
1 parent da1f8ae commit 7b5e0d3

File tree

5 files changed

+848
-174
lines changed

5 files changed

+848
-174
lines changed

.pre-commit-config.yaml

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ default_stages: [commit]
55
default_install_hook_types: [pre-commit, commit-msg]
66

77
repos:
8-
- repo: https://github.com/astral-sh/ruff-pre-commit
9-
# Ruff version.
10-
rev: v0.1.11
11-
hooks:
12-
# Run the linter.
13-
- id: ruff
14-
args: [--fix]
15-
# Run the formatter.
16-
- id: ruff-format
8+
# - repo: https://github.com/astral-sh/ruff-pre-commit
9+
# # Ruff version.
10+
# rev: v0.1.11
11+
# hooks:
12+
# # Run the linter.
13+
# - id: ruff
14+
# args: [--fix]
15+
# # Run the formatter.
16+
# - id: ruff-format
1717

1818
- repo: https://github.com/timothycrosley/isort
1919
rev: 5.12.0
@@ -50,14 +50,14 @@ repos:
5050
args:
5151
- "--autofix"
5252
- "--indent=2"
53-
- repo: local
54-
hooks:
55-
- id: validate-commit-msg
56-
name: Commit Message is Valid
57-
language: pygrep
58-
entry: ^(break|build|ci|docs|feat|fix|perf|refactor|style|test|ops|hotfix|release|maint|init|enh|revert)\([\w,\.,\-,\(,\),\/]+\)(!?)(:)\s{1}([\w,\W,:]+)
59-
stages: [commit-msg]
60-
args: [--negate]
53+
# - repo: local
54+
# hooks:
55+
# - id: validate-commit-msg
56+
# name: Commit Message is Valid
57+
# language: pygrep
58+
# entry: ^(break|build|ci|docs|feat|fix|perf|refactor|style|test|ops|hotfix|release|maint|init|enh|revert)\([\w,\.,\-,\(,\),\/]+\)(!?)(:)\s{1}([\w,\W,:]+)
59+
# stages: [commit-msg]
60+
# args: [--negate]
6161

6262
- repo: https://github.com/pre-commit/mirrors-prettier
6363
rev: v3.0.3

dsp/modules/hf.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# from peft import PeftConfig, PeftModel
22
# from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, AutoConfig
3+
import os
34
from typing import Literal, Optional
45

56
from dsp.modules.lm import LM
67

78
# from dsp.modules.finetuning.finetune_hf import preprocess_prompt
89

10+
911
def openai_to_hf(**kwargs):
1012
hf_kwargs = {}
1113
for k, v in kwargs.items():
@@ -26,8 +28,19 @@ def openai_to_hf(**kwargs):
2628

2729

2830
class HFModel(LM):
29-
def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool = False,
30-
hf_device_map: Literal["auto", "balanced", "balanced_low_0", "sequential"] = "auto"):
31+
def __init__(
32+
self,
33+
model: str,
34+
checkpoint: Optional[str] = None,
35+
is_client: bool = False,
36+
hf_device_map: Literal[
37+
"auto",
38+
"balanced",
39+
"balanced_low_0",
40+
"sequential",
41+
] = "auto",
42+
token: Optional[str] = None,
43+
):
3144
"""wrapper for Hugging Face models
3245
3346
Args:
@@ -42,6 +55,10 @@ def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool
4255
self.provider = "hf"
4356
self.is_client = is_client
4457
self.device_map = hf_device_map
58+
59+
hf_autoconfig_kwargs = dict(token=token or os.environ.get("HF_TOKEN"))
60+
hf_autotokenizer_kwargs = hf_autoconfig_kwargs.copy()
61+
hf_automodel_kwargs = hf_autoconfig_kwargs.copy()
4562
if not self.is_client:
4663
try:
4764
import torch
@@ -52,40 +69,68 @@ def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool
5269
) from exc
5370
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5471
try:
55-
architecture = AutoConfig.from_pretrained(model).__dict__["architectures"][0]
56-
self.encoder_decoder_model = ("ConditionalGeneration" in architecture) or ("T5WithLMHeadModel" in architecture)
72+
architecture = AutoConfig.from_pretrained(
73+
model,
74+
**hf_autoconfig_kwargs,
75+
).__dict__["architectures"][0]
76+
self.encoder_decoder_model = ("ConditionalGeneration" in architecture) or (
77+
"T5WithLMHeadModel" in architecture
78+
)
5779
self.decoder_only_model = ("CausalLM" in architecture) or ("GPT2LMHeadModel" in architecture)
58-
assert self.encoder_decoder_model or self.decoder_only_model, f"Unknown HuggingFace model class: {model}"
59-
self.tokenizer = AutoTokenizer.from_pretrained(model if checkpoint is None else checkpoint)
80+
assert (
81+
self.encoder_decoder_model or self.decoder_only_model
82+
), f"Unknown HuggingFace model class: {model}"
83+
self.tokenizer = AutoTokenizer.from_pretrained(
84+
model if checkpoint is None else checkpoint,
85+
**hf_autotokenizer_kwargs,
86+
)
6087

6188
self.rationale = True
6289
AutoModelClass = AutoModelForSeq2SeqLM if self.encoder_decoder_model else AutoModelForCausalLM
6390
if checkpoint:
6491
# with open(os.path.join(checkpoint, '..', 'compiler_config.json'), 'r') as f:
6592
# config = json.load(f)
66-
self.rationale = False #config['rationale']
93+
self.rationale = False # config['rationale']
6794
# if config['peft']:
6895
# peft_config = PeftConfig.from_pretrained(checkpoint)
6996
# self.model = AutoModelClass.from_pretrained(peft_config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map=hf_device_map)
7097
# self.model = PeftModel.from_pretrained(self.model, checkpoint)
7198
# else:
7299
if self.device_map:
73-
self.model = AutoModelClass.from_pretrained(checkpoint, device_map=self.device_map)
100+
self.model = AutoModelClass.from_pretrained(
101+
checkpoint,
102+
device_map=self.device_map,
103+
**hf_automodel_kwargs,
104+
)
74105
else:
75-
self.model = AutoModelClass.from_pretrained(checkpoint).to(self.device)
106+
self.model = AutoModelClass.from_pretrained(
107+
checkpoint,
108+
**hf_automodel_kwargs,
109+
).to(self.device)
76110
else:
77111
if self.device_map:
78-
self.model = AutoModelClass.from_pretrained(model, device_map=self.device_map)
112+
self.model = AutoModelClass.from_pretrained(
113+
model,
114+
device_map=self.device_map,
115+
**hf_automodel_kwargs,
116+
)
79117
else:
80-
self.model = AutoModelClass.from_pretrained(model).to(self.device)
118+
self.model = AutoModelClass.from_pretrained(
119+
model,
120+
**hf_automodel_kwargs,
121+
).to(self.device)
81122
self.drop_prompt_from_output = False
82123
except ValueError:
83124
self.model = AutoModelForCausalLM.from_pretrained(
84125
model if checkpoint is None else checkpoint,
85126
device_map=self.device_map,
127+
**hf_automodel_kwargs,
86128
)
87129
self.drop_prompt_from_output = True
88-
self.tokenizer = AutoTokenizer.from_pretrained(model)
130+
self.tokenizer = AutoTokenizer.from_pretrained(
131+
model,
132+
**hf_autotokenizer_kwargs,
133+
)
89134
self.drop_prompt_from_output = True
90135
self.history = []
91136

@@ -111,7 +156,7 @@ def _generate(self, prompt, **kwargs):
111156
# print(prompt)
112157
if isinstance(prompt, dict):
113158
try:
114-
prompt = prompt['messages'][0]['content']
159+
prompt = prompt["messages"][0]["content"]
115160
except (KeyError, IndexError, TypeError):
116161
print("Failed to extract 'content' from the prompt.")
117162
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
@@ -121,10 +166,7 @@ def _generate(self, prompt, **kwargs):
121166
if self.drop_prompt_from_output:
122167
input_length = inputs.input_ids.shape[1]
123168
outputs = outputs[:, input_length:]
124-
completions = [
125-
{"text": c}
126-
for c in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
127-
]
169+
completions = [{"text": c} for c in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]
128170
response = {
129171
"prompt": prompt,
130172
"choices": completions,

0 commit comments

Comments
 (0)