@@ -80,14 +80,18 @@ class ConfigMixin:
8080 - **config_name** (`str`) -- A filename under which the config should stored when calling
8181 [`~ConfigMixin.save_config`] (should be overridden by parent class).
8282 - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
83- overridden by parent class).
84- - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
85- class).
83+ overridden by subclass).
84+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
85+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
86+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
87+ subclass).
8688 """
8789 config_name = None
8890 ignore_for_config = []
8991 has_compatibles = False
9092
93+ _deprecated_kwargs = []
94+
9195 def register_to_config (self , ** kwargs ):
9296 if self .config_name is None :
9397 raise NotImplementedError (f"Make sure that { self .__class__ } has defined a class name `config_name`" )
@@ -195,10 +199,10 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
195199 if "dtype" in unused_kwargs :
196200 init_dict ["dtype" ] = unused_kwargs .pop ("dtype" )
197201
198- if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict :
199- deprecate ( "remove this" , "0.10.0" , "remove" )
200- predict_epsilon = unused_kwargs . pop ( "predict_epsilon" )
201- init_dict ["prediction_type" ] = "epsilon" if predict_epsilon else "sample"
202+ # add possible deprecated kwargs
203+ for deprecated_kwarg in cls . _deprecated_kwargs :
204+ if deprecated_kwarg in unused_kwargs :
205+ init_dict [deprecated_kwarg ] = unused_kwargs . pop ( deprecated_kwarg )
202206
203207 # Return model and optionally state and/or unused_kwargs
204208 model = cls (** init_dict )
@@ -526,7 +530,6 @@ def inner_init(self, *args, **kwargs):
526530 # Ignore private kwargs in the init.
527531 init_kwargs = {k : v for k , v in kwargs .items () if not k .startswith ("_" )}
528532 config_init_kwargs = {k : v for k , v in kwargs .items () if k .startswith ("_" )}
529- init (self , * args , ** init_kwargs )
530533 if not isinstance (self , ConfigMixin ):
531534 raise RuntimeError (
532535 f"`@register_for_config` was applied to { self .__class__ .__name__ } init method, but this class does "
@@ -553,6 +556,7 @@ def inner_init(self, *args, **kwargs):
553556 )
554557 new_kwargs = {** config_init_kwargs , ** new_kwargs }
555558 getattr (self , "register_to_config" )(** new_kwargs )
559+ init (self , * args , ** init_kwargs )
556560
557561 return inner_init
558562
0 commit comments