Skip to content

Commit 1fee26e

Browse files
imomalievatodorov
authored andcommitted
infer model_class in foreignk key when models is a package
1 parent b8ec891 commit 1fee26e

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

pylint_django/transforms/foreignkey.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@ def is_foreignkey_in_class(node):
2525
return attr in ('OneToOneField', 'ForeignKey')
2626

2727

28+
def _get_model_class_defs_from_module(module, model_name, module_name):
29+
class_defs = []
30+
for module_node in module.lookup(model_name)[1]:
31+
if isinstance(module_node, nodes.ClassDef) and node_is_subclass(
32+
module_node, "django.db.models.base.Model"
33+
):
34+
class_defs.append(module_node)
35+
elif isinstance(module_node, nodes.ImportFrom):
36+
imported_module = module_node.do_import_module()
37+
class_defs.extend(
38+
_get_model_class_defs_from_module(
39+
imported_module, model_name, module_name
40+
)
41+
)
42+
return class_defs
43+
44+
2845
def infer_key_classes(node, context=None):
2946
keyword_args = [kw.value for kw in node.keywords]
3047
all_args = chain(node.args, keyword_args)
@@ -73,17 +90,20 @@ def infer_key_classes(node, context=None):
7390
# 'auth.models', 'User' which works nicely with the `endswith()`
7491
# comparison below
7592
module_name += '.models'
93+
# ensure that module is loaded in cache, for cases when models is a package
94+
if module_name not in MANAGER.astroid_cache:
95+
MANAGER.ast_from_module_name(module_name)
7696

77-
for module in MANAGER.astroid_cache.values():
97+
# create list from dict_values, because it may be modified in loop
98+
for module in list(MANAGER.astroid_cache.values()):
7899
# only load model classes from modules which match the module in
79100
# which *we think* they are defined. This will prevent infering
80101
# other models of the same name which are found elsewhere!
81102
if model_name in module.locals and module.name.endswith(module_name):
82-
class_defs = [
83-
module_node for module_node in module.lookup(model_name)[1]
84-
if isinstance(module_node, nodes.ClassDef)
85-
and node_is_subclass(module_node, 'django.db.models.base.Model')
86-
]
103+
class_defs = _get_model_class_defs_from_module(
104+
module, model_name, module_name
105+
)
106+
87107
if class_defs:
88108
return iter([class_defs[0].instantiate_class()])
89109
else:

0 commit comments

Comments
 (0)