@@ -25,6 +25,23 @@ def is_foreignkey_in_class(node):
25
25
return attr in ('OneToOneField' , 'ForeignKey' )
26
26
27
27
28
+ def _get_model_class_defs_from_module (module , model_name , module_name ):
29
+ class_defs = []
30
+ for module_node in module .lookup (model_name )[1 ]:
31
+ if isinstance (module_node , nodes .ClassDef ) and node_is_subclass (
32
+ module_node , "django.db.models.base.Model"
33
+ ):
34
+ class_defs .append (module_node )
35
+ elif isinstance (module_node , nodes .ImportFrom ):
36
+ imported_module = module_node .do_import_module ()
37
+ class_defs .extend (
38
+ _get_model_class_defs_from_module (
39
+ imported_module , model_name , module_name
40
+ )
41
+ )
42
+ return class_defs
43
+
44
+
28
45
def infer_key_classes (node , context = None ):
29
46
keyword_args = [kw .value for kw in node .keywords ]
30
47
all_args = chain (node .args , keyword_args )
@@ -73,17 +90,20 @@ def infer_key_classes(node, context=None):
73
90
# 'auth.models', 'User' which works nicely with the `endswith()`
74
91
# comparison below
75
92
module_name += '.models'
93
+ # ensure that module is loaded in cache, for cases when models is a package
94
+ if module_name not in MANAGER .astroid_cache :
95
+ MANAGER .ast_from_module_name (module_name )
76
96
77
- for module in MANAGER .astroid_cache .values ():
97
+ # create list from dict_values, because it may be modified in loop
98
+ for module in list (MANAGER .astroid_cache .values ()):
78
99
# only load model classes from modules which match the module in
79
100
# which *we think* they are defined. This will prevent infering
80
101
# other models of the same name which are found elsewhere!
81
102
if model_name in module .locals and module .name .endswith (module_name ):
82
- class_defs = [
83
- module_node for module_node in module .lookup (model_name )[1 ]
84
- if isinstance (module_node , nodes .ClassDef )
85
- and node_is_subclass (module_node , 'django.db.models.base.Model' )
86
- ]
103
+ class_defs = _get_model_class_defs_from_module (
104
+ module , model_name , module_name
105
+ )
106
+
87
107
if class_defs :
88
108
return iter ([class_defs [0 ].instantiate_class ()])
89
109
else :
0 commit comments