|
21 | 21 |
|
22 | 22 | PathLike = Union[Path, str]
|
23 | 23 |
|
24 |
| -_DEFAULT_BATCH_SIZE = 1024 |
| 24 | +_DEFAULT_BATCH_SIZE = 256 |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class ModulewithWeights(Protocol):
|
@@ -93,22 +93,34 @@ def create_embeddings(
|
93 | 93 |
|
94 | 94 | tokenized.extend([tokenizer.encode_plus(token, return_tensors="pt")["input_ids"][0] for token in tokens])
|
95 | 95 |
|
96 |
| - for batch_idx in tqdm(range(0, len(tokenized), _DEFAULT_BATCH_SIZE)): |
97 |
| - batch = tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE] |
| 96 | + # Add token_type_ids only if the model supports it |
| 97 | + add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args |
| 98 | + |
| 99 | + lengths = np.asarray([len(sequence) for sequence in tokenized]) |
| 100 | + sort_order = np.argsort(lengths) |
| 101 | + |
| 102 | + sorted_tokenized = [tokenized[i] for i in sort_order] |
| 103 | + |
| 104 | + pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens") |
| 105 | + |
| 106 | + for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE): |
| 107 | + batch = sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE] |
98 | 108 |
|
99 | 109 | encoded = {}
|
100 | 110 | encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
|
101 | 111 | encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
|
102 | 112 |
|
103 |
| - # Add token_type_ids only if the model supports it |
104 |
| - if "token_type_ids" in inspect.getfullargspec(model.forward).args: |
| 113 | + if add_token_type_ids: |
105 | 114 | encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
|
106 | 115 |
|
107 | 116 | out = _encode_mean_using_model(model, encoded)
|
108 |
| - intermediate_weights.append(out.numpy()) |
| 117 | + intermediate_weights.extend(out.numpy()) |
| 118 | + pbar.update(len(batch)) |
109 | 119 |
|
| 120 | + # Sort the output back to the original order |
| 121 | + intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)] |
110 | 122 | out_tokens.extend(tokens)
|
111 |
| - out_weights = np.concatenate(intermediate_weights) |
| 123 | + out_weights = np.stack(intermediate_weights) |
112 | 124 |
|
113 | 125 | return out_tokens, out_weights
|
114 | 126 |
|
|
0 commit comments