-
Notifications
You must be signed in to change notification settings - Fork 812
add support for not loading weights #1424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some questions but overall the change itself looks good.
torchtext/models/roberta/bundler.py
Outdated
@@ -56,7 +56,10 @@ class RobertaModelBundle: | |||
_head: Optional[Module] = None | |||
transform: Optional[Callable] = None | |||
|
|||
def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel: | |||
def get_model(self, load_weights=True, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments.
- Adding an argument in front of existing one can be a BC-breaking, if this API is already a part of the previous release.
- Among
head
andload_weights
, which one do you think is use more often? I think the one used more frequently should come first.
Also this makes me wonder what the combination of load_weights
and custom head should behave.
Is the provided custom head expected to be trained or untrained?
If load_weights=False
, this does not matter, but if load_weights=True
, there are two cases, where the provided custom head comes with pre-trained weights and not.
Now, looking at the logic where the state dict is loaded, (model.load_state_dict(state_dict, strict=False)
), isn't this code overwriting the weight for the given custom head, if the key matches? Of course, if that's the spec and it is documented somewhere that's okay, (and this is out-of-the-scope of this PR) but I did not realize this when I reviewed the original PR for get_model
logic. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments.
- Adding an argument in front of existing one can be a BC-breaking, if this API is already a part of the previous release.
They are not released yet, so we are OK breaking it if necessary.
- Among
head
andload_weights
, which one do you think is use more often? I think the one used more frequently should come first.
It's a good point. I cannot say for sure, but my guess would be users would want to put their custom heads more often than making load_weights=False (as they would like to use pre-trained encoder weights). So I will change the order.
Also this makes me wonder what the combination of
load_weights
and custom head should behave. Is the provided custom head expected to be trained or untrained?If
load_weights=False
, this does not matter, but ifload_weights=True
, there are two cases, where the provided custom head comes with pre-trained weights and not.Now, looking at the logic where the state dict is loaded, (
model.load_state_dict(state_dict, strict=False)
), isn't this code overwriting the weight for the given custom head, if the key matches? Of course, if that's the spec and it is documented somewhere that's okay, (and this is out-of-the-scope of this PR) but I did not realize this when I reviewed the original PR forget_model
logic. What do you think?
Thanks for surfacing this. I have yet to figure out the final behavior and document it properly. One idea would be to make sure when the user provide custom head, we only load pre-trained weights for encoder, leaving the custom head in the same state as provided by user. In which case, we do not have to worry about over-writing weights if the key matches. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think it makes more sense to leave the custom head provided by user untouched.
torchtext/models/roberta/bundler.py
Outdated
def get_model(self, load_weights=True, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel: | ||
|
||
if load_weights: | ||
assert self._path is not None, "load_weights cannot be True when _path is not set" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._path
is abstracted away from the regular users, so I think rephrasing it without refereeing to an internal attribute would be better. Otherwise, I would wonder, "Did I do something wrong about _path
?"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense!
Follow-up on #1406
We would like to add support for not loading weights from pre-trained models, such that user can still instantiate standard model architectures but without initializing them to support training from scratch.
Example usage: