Skip to content

Commit 889bfa4

Browse files
Update our sampler documentation to reflect usage (#1444)
We will update our samplers in the near future to push the backend specific compilation details out: keras-team/keras-hub#1425 Also in general, we want our documentation to reflect the main usage of our classes, which is using them with Seq2SeqLM and CausalLM classes. So with that in mind, this updates our sampler docs to show the practical usage of the sampling classes with our modeling classes. For the base class, we show the main use case of overriding the `get_next_token()` function.
1 parent d5df710 commit 889bfa4

File tree

7 files changed

+77
-208
lines changed

7 files changed

+77
-208
lines changed

keras_nlp/samplers/beam_sampler.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
from keras_nlp.api_export import keras_nlp_export
1919
from keras_nlp.backend import ops
2020
from keras_nlp.samplers.sampler import Sampler
21-
from keras_nlp.samplers.sampler import call_args_docstring
22-
from keras_nlp.utils.python_utils import format_docstring
2321

2422

25-
@format_docstring(call_args=call_args_docstring)
2623
@keras_nlp_export("keras_nlp.samplers.BeamSampler")
2724
class BeamSampler(Sampler):
2825
"""Beam Sampler class.
@@ -42,55 +39,17 @@ class BeamSampler(Sampler):
4239
{{call_args}}
4340
4441
Examples:
45-
Return only the beam with the highest accumulated probability.
4642
```python
47-
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
48-
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
49-
char_lookup = {v: k for k, v in int_lookup.items()}
50-
batch_size, length, vocab_size = 1, 12, len(int_lookup)
51-
52-
def next(prompt, cache, index):
53-
prompt_batch_size = tf.shape(prompt)[0]
54-
hidden_states = np.ones((prompt_batch_size, 10))
55-
# A uniform distribution over our alphabet.
56-
logits = np.ones((prompt_batch_size, vocab_size))
57-
return logits, hidden_states, cache
58-
59-
output = keras_nlp.samplers.BeamSampler()(
60-
next=next,
61-
prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"),
62-
index=5,
63-
)
64-
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
65-
# >>> ['zzzzzeeeeeee']
66-
```
43+
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
6744
68-
Return all beams and their probabilities.
69-
```python
70-
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
71-
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
72-
char_lookup = {v: k for k, v in int_lookup.items()}
73-
batch_size, length, vocab_size = 1, 8, len(int_lookup)
74-
75-
def next(prompt, cache, index):
76-
prompt_batch_size = tf.shape(prompt)[0]
77-
hidden_states = np.ones((prompt_batch_size, 10))
78-
# A uniform distribution over our alphabet.
79-
logits = np.ones((batch_size, vocab_size))
80-
return logits, hidden_states, cache
81-
82-
beams, probs = keras_nlp.samplers.BeamSampler(return_all_beams=True)(
83-
next=next,
84-
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
85-
index=5,
86-
)
87-
88-
print(beams.shape)
89-
# >>> (1, 5, 8)
90-
print(probs.shape)
91-
# >>> (1, 5)
92-
print(["".join([int_lookup[i] for i in s]) for s in beams[0].numpy()])
93-
# >>> ['zzzzzeee', 'zzzzzeed', 'zzzzzeec', 'zzzzzeea', 'zzzzzeeb']
45+
# Pass by name to compile.
46+
causal_lm.compile(sampler="beam")
47+
causal_lm.generate(["Keras is a"])
48+
49+
# Pass by object to compile.
50+
sampler = keras_nlp.samplers.BeamSampler(num_beams=5)
51+
causal_lm.compile(sampler=sampler)
52+
causal_lm.generate(["Keras is a"])
9453
```
9554
"""
9655

keras_nlp/samplers/contrastive_sampler.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717
from keras_nlp.api_export import keras_nlp_export
1818
from keras_nlp.backend import ops
1919
from keras_nlp.samplers.sampler import Sampler
20-
from keras_nlp.samplers.sampler import call_args_docstring
21-
from keras_nlp.utils.python_utils import format_docstring
2220

2321

24-
@format_docstring(call_args=call_args_docstring)
2522
@keras_nlp_export("keras_nlp.samplers.ContrastiveSampler")
2623
class ContrastiveSampler(Sampler):
2724
"""Contrastive Sampler class.
@@ -44,28 +41,16 @@ class ContrastiveSampler(Sampler):
4441
4542
Examples:
4643
```python
47-
# Use a simple alphabet of lowercase characters to [0, 26).
48-
int_lookup = {i: chr(i + ord("a")) for i in range(26)}
49-
char_lookup = {v: k for k, v in int_lookup.items()}
50-
batch_size, length, vocab_size = 1, 12, len(int_lookup)
51-
hidden_size = 5
52-
index = 5
53-
54-
def next(prompt, cache, index):
55-
prompt_batch_size = tf.shape(prompt)[0]
56-
hidden_states = np.ones((prompt_batch_size, hidden_size))
57-
# A uniform distribution over our alphabet.
58-
logits = np.ones((prompt_batch_size, vocab_size))
59-
return logits, hidden_states, cache
60-
61-
output = keras_nlp.samplers.ContrastiveSampler()(
62-
next=next,
63-
prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"),
64-
index=index,
65-
hidden_states=np.ones([batch_size, index, hidden_size]),
66-
)
67-
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
68-
# >>> "zzzzzeeeeeee"
44+
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
45+
46+
# Pass by name to compile.
47+
causal_lm.compile(sampler="contrastive")
48+
causal_lm.generate(["Keras is a"])
49+
50+
# Pass by object to compile.
51+
sampler = keras_nlp.samplers.ContrastiveSampler(k=5)
52+
causal_lm.compile(sampler=sampler)
53+
causal_lm.generate(["Keras is a"])
6954
```
7055
"""
7156

keras_nlp/samplers/greedy_sampler.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,41 +15,27 @@
1515
from keras_nlp.api_export import keras_nlp_export
1616
from keras_nlp.backend import ops
1717
from keras_nlp.samplers.sampler import Sampler
18-
from keras_nlp.samplers.sampler import call_args_docstring
19-
from keras_nlp.utils.python_utils import format_docstring
2018

2119

22-
@format_docstring(call_args=call_args_docstring)
2320
@keras_nlp_export("keras_nlp.samplers.GreedySampler")
2421
class GreedySampler(Sampler):
2522
"""Greedy sampler class.
2623
2724
This sampler is implemented on greedy search, i.e., always picking up the
2825
token of the largest probability as the next token.
2926
30-
Call arguments:
31-
{{call_args}}
32-
3327
Examples:
3428
```python
35-
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
36-
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
37-
char_lookup = {v: k for k, v in int_lookup.items()}
38-
batch_size, length, vocab_size = 1, 12, len(int_lookup)
39-
40-
def next(prompt, cache, index):
41-
hidden_states = np.ones((batch_size, 10))
42-
# A uniform distribution over our alphabet.
43-
logits = np.ones((batch_size, vocab_size))
44-
return logits, hidden_states, cache
45-
46-
output = keras_nlp.samplers.GreedySampler()(
47-
next=next,
48-
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
49-
index=5,
50-
)
51-
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
52-
# >>> ['zzzzzaaaaaaa']
29+
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
30+
31+
# Pass by name to compile.
32+
causal_lm.compile(sampler="greedy")
33+
causal_lm.generate(["Keras is a"])
34+
35+
# Pass by object to compile.
36+
sampler = keras_nlp.samplers.GreedySampler()
37+
causal_lm.compile(sampler=sampler)
38+
causal_lm.generate(["Keras is a"])
5339
```
5440
"""
5541

keras_nlp/samplers/random_sampler.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@
1616
from keras_nlp.backend import ops
1717
from keras_nlp.backend import random
1818
from keras_nlp.samplers.sampler import Sampler
19-
from keras_nlp.samplers.sampler import call_args_docstring
20-
from keras_nlp.utils.python_utils import format_docstring
2119

2220

23-
@format_docstring(call_args=call_args_docstring)
2421
@keras_nlp_export("keras_nlp.samplers.RandomSampler")
2522
class RandomSampler(Sampler):
2623
"""Random Sampler class.
@@ -37,24 +34,16 @@ class RandomSampler(Sampler):
3734
3835
Examples:
3936
```python
40-
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
41-
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
42-
char_lookup = {v: k for k, v in int_lookup.items()}
43-
batch_size, length, vocab_size = 1, 12, len(int_lookup)
37+
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
4438
45-
def next(prompt, state, index):
46-
hidden_states = np.ones((batch_size, 10))
47-
# A uniform distribution over our alphabet.
48-
logits = np.ones((batch_size, vocab_size))
49-
return logits, hidden_states, state
39+
# Pass by name to compile.
40+
causal_lm.compile(sampler="random")
41+
causal_lm.generate(["Keras is a"])
5042
51-
output = keras_nlp.samplers.RandomSampler()(
52-
next=next,
53-
prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"),
54-
index=5,
55-
)
56-
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
57-
# >>> ['zzzzzcpnjqij']
43+
# Pass by object to compile.
44+
sampler = keras_nlp.samplers.RandomSampler(temperature=0.7)
45+
causal_lm.compile(sampler=sampler)
46+
causal_lm.generate(["Keras is a"])
5847
```
5948
"""
6049

keras_nlp/samplers/sampler.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,8 @@
1717
from keras_nlp.backend import keras
1818
from keras_nlp.backend import ops
1919
from keras_nlp.backend import random
20-
from keras_nlp.utils.python_utils import format_docstring
21-
22-
call_args_docstring = """next: A function which takes in the
23-
`prompt, cache, index` of the current generation loop, and outputs
24-
a tuple `(logits, hidden_states, cache)` with `logits` being the
25-
logits of next token, `hidden_states` being the representation of
26-
the next token, and `cache` for next iteration.
27-
prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This
28-
tensor will be iteratively updated column by column with new sampled
29-
values, starting at `index`.
30-
cache: Optional. A tensor or nested structure of tensors that will be
31-
updated by each call to `next`. This can be used to cache
32-
computations from early iterations of the generative loop.
33-
index: Optional. The first index of `prompt` to start sampling at.
34-
Usually this is set as the length of the shortest non-padded
35-
sequence in `prompt`.
36-
mask: Optional. A 2D integer tensor with the same shape as `prompt`.
37-
Locations which are `True` in the mask are never updated during
38-
sampling. Usually used to mark all locations in the dense prompt
39-
tensor which were present in a user input.
40-
end_token_id: Optional. The token marking the end of the sequence. If
41-
specified, sampling will stop as soon as all sequences in the prompt
42-
produce a `end_token_id` in a location where `mask` is `False`.
43-
"""
44-
45-
46-
@format_docstring(call_args=call_args_docstring)
20+
21+
4722
@keras_nlp_export("keras_nlp.samplers.Sampler")
4823
class Sampler:
4924
"""Base sampler class.
@@ -57,35 +32,32 @@ class Sampler:
5732
{{call_args}}
5833
5934
This base class can be extended to implement different auto-regressive
60-
sampling methods. Subclasses can either:
61-
62-
- Override the `get_next_token()` method, which computes the next token
63-
based on a probability distribution over all possible vocab entries.
64-
- Override `__call__`, if the sampling method needs additional information
65-
beyond the next tokens probability distribution to sample a sequence.
66-
67-
Please check available subclass samplers for examples.
35+
sampling methods. To do so, override the `get_next_token()` method, which
36+
computes the next token based on a probability distribution over all
37+
possible vocab entries.
6838
6939
Examples:
7040
7141
```python
72-
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
73-
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
74-
char_lookup = {v: k for k, v in int_lookup.items()}
75-
batch_size, length, vocab_size = 1, 12, len(int_lookup)
76-
77-
def next(prompt, cache, index):
78-
# return a uniform distribution over our alphabet.
79-
logits = ops.ones((batch_size, vocab_size))
80-
return logits, None, cache
81-
82-
output = keras_nlp.samplers.GreedySampler()(
83-
next=next,
84-
prompt=ops.fill((batch_size, length,), char_lookup['z']),
85-
index=5,
86-
)
87-
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
88-
# >>> ['zzzzzaaaaaaa']
42+
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
43+
44+
# Greedy search with some tokens forbidden.
45+
class CustomSampler(keras_nlp.samplers.Sampler):
46+
def __init__(self, forbidden_tokens, **kwargs):
47+
super().__init__(**kwargs)
48+
self.forbidden_tokens = forbidden_tokens
49+
50+
def get_next_token(self, probs):
51+
batch_size, vocab_size = keras.ops.shape(probs)
52+
for id in self.forbidden_tokens:
53+
update = keras.ops.zeros((batch_size, 1))
54+
probs = keras.ops.slice_update(probs, (0, id), update)
55+
return keras.ops.argmax(probs, axis=-1)
56+
57+
# 257 = "a" with a leading space, 262 = "the" with a leading space.
58+
causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262]))
59+
causal_lm.summary()
60+
causal_lm.generate(["That's strange"])
8961
```
9062
"""
9163

keras_nlp/samplers/top_k_sampler.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@
1616
from keras_nlp.backend import ops
1717
from keras_nlp.backend import random
1818
from keras_nlp.samplers.sampler import Sampler
19-
from keras_nlp.samplers.sampler import call_args_docstring
20-
from keras_nlp.utils.python_utils import format_docstring
2119

2220

23-
@format_docstring(call_args=call_args_docstring)
2421
@keras_nlp_export("keras_nlp.samplers.TopKSampler")
2522
class TopKSampler(Sampler):
2623
"""Top-K Sampler class.
@@ -38,24 +35,16 @@ class TopKSampler(Sampler):
3835
3936
Examples:
4037
```python
41-
# Use a simple alphabet of lowercase characters with ids in range [0, 25].
42-
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
43-
char_lookup = {v: k for k, v in int_lookup.items()}
44-
batch_size, length, vocab_size = 1, 12, len(int_lookup)
38+
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
4539
46-
def next(prompt, cache, index):
47-
hidden_states = np.ones((batch_size, 10))
48-
# A uniform distribution over our alphabet.
49-
logits = np.ones((batch_size, vocab_size))
50-
return logits, hidden_states, cache
40+
# Pass by name to compile.
41+
causal_lm.compile(sampler="top_k")
42+
causal_lm.generate(["Keras is a"])
5143
52-
output = keras_nlp.samplers.TopKSampler(k=3)(
53-
next=next,
54-
prompt=np.full((batch_size, length,), char_lookup['z'], dtypes="int32"),
55-
index=5,
56-
)
57-
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
58-
# >>> ['zzzzzacbbcaa']
44+
# Pass by object to compile.
45+
sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7)
46+
causal_lm.compile(sampler=sampler)
47+
causal_lm.generate(["Keras is a"])
5948
```
6049
"""
6150

0 commit comments

Comments
 (0)