@@ -740,6 +740,7 @@ def _compose_mro(cls, types):
740740 # Remove entries which are already present in the __mro__ or unrelated.
741741 def is_related (typ ):
742742 return (typ not in bases and hasattr (typ , '__mro__' )
743+ and not isinstance (typ , GenericAlias )
743744 and issubclass (cls , typ ))
744745 types = [n for n in types if is_related (n )]
745746 # Remove entries which are strict bases of other entries (they will end up
@@ -837,16 +838,25 @@ def dispatch(cls):
837838 dispatch_cache [cls ] = impl
838839 return impl
839840
841+ def _is_valid_dispatch_type (cls ):
842+ return isinstance (cls , type ) and not isinstance (cls , GenericAlias )
843+
840844 def register (cls , func = None ):
841845 """generic_func.register(cls, func) -> func
842846
843847 Registers a new implementation for the given *cls* on a *generic_func*.
844848
845849 """
846850 nonlocal cache_token
847- if func is None :
848- if isinstance ( cls , type ) :
851+ if _is_valid_dispatch_type ( cls ) :
852+ if func is None :
849853 return lambda f : register (cls , f )
854+ else :
855+ if func is not None :
856+ raise TypeError (
857+ f"Invalid first argument to `register()`. "
858+ f"{ cls !r} is not a class."
859+ )
850860 ann = getattr (cls , '__annotations__' , {})
851861 if not ann :
852862 raise TypeError (
@@ -859,11 +869,12 @@ def register(cls, func=None):
859869 # only import typing if annotation parsing is necessary
860870 from typing import get_type_hints
861871 argname , cls = next (iter (get_type_hints (func ).items ()))
862- if not isinstance (cls , type ):
872+ if not _is_valid_dispatch_type (cls ):
863873 raise TypeError (
864874 f"Invalid annotation for { argname !r} . "
865875 f"{ cls !r} is not a class."
866876 )
877+
867878 registry [cls ] = func
868879 if cache_token is None and hasattr (cls , '__abstractmethods__' ):
869880 cache_token = get_cache_token ()
0 commit comments