diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 8b0c3d0e3b..dc4ce808d4 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -56,7 +56,10 @@ class RobertaModelBundle: _head: Optional[Module] = None transform: Optional[Callable] = None - def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel: + def get_model(self, head: Optional[Module] = None, load_weights=True, *, dl_kwargs=None) -> RobertaModel: + + if load_weights: + assert self._path is not None, "load_weights cannot be True. The pre-trained model weights are not available for the current object" if head is not None: input_head = head @@ -67,6 +70,9 @@ def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> Roberta model = _get_model(self._params, input_head) + if not load_weights: + return model + dl_kwargs = {} if dl_kwargs is None else dl_kwargs state_dict = load_state_dict_from_url(self._path, **dl_kwargs) if input_head is not None: