Skip to content

fix left-padding #2278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open

fix left-padding #2278

wants to merge 12 commits into from

Conversation

pass-lin
Copy link
Contributor

There was a problem with the previous left-padding implementation
For example the input is [[1, 2, 3], [1, 2, 3, 4]]. In the case of seq-len = 5 the result should be [[0, 1, 2, 3, 0], [1, 2, 3, 4, 0]]
And the output of the previous implementation is [[0, 0, 1, 2, 3], [0, 1, 2, 3, 4, 0]]
This is the wrong implementation, so we fixed it in this PR

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not look right to me, I think we might have introduced some bugs.

padding_shape = [tf.shape(outputs)[0]] + [1] * (len(outputs.shape) - 1)
padding_shape[axis] = shape[axis] - tf.shape(outputs)[axis]
padding_shape = tf.cast(padding_shape, "int64")
print(padding_shape, pad_value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no prints please!

@@ -141,7 +141,7 @@ def check_special_value_type(value, value_name):
self.start_value = start_value
self.end_value = end_value

self.pad_value = pad_value
self.pad_value = pad_value or 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have this line for start end packer but not multi-segment packer? what's the difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have this line for start end packer but not multi-segment packer? what's the difference?

If you delete the test, an error will be reported.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! But that's not the question. We want the implementation of start end packer and multi segment packer to look similar where we can. Having "none" pad value be different between these layers could lead to subtle bugs for end users. Is there a technical reason why we need this line for StartEndPacker and not MultiSegmentPacker? If this is just for tests, let's rework the tests. Let's try to keep the layers working roughly the same.

@@ -17,7 +17,7 @@ def test_dense_input(self):
sequence_length=5, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [0, 0, 5, 6, 7]
expected_output = [5, 6, 7, 0, 0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. Why is left padding the same as right padding in this case? The test case before looks correct, this looks wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. Why is left padding the same as right padding in this case? The test case before looks correct, this looks wrong.

No, it was obviously wrong before.
Because left padding is for casualLM. If expected_output = [0, 0, 5, 6, 7], how to generate output?

@@ -40,7 +40,7 @@ def test_dense_2D_input(self):
sequence_length=5, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [[0, 0, 5, 6, 7]]
expected_output = [[5, 6, 7, 0, 0]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, we are showing a lot of right hand side padding for the left padding option. I think we have introduced a bug.

def pad(x, shape, padding_side, pad_value):
if padding_side == "left":
def pad(x, shape, padding_side, pad_value, axis=-1):
if padding_side == "left" and pad_value is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand what we are trying to do here, but I think it is buggy, we should go back to the reverse and to_tensor approach and avoid this manual padding.

@pass-lin
Copy link
Contributor Author

pass-lin commented May 30, 2025

@mattdangerw
We should revisit the significance of the left-pad.
Left-pad is essentially designed for casualLM. We need to reserve some tokens for inference.
If leave-padding results in [0, 0, 1, 2, 3, 4, 0, 0]. Casualty-lm can be inferred as [0, 0, 1, 2, 3, 4, 5, 6]. But if leave-padding results in [0, 0, 0, 0, 1, 2, 3, 4], how does casualLM reason?

Another strategy is left-pading to maintain the previous implementation, which we implement by modifying the CausalLMPreprocessor. If you prefer this implementation, please close this pr

@mattdangerw
Copy link
Member

@pass-lin got it, yes if you are hoping to make left padding work for generative inference that is a larger undertaking and needs more consideration.

Might be better to start with an end to end prototype of how the whole thing would work for generation. How does a end user specify left padding in a high-level flow? How are we going to make sure position embeddings are correctly computed for input sequences? We don't pass a full input of position ints right now, which would basically break an attempt to correctly left pad during generation.

We also have to have to think of API design in a layered way. Having StartEndPacker(sequence_length=10, padding_side="left")([1, 2, 3]) padding on the right hand side is just plain confusing. How can we make this behavior intuitive for users that might not even be doing generation with the low level layer?

In short, a lot to figure out here for generation. We'd probably want to start with the ability to pass position tensor inputs to our backbones, and maybe further refactor our causal lm implementation so that we wouldn't have to duplicate logic here across a number of models. Not sure we are ready for this.

@pass-lin
Copy link
Contributor Author

pass-lin commented Jun 2, 2025

@pass-lin got it, yes if you are hoping to make left padding work for generative inference that is a larger undertaking and needs more consideration.

Might be better to start with an end to end prototype of how the whole thing would work for generation. How does a end user specify left padding in a high-level flow? How are we going to make sure position embeddings are correctly computed for input sequences? We don't pass a full input of position ints right now, which would basically break an attempt to correctly left pad during generation.

We also have to have to think of API design in a layered way. Having padding on the right hand side is just plain confusing. How can we make this behavior intuitive for users that might not even be doing generation with the low level layer?StartEndPacker(sequence_length=10, padding_side="left")([1, 2, 3])

In short, a lot to figure out here for generation. We'd probably want to start with the ability to pass position tensor inputs to our backbones, and maybe further refactor our causal lm implementation so that we wouldn't have to duplicate logic here across a number of models. Not sure we are ready for this.

We need to confirm it. For casuallm.generate, we should use left-padding. Should we use our current modification method or CausalLMPreprocessor?

And is the "casualLM batch generate" task in the plans of the keras-team? If the keras-team plans to implement this function, then I will close this PR, and everything will follow your plans.

@mattdangerw
Copy link
Member

We need to confirm it. For casuallm.generate, we should use left-padding. Should we use our current modification method or CausalLMPreprocessor?

It would be great to expose an option for left padding, but only if we could do it correctly, which would require many other changes (notably the need to pass position tensor input). Probably best to start with a prototype. Keep in mind generation might look fine without position inputs, but would not be correct (and we could not merge).

And is the "casualLM batch generate" task in the plans of the keras-team? If the keras-team plans to implement this function, then I will close this PR, and everything will follow your plans.

I don't know what you mean here. Batch generation is already supported. The only difference between right and left padding would be when doing early stopping. In the case where you have multiple sequences left padding would sometimes lead to an earlier stopping condition.

What is it you are trying to do end to end?

@pass-lin
Copy link
Contributor Author

pass-lin commented Jun 3, 2025

我们需要确认一下。对于 casuallm.generate,我们应该使用 left-padding。我们应该使用当前的修改方法还是 CausalLMPreprocessor?

公开左填充选项会很棒,但前提是我们能正确地做到这一点,这将需要许多其他更改(特别是需要传递位置张量输入)。最好从原型开始。请记住,没有位置输入的生成可能看起来不错,但会不正确(而且我们无法合并)。

“casualLM batch generate” 任务在 keras-team 的计划中吗?如果 keras-team 打算实现这个功能,那我就关闭这个 PR,一切都按照你的计划进行。

我不知道你在这里是什么意思。已支持批量生成。右填充和左填充之间的唯一区别是进行提前停止时。在有多个序列的情况下,左填充有时会导致提前停止。

你端到端地尝试做什么?

        if strip_prompt:
            outputs = [strip_prompt_function(generate(x), x) for x in inputs]
        else:
            outputs = [generate(x) for x in inputs]

        if self.preprocessor is not None:
            outputs = [postprocess(x) for x in outputs]

This is the current implementation of the generate function, but it does not truly achieve batch inference. Instead, it performs serial inference on each sample. Using left padding is intended to enable functionality similar to model.generate(x, batch_size=16), which is analogous to the batch inference usage of model.predict.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants