Skip to content

Commit 0402b98

Browse files
authored
feat: faster inference for large vocab (MinishLab#221)
1 parent 79900e7 commit 0402b98

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

model2vec/distill/inference.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
PathLike = Union[Path, str]
2323

24-
_DEFAULT_BATCH_SIZE = 1024
24+
_DEFAULT_BATCH_SIZE = 256
2525

2626

2727
class ModulewithWeights(Protocol):
@@ -93,22 +93,34 @@ def create_embeddings(
9393

9494
tokenized.extend([tokenizer.encode_plus(token, return_tensors="pt")["input_ids"][0] for token in tokens])
9595

96-
for batch_idx in tqdm(range(0, len(tokenized), _DEFAULT_BATCH_SIZE)):
97-
batch = tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
96+
# Add token_type_ids only if the model supports it
97+
add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args
98+
99+
lengths = np.asarray([len(sequence) for sequence in tokenized])
100+
sort_order = np.argsort(lengths)
101+
102+
sorted_tokenized = [tokenized[i] for i in sort_order]
103+
104+
pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")
105+
106+
for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
107+
batch = sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
98108

99109
encoded = {}
100110
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
101111
encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
102112

103-
# Add token_type_ids only if the model supports it
104-
if "token_type_ids" in inspect.getfullargspec(model.forward).args:
113+
if add_token_type_ids:
105114
encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
106115

107116
out = _encode_mean_using_model(model, encoded)
108-
intermediate_weights.append(out.numpy())
117+
intermediate_weights.extend(out.numpy())
118+
pbar.update(len(batch))
109119

120+
# Sort the output back to the original order
121+
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
110122
out_tokens.extend(tokens)
111-
out_weights = np.concatenate(intermediate_weights)
123+
out_weights = np.stack(intermediate_weights)
112124

113125
return out_tokens, out_weights
114126

0 commit comments

Comments
 (0)