Skip to content

Commit 6e137da

Browse files
kikehjkimbo
authored andcommitted
Check for filters defined on base filterset classes (#730)
* Check for filters defined on base filterset classes * Make python2.7 compatible and run black * Add filter method and use filter in test * Check article headline and reformat
1 parent 59f4f13 commit 6e137da

File tree

2 files changed

+116
-9
lines changed

2 files changed

+116
-9
lines changed

graphene_django/filter/tests/test_fields.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,106 @@ class Query(ObjectType):
818818
}
819819
"""
820820
)
821+
822+
823+
def test_filter_filterset_based_on_mixin():
824+
class ArticleFilterMixin(FilterSet):
825+
@classmethod
826+
def get_filters(cls):
827+
filters = super(FilterSet, cls).get_filters()
828+
filters.update(
829+
{
830+
"viewer__email__in": django_filters.CharFilter(
831+
method="filter_email_in", field_name="reporter__email__in"
832+
)
833+
}
834+
)
835+
836+
return filters
837+
838+
def filter_email_in(cls, queryset, name, value):
839+
return queryset.filter(**{name: [value]})
840+
841+
class NewArticleFilter(ArticleFilterMixin, ArticleFilter):
842+
pass
843+
844+
class NewReporterNode(DjangoObjectType):
845+
class Meta:
846+
model = Reporter
847+
interfaces = (Node,)
848+
849+
class NewArticleFilterNode(DjangoObjectType):
850+
viewer = Field(NewReporterNode)
851+
852+
class Meta:
853+
model = Article
854+
interfaces = (Node,)
855+
filterset_class = NewArticleFilter
856+
857+
def resolve_viewer(self, info):
858+
return self.reporter
859+
860+
class Query(ObjectType):
861+
all_articles = DjangoFilterConnectionField(NewArticleFilterNode)
862+
863+
reporter_1 = Reporter.objects.create(
864+
first_name="John", last_name="Doe", email="[email protected]"
865+
)
866+
867+
article_1 = Article.objects.create(
868+
headline="Hello",
869+
reporter=reporter_1,
870+
editor=reporter_1,
871+
pub_date=datetime.now(),
872+
pub_date_time=datetime.now(),
873+
)
874+
875+
reporter_2 = Reporter.objects.create(
876+
first_name="Adam", last_name="Doe", email="[email protected]"
877+
)
878+
879+
article_2 = Article.objects.create(
880+
headline="Good Bye",
881+
reporter=reporter_2,
882+
editor=reporter_2,
883+
pub_date=datetime.now(),
884+
pub_date_time=datetime.now(),
885+
)
886+
887+
schema = Schema(query=Query)
888+
889+
query = (
890+
"""
891+
query NodeFilteringQuery {
892+
allArticles(viewer_Email_In: "%s") {
893+
edges {
894+
node {
895+
headline
896+
viewer {
897+
email
898+
}
899+
}
900+
}
901+
}
902+
}
903+
"""
904+
% reporter_1.email
905+
)
906+
907+
expected = {
908+
"allArticles": {
909+
"edges": [
910+
{
911+
"node": {
912+
"headline": article_1.headline,
913+
"viewer": {"email": reporter_1.email},
914+
}
915+
}
916+
]
917+
}
918+
}
919+
920+
result = schema.execute(query)
921+
922+
assert not result.errors
923+
assert result.data == expected

graphene_django/filter/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@ def get_filtering_args_from_filterset(filterset_class, type):
1313
args = {}
1414
model = filterset_class._meta.model
1515
for name, filter_field in six.iteritems(filterset_class.base_filters):
16+
form_field = None
17+
1618
if name in filterset_class.declared_filters:
1719
form_field = filter_field.field
1820
else:
1921
field_name = name.split("__", 1)[0]
20-
model_field = model._meta.get_field(field_name)
2122

22-
if hasattr(model_field, "formfield"):
23-
form_field = model_field.formfield(
24-
required=filter_field.extra.get("required", False)
25-
)
23+
if hasattr(model, field_name):
24+
model_field = model._meta.get_field(field_name)
25+
26+
if hasattr(model_field, "formfield"):
27+
form_field = model_field.formfield(
28+
required=filter_field.extra.get("required", False)
29+
)
2630

27-
# Fallback to field defined on filter if we can't get it from the
28-
# model field
29-
if not form_field:
30-
form_field = filter_field.field
31+
# Fallback to field defined on filter if we can't get it from the
32+
# model field
33+
if not form_field:
34+
form_field = filter_field.field
3135

3236
field_type = convert_form_field(form_field).Argument()
3337
field_type.description = filter_field.label

0 commit comments

Comments
 (0)