Skip to content

add support for not loading weights #1424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 26, 2021
Merged

Conversation

parmeet
Copy link
Contributor

@parmeet parmeet commented Oct 26, 2021

Follow-up on #1406

We would like to add support for not loading weights from pre-trained models, such that user can still instantiate standard model architectures but without initializing them to support training from scratch.

Example usage:

import torchtext
xlmr_base = torchtext.models.XLMR_BASE_ENCODER
model_uninitialized = xlmr_base.get_model(load_weights=False)

Copy link
Contributor

@mthrok mthrok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions but overall the change itself looks good.

@@ -56,7 +56,10 @@ class RobertaModelBundle:
_head: Optional[Module] = None
transform: Optional[Callable] = None

def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel:
def get_model(self, load_weights=True, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel:
Copy link
Contributor

@mthrok mthrok Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments.

  1. Adding an argument in front of existing one can be a BC-breaking, if this API is already a part of the previous release.
  2. Among head and load_weights, which one do you think is use more often? I think the one used more frequently should come first.

Also this makes me wonder what the combination of load_weights and custom head should behave.
Is the provided custom head expected to be trained or untrained?

If load_weights=False, this does not matter, but if load_weights=True, there are two cases, where the provided custom head comes with pre-trained weights and not.

Now, looking at the logic where the state dict is loaded, (model.load_state_dict(state_dict, strict=False)), isn't this code overwriting the weight for the given custom head, if the key matches? Of course, if that's the spec and it is documented somewhere that's okay, (and this is out-of-the-scope of this PR) but I did not realize this when I reviewed the original PR for get_model logic. What do you think?

Copy link
Contributor Author

@parmeet parmeet Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments.

  1. Adding an argument in front of existing one can be a BC-breaking, if this API is already a part of the previous release.

They are not released yet, so we are OK breaking it if necessary.

  1. Among head and load_weights, which one do you think is use more often? I think the one used more frequently should come first.

It's a good point. I cannot say for sure, but my guess would be users would want to put their custom heads more often than making load_weights=False (as they would like to use pre-trained encoder weights). So I will change the order.

Also this makes me wonder what the combination of load_weights and custom head should behave. Is the provided custom head expected to be trained or untrained?

If load_weights=False, this does not matter, but if load_weights=True, there are two cases, where the provided custom head comes with pre-trained weights and not.

Now, looking at the logic where the state dict is loaded, (model.load_state_dict(state_dict, strict=False)), isn't this code overwriting the weight for the given custom head, if the key matches? Of course, if that's the spec and it is documented somewhere that's okay, (and this is out-of-the-scope of this PR) but I did not realize this when I reviewed the original PR for get_model logic. What do you think?

Thanks for surfacing this. I have yet to figure out the final behavior and document it properly. One idea would be to make sure when the user provide custom head, we only load pre-trained weights for encoder, leaving the custom head in the same state as provided by user. In which case, we do not have to worry about over-writing weights if the key matches. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it makes more sense to leave the custom head provided by user untouched.

def get_model(self, load_weights=True, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel:

if load_weights:
assert self._path is not None, "load_weights cannot be True when _path is not set"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._path is abstracted away from the regular users, so I think rephrasing it without refereeing to an internal attribute would be better. Otherwise, I would wonder, "Did I do something wrong about _path?"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants