Skip to content

Commit ea8265f

Browse files
authored
[3.x.x/plugins] backport fix for generation of self referencing types (ExpediaGroup#949)
* [3.x.x/plugins] backport fix for generation of self referencing types See: ExpediaGroup#948 for details * fix failing unit test
1 parent 1371cbf commit ea8265f

File tree

10 files changed

+177
-89
lines changed

10 files changed

+177
-89
lines changed

plugins/graphql-kotlin-plugin-core/README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ using [square/kotlinpoet](https://github.com/square/kotlinpoet) library.
2727

2828
## Code Generation Limitations
2929

30-
* Currently only Ktor Http Client is supported. Additional clients (e.g. Spring WebClient) might be supported in the future.
3130
* Due to the custom logic required for deserialization of polymorphic types and default enum values only Jackson is currently supported.
3231
* Only a single operation per GraphQL query file is supported.
3332
* Subscriptions are currently NOT supported.
34-
* Nested queries have limited support as same object will be used for ALL nested results. This means that you have to
35-
explicitly ask for data from ALL nested levels + the NULL/empty child following it (that may skip recursive field selection
36-
as it will be NULL)

plugins/graphql-kotlin-plugin-core/src/main/kotlin/com/expediagroup/graphql/plugin/generator/GraphQLClientGeneratorContext.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ data class GraphQLClientGeneratorContext(
3838
val classNameCache: MutableMap<String, MutableList<ClassName>> = mutableMapOf()
3939
val typeSpecs: MutableMap<String, TypeSpec> = mutableMapOf()
4040
val typeAliases: MutableMap<String, TypeAliasSpec> = mutableMapOf()
41-
val objectsWithTypeNameSelection: MutableSet<String> = mutableSetOf()
41+
val typeToSelectionSetMap: MutableMap<String, Set<String>> = mutableMapOf()
4242
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright 2020 Expedia, Inc
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.expediagroup.graphql.plugin.generator.extensions
18+
19+
import com.expediagroup.graphql.plugin.generator.GraphQLClientGeneratorContext
20+
import graphql.language.Document
21+
import graphql.language.FragmentDefinition
22+
23+
internal fun Document.findFragmentDefinition(context: GraphQLClientGeneratorContext, targetFragment: String, targetType: String): FragmentDefinition =
24+
this.getDefinitionsOfType(FragmentDefinition::class.java)
25+
.find { it.name == targetFragment && context.graphQLSchema.getType(it.typeCondition.name).isPresent }
26+
?: throw RuntimeException("fragment not found")

plugins/graphql-kotlin-plugin-core/src/main/kotlin/com/expediagroup/graphql/plugin/generator/types/generateGraphQLObjectTypeSpec.kt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
package com.expediagroup.graphql.plugin.generator.types
1818

1919
import com.expediagroup.graphql.plugin.generator.GraphQLClientGeneratorContext
20+
import com.expediagroup.graphql.plugin.generator.extensions.findFragmentDefinition
2021
import com.squareup.kotlinpoet.FunSpec
2122
import com.squareup.kotlinpoet.KModifier
2223
import com.squareup.kotlinpoet.TypeSpec
23-
import graphql.language.FragmentDefinition
2424
import graphql.language.FragmentSpread
2525
import graphql.language.ObjectTypeDefinition
2626
import graphql.language.SelectionSet
@@ -59,8 +59,7 @@ internal fun generateGraphQLObjectTypeSpec(
5959
selectionSet.getSelectionsOfType(FragmentSpread::class.java)
6060
.forEach { fragment ->
6161
val fragmentDefinition = context.queryDocument
62-
.getDefinitionsOfType(FragmentDefinition::class.java)
63-
.find { it.name == fragment.name } ?: throw RuntimeException("fragment not found")
62+
.findFragmentDefinition(context, fragment.name, objectDefinition.name)
6463
generatePropertySpecs(
6564
context = context,
6665
objectName = objectDefinition.name,

plugins/graphql-kotlin-plugin-core/src/main/kotlin/com/expediagroup/graphql/plugin/generator/types/generateInterfaceTypeSpec.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ internal fun generateInterfaceTypeSpec(
131131
val unwrappedClassName = implementationClassName.copy(nullable = false)
132132
// we point to original implementation name as that will be value from the __typename
133133
jsonSubTypesCodeBlock.add("com.fasterxml.jackson.annotation.JsonSubTypes.Type(value = %T::class, name=%S)", unwrappedClassName, implementationName)
134-
context.objectsWithTypeNameSelection.add(implementationName)
135134
}
136135

137136
// add jackson annotations to handle deserialization

plugins/graphql-kotlin-plugin-core/src/main/kotlin/com/expediagroup/graphql/plugin/generator/types/generatePropertySpecs.kt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,7 @@ internal fun generatePropertySpecs(
3636
fieldDefinitions: List<FieldDefinition>,
3737
abstract: Boolean = false
3838
): List<PropertySpec> = selectionSet.getSelectionsOfType(Field::class.java)
39-
.filterNot {
40-
val typeNameSelected = it.name == "__typename"
41-
if (typeNameSelected) {
42-
context.objectsWithTypeNameSelection.add(objectName)
43-
}
44-
typeNameSelected
45-
}
39+
.filterNot { it.name == "__typename" }
4640
.map { selectedField ->
4741
val fieldDefinition = fieldDefinitions.find { it.name == selectedField.name }
4842
?: throw RuntimeException("unable to find corresponding field definition of ${selectedField.name} in $objectName")

plugins/graphql-kotlin-plugin-core/src/main/kotlin/com/expediagroup/graphql/plugin/generator/types/generateTypeName.kt

Lines changed: 55 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@
1717
package com.expediagroup.graphql.plugin.generator.types
1818

1919
import com.expediagroup.graphql.plugin.generator.GraphQLClientGeneratorContext
20+
import com.expediagroup.graphql.plugin.generator.extensions.findFragmentDefinition
2021
import com.squareup.kotlinpoet.BOOLEAN
2122
import com.squareup.kotlinpoet.ClassName
2223
import com.squareup.kotlinpoet.FLOAT
2324
import com.squareup.kotlinpoet.INT
2425
import com.squareup.kotlinpoet.LIST
25-
import com.squareup.kotlinpoet.ParameterizedTypeName
2626
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
2727
import com.squareup.kotlinpoet.STRING
2828
import com.squareup.kotlinpoet.TypeName
2929
import graphql.Scalars
3030
import graphql.language.EnumTypeDefinition
3131
import graphql.language.Field
32-
import graphql.language.FragmentDefinition
3332
import graphql.language.FragmentSpread
3433
import graphql.language.InlineFragment
3534
import graphql.language.InputObjectTypeDefinition
@@ -76,11 +75,15 @@ internal fun generateCustomClassName(context: GraphQLClientGeneratorContext, gra
7675

7776
return if (cachedTypeNames == null || cachedTypeNames.isEmpty()) {
7877
// build new custom type
79-
val className = if (graphQLTypeDefinition is ScalarTypeDefinition && context.scalarTypeToConverterMapping[graphQLTypeName] == null) {
78+
if (graphQLTypeDefinition is ScalarTypeDefinition && context.scalarTypeToConverterMapping[graphQLTypeName] == null) {
8079
val typeAlias = generateGraphQLCustomScalarTypeAlias(context, graphQLTypeDefinition)
81-
ClassName(context.packageName, typeAlias.name)
80+
val className = ClassName(context.packageName, typeAlias.name)
81+
context.classNameCache[graphQLTypeName] = mutableListOf(className)
82+
className
8283
} else {
83-
val typeSpec = when (graphQLTypeDefinition) {
84+
val className = generateClassName(context, graphQLTypeDefinition, selectionSet)
85+
// generate corresponding type spec
86+
when (graphQLTypeDefinition) {
8487
is ObjectTypeDefinition -> generateGraphQLObjectTypeSpec(context, graphQLTypeDefinition, selectionSet)
8588
is InputObjectTypeDefinition -> generateGraphQLInputObjectTypeSpec(context, graphQLTypeDefinition)
8689
is EnumTypeDefinition -> generateGraphQLEnumTypeSpec(context, graphQLTypeDefinition)
@@ -89,10 +92,8 @@ internal fun generateCustomClassName(context: GraphQLClientGeneratorContext, gra
8992
is ScalarTypeDefinition -> generateGraphQLCustomScalarTypeSpec(context, graphQLTypeDefinition)
9093
else -> throw RuntimeException("should never happen")
9194
}
92-
ClassName(context.packageName, "${context.rootType}.${typeSpec.name}")
95+
className
9396
}
94-
context.classNameCache[graphQLTypeName] = mutableListOf(className)
95-
className
9697
} else if (selectionSet == null) {
9798
cachedTypeNames.first()
9899
} else {
@@ -105,57 +106,56 @@ internal fun generateCustomClassName(context: GraphQLClientGeneratorContext, gra
105106

106107
// if different selection set we need to generate custom type
107108
val overriddenName = "$graphQLTypeName${cachedTypeNames.size + 1}"
108-
val typeSpec = when (graphQLTypeDefinition) {
109+
val className = generateClassName(context, graphQLTypeDefinition, selectionSet, overriddenName)
110+
111+
// generate new type spec
112+
when (graphQLTypeDefinition) {
109113
is ObjectTypeDefinition -> generateGraphQLObjectTypeSpec(context, graphQLTypeDefinition, selectionSet, overriddenName)
110114
is InterfaceTypeDefinition -> generateGraphQLInterfaceTypeSpec(context, graphQLTypeDefinition, selectionSet, overriddenName)
111115
is UnionTypeDefinition -> generateGraphQLUnionTypeSpec(context, graphQLTypeDefinition, selectionSet, overriddenName)
112116
else -> throw RuntimeException("should never happen")
113117
}
114-
val className = ClassName(context.packageName, "${context.rootType}.${typeSpec.name}")
115-
context.classNameCache[graphQLTypeName]?.add(className)
116118
className
117119
}
118120
}
119121

122+
/**
123+
* Generate custom [ClassName] reference to a Kotlin class representing GraphQL complex type (object, input object, enum, interface, union or custom scalar) and caches the value.
124+
*/
125+
internal fun generateClassName(
126+
context: GraphQLClientGeneratorContext,
127+
graphQLType: NamedNode<*>,
128+
selectionSet: SelectionSet? = null,
129+
nameOverride: String? = null
130+
): ClassName {
131+
val typeName = nameOverride ?: graphQLType.name
132+
val className = ClassName(context.packageName, "${context.rootType}.$typeName")
133+
val classNames = context.classNameCache.getOrDefault(graphQLType.name, mutableListOf())
134+
classNames.add(className)
135+
context.classNameCache[graphQLType.name] = classNames
136+
137+
if (selectionSet != null) {
138+
val selectedFields = calculateSelectedFields(context, typeName, selectionSet)
139+
context.typeToSelectionSetMap[typeName] = selectedFields
140+
}
141+
142+
return className
143+
}
144+
120145
private fun ClassName.simpleNameWithoutWrapper() = this.simpleName.substringAfter(".")
121146

122147
private fun isCachedTypeApplicable(context: GraphQLClientGeneratorContext, graphQLTypeName: String, graphQLTypeDefinition: TypeDefinition<*>, selectionSet: SelectionSet): Boolean =
123148
when (graphQLTypeDefinition) {
124-
is UnionTypeDefinition -> {
125-
val unionImplementations = graphQLTypeDefinition.memberTypes.filterIsInstance(graphql.language.TypeName::class.java).map { it.name }
126-
var result = true
127-
for (unionImplementation in unionImplementations) {
128-
result = result && verifySelectionSet(context, unionImplementation, selectionSet)
129-
if (!result) {
130-
break
131-
}
132-
}
133-
result
134-
}
135-
is InterfaceTypeDefinition -> {
136-
var result = verifySelectionSet(context, graphQLTypeName, selectionSet)
137-
if (result) {
138-
val implementations = context.graphQLSchema.getImplementationsOf(graphQLTypeDefinition).map { it.name }
139-
for (implementation in implementations) {
140-
result = result && verifySelectionSet(context, implementation, selectionSet)
141-
if (!result) {
142-
break
143-
}
144-
}
145-
}
146-
result
147-
}
149+
is UnionTypeDefinition -> verifySelectionSet(context, graphQLTypeName, selectionSet)
150+
is InterfaceTypeDefinition -> verifySelectionSet(context, graphQLTypeName, selectionSet)
148151
is ObjectTypeDefinition -> verifySelectionSet(context, graphQLTypeName, selectionSet)
149152
else -> true
150153
}
151154

152155
private fun verifySelectionSet(context: GraphQLClientGeneratorContext, graphQLTypeName: String, selectionSet: SelectionSet): Boolean {
153156
val selectedFields = calculateSelectedFields(context, graphQLTypeName, selectionSet)
154-
val properties = calculateGeneratedTypeProperties(context, graphQLTypeName)
155-
if (context.objectsWithTypeNameSelection.contains(graphQLTypeName)) {
156-
properties.add("__typename")
157-
}
158-
return selectedFields == properties
157+
val cachedTypeFields = context.typeToSelectionSetMap[graphQLTypeName]
158+
return selectedFields == cachedTypeFields
159159
}
160160

161161
private fun calculateSelectedFields(
@@ -173,42 +173,26 @@ private fun calculateSelectedFields(
173173
result.addAll(calculateSelectedFields(context, targetType, selection.selectionSet, "$path${selection.name}."))
174174
}
175175
}
176-
is InlineFragment -> if (selection.typeCondition.name == targetType) {
177-
result.addAll(calculateSelectedFields(context, targetType, selection.selectionSet))
176+
is InlineFragment -> {
177+
val targetFragmentType = selection.typeCondition.name
178+
val fragmentPathPrefix = if (targetFragmentType == targetType) {
179+
path
180+
} else {
181+
"$path$targetFragmentType."
182+
}
183+
result.addAll(calculateSelectedFields(context, targetType, selection.selectionSet, fragmentPathPrefix))
178184
}
179185
is FragmentSpread -> {
180-
val fragmentDefinition = context.queryDocument
181-
.getDefinitionsOfType(FragmentDefinition::class.java)
182-
.find { it.name == selection.name } ?: throw RuntimeException("fragment not found")
183-
if (fragmentDefinition.typeCondition.name == targetType) {
184-
result.addAll(calculateSelectedFields(context, targetType, fragmentDefinition.selectionSet))
186+
val fragmentDefinition = context.queryDocument.findFragmentDefinition(context, selection.name, targetType)
187+
val targetFragmentType = fragmentDefinition.typeCondition.name
188+
val fragmentPathPrefix = if (targetFragmentType == targetType) {
189+
path
190+
} else {
191+
"$path$targetFragmentType."
185192
}
193+
result.addAll(calculateSelectedFields(context, targetType, fragmentDefinition.selectionSet, fragmentPathPrefix))
186194
}
187195
}
188196
}
189197
return result
190198
}
191-
192-
private fun calculateGeneratedTypeProperties(context: GraphQLClientGeneratorContext, graphQLTypeName: String, path: String = ""): MutableSet<String> {
193-
val props = mutableSetOf<String>()
194-
195-
val typeSpec = context.typeSpecs[graphQLTypeName]
196-
for (property in typeSpec?.propertySpecs ?: emptyList()) {
197-
props.add(path + property.name)
198-
when (val propertyType = property.type) {
199-
is ParameterizedTypeName -> {
200-
val genericType = propertyType.typeArguments.firstOrNull() as? ClassName
201-
val genericTypeName = genericType?.simpleNameWithoutWrapper() ?: ""
202-
props.addAll(calculateGeneratedTypeProperties(context, genericTypeName, "$path${property.name}."))
203-
}
204-
is ClassName -> {
205-
val fieldTypeName = propertyType.simpleNameWithoutWrapper()
206-
// we need to check whether generated type is a custom scalar
207-
if (context.scalarTypeToConverterMapping[fieldTypeName] == null) {
208-
props.addAll(calculateGeneratedTypeProperties(context, fieldTypeName, "$path${property.name}."))
209-
}
210-
}
211-
}
212-
}
213-
return props
214-
}

plugins/graphql-kotlin-plugin-core/src/test/kotlin/com/expediagroup/graphql/plugin/generator/types/GenerateGraphQLInputObjectTypeSpecIT.kt

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,69 @@ class GenerateGraphQLInputObjectTypeSpecIT {
6161
""".trimIndent()
6262
verifyGeneratedFileSpecContents(query, expected)
6363
}
64+
65+
@Test
66+
fun `verify generation of self referencing input object`() {
67+
val expected =
68+
"""
69+
package com.expediagroup.graphql.plugin.generator.integration
70+
71+
import com.expediagroup.graphql.client.GraphQLClient
72+
import com.expediagroup.graphql.types.GraphQLResponse
73+
import io.ktor.client.request.HttpRequestBuilder
74+
import kotlin.Boolean
75+
import kotlin.Float
76+
import kotlin.String
77+
import kotlin.Unit
78+
79+
const val INPUT_OBJECT_TEST_QUERY: String =
80+
"query InputObjectTestQuery(${'$'}{'${'$'}'}input: ComplexArgumentInput) {\n complexInputObjectQuery(criteria: ${'$'}{'${'$'}'}input)\n}"
81+
82+
class InputObjectTestQuery(
83+
private val graphQLClient: GraphQLClient<*>
84+
) {
85+
suspend fun execute(variables: InputObjectTestQuery.Variables,
86+
requestBuilder: HttpRequestBuilder.() -> Unit = {}):
87+
GraphQLResponse<InputObjectTestQuery.Result> = graphQLClient.execute(INPUT_OBJECT_TEST_QUERY,
88+
"InputObjectTestQuery", variables, requestBuilder)
89+
90+
data class Variables(
91+
val input: InputObjectTestQuery.ComplexArgumentInput? = null
92+
)
93+
94+
/**
95+
* Self referencing input object
96+
*/
97+
data class ComplexArgumentInput(
98+
/**
99+
* Maximum value for test criteria
100+
*/
101+
val max: Float? = null,
102+
/**
103+
* Minimum value for test criteria
104+
*/
105+
val min: Float? = null,
106+
/**
107+
* Next criteria
108+
*/
109+
val next: InputObjectTestQuery.ComplexArgumentInput? = null
110+
)
111+
112+
data class Result(
113+
/**
114+
* Query that accepts self referencing input object
115+
*/
116+
val complexInputObjectQuery: Boolean
117+
)
118+
}
119+
""".trimIndent()
120+
121+
val query =
122+
"""
123+
query InputObjectTestQuery(${'$'}input: ComplexArgumentInput) {
124+
complexInputObjectQuery(criteria: ${'$'}input)
125+
}
126+
""".trimIndent()
127+
verifyGeneratedFileSpecContents(query, expected)
128+
}
64129
}

plugins/graphql-kotlin-plugin-core/src/test/kotlin/com/expediagroup/graphql/plugin/generator/types/GenerateGraphQLObjectTypeSpecIT.kt

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ class GenerateGraphQLObjectTypeSpecIT {
300300
}
301301

302302
@Test
303-
fun `verify we can generate nested objects`() {
303+
fun `verify we can generate self-reference objects`() {
304304
val expected =
305305
"""
306306
package com.expediagroup.graphql.plugin.generator.integration
@@ -323,6 +323,20 @@ class GenerateGraphQLObjectTypeSpecIT {
323323
GraphQLResponse<NestedQuery.Result> = graphQLClient.execute(NESTED_QUERY, "NestedQuery", null,
324324
requestBuilder)
325325
326+
/**
327+
* Example of an object self-referencing itself
328+
*/
329+
data class NestedObject2(
330+
/**
331+
* Unique identifier
332+
*/
333+
val id: Int,
334+
/**
335+
* Name of the object
336+
*/
337+
val name: String
338+
)
339+
326340
/**
327341
* Example of an object self-referencing itself
328342
*/
@@ -338,7 +352,7 @@ class GenerateGraphQLObjectTypeSpecIT {
338352
/**
339353
* Children elements
340354
*/
341-
val children: List<NestedQuery.NestedObject>
355+
val children: List<NestedQuery.NestedObject2>
342356
)
343357
344358
data class Result(

0 commit comments

Comments
 (0)