Closed
Description
Follow-up to #750 (comment)
- Return all beams (tensor of shape
(batch_size, num_beams, sequence_length)
). - Return the beam scores as well (tensor of shape
(batch_size, num_beams)
).
The changes are to be made to this utility function: https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/utils/text_generation.py#L213.
@mattdangerw, do we want to sort the outputs based on log prob and return? Or is that unnecessary overhead?
Some basic tinkering done because a user wanted a quick fix: https://github.com/abheesht17/keras-nlp/blob/beam_search_return_all_seqs/keras_nlp/utils/text_generation.py#L400-L414