Skip to content

Commit cf296ac

Browse files
rose-ajoemcbride
authored andcommitted
Fix subscribing to non object graph types (graphql-dotnet#989)
* use node.GetParentType(context.Schema) * fix subscription update execution * create test for subscribing to scalar fields * remove changes to verion numbers
1 parent d408bc4 commit cf296ac

File tree

5 files changed

+134
-9
lines changed

5 files changed

+134
-9
lines changed

src/GraphQL.Tests/Subscription/SubscriptionSchema.cs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Concurrent;
3+
using System.Collections.Generic;
34
using System.Linq;
45
using System.Reactive.Linq;
56
using System.Reactive.Subjects;
@@ -64,6 +65,22 @@ public ChatSubscriptions(IChat chat)
6465
Resolver = new FuncFieldResolver<Message>(ResolveMessage),
6566
AsyncSubscriber = new AsyncEventStreamResolver<Message>(SubscribeByIdAsync)
6667
});
68+
69+
AddField(new EventStreamFieldType
70+
{
71+
Name = "messageGetAll",
72+
Type = typeof(ListGraphType<MessageType>),
73+
Resolver = new FuncFieldResolver<List<Message>>(context => context.Source as List<Message>),
74+
Subscriber = new EventStreamResolver<List<Message>>(context => _chat.MessagesGetAll())
75+
});
76+
77+
AddField(new EventStreamFieldType
78+
{
79+
Name = "newMessageContent",
80+
Type = typeof(StringGraphType),
81+
Resolver = new FuncFieldResolver<string>(context => context.Source as string),
82+
Subscriber = new EventStreamResolver<string>(context => Subscribe(context).Select(message => message.Content))
83+
});
6784
}
6885

6986
private IObservable<Message> SubscribeById(ResolveEventStreamContext context)
@@ -193,6 +210,7 @@ public interface IChat
193210
Message AddMessage(Message message);
194211

195212
IObservable<Message> Messages();
213+
IObservable<List<Message>> MessagesGetAll();
196214

197215
Message AddMessage(ReceivedMessage message);
198216

@@ -202,7 +220,7 @@ public interface IChat
202220
public class Chat : IChat
203221
{
204222
private readonly ISubject<Message> _messageStream = new ReplaySubject<Message>(1);
205-
223+
private readonly ISubject<List<Message>> _allMessageStream = new ReplaySubject<List<Message>>(1);
206224

207225
public Chat()
208226
{
@@ -244,6 +262,14 @@ public async Task<IObservable<Message>> MessagesAsync()
244262
return Messages();
245263
}
246264

265+
public List<Message> AddMessageGetAll(Message message)
266+
{
267+
AllMessages.Push(message);
268+
var l = new List<Message>(AllMessages);
269+
_allMessageStream.OnNext(l);
270+
return l;
271+
}
272+
247273
public Message AddMessage(Message message)
248274
{
249275
AllMessages.Push(message);
@@ -256,6 +282,11 @@ public IObservable<Message> Messages()
256282
return _messageStream.AsObservable();
257283
}
258284

285+
public IObservable<List<Message>> MessagesGetAll()
286+
{
287+
return _allMessageStream.AsObservable();
288+
}
289+
259290
public void AddError(Exception exception)
260291
{
261292
_messageStream.OnError(exception);

src/GraphQL.Tests/Subscription/SubscriptionSchemaWithReflection.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.Reactive.Linq;
34
using System.Threading.Tasks;
45
using GraphQL.Subscription;
@@ -25,6 +26,7 @@ type Subscription {
2526
messageAddedByUser(id: String!) : Message
2627
messageAddedAsync : Message
2728
messageAddedByUserAsync(id: String!) : Message
29+
messageGetAll : [Message]
2830
}
2931
";
3032

@@ -57,6 +59,18 @@ public Message ResolveMessageAdded(ResolveFieldContext context)
5759
return context.Source as Message;
5860
}
5961

62+
[GraphQLMetadata(Name = "messageGetAll", Type = ResolverType.Subscriber)]
63+
public IObservable<List<Message>> SubscribeMessageGetAll(ResolveEventStreamContext context)
64+
{
65+
return SubscriptionSchemaWithReflection.Chat.MessagesGetAll();
66+
}
67+
68+
[GraphQLMetadata(Name = "messageGetAll")]
69+
public List<Message> ResolveMessageGetAll(ResolveFieldContext context)
70+
{
71+
return context.Source as List<Message>;
72+
}
73+
6074
[GraphQLMetadata(Name = "messageAddedByUser", Type = ResolverType.Subscriber)]
6175
public IObservable<Message> SubscribeMessageAddedByUser(ResolveEventStreamContext context, string id)
6276
{

src/GraphQL.Tests/Subscription/SubscriptionTests.cs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,81 @@ protected async Task<SubscriptionExecutionResult> ExecuteSubscribeAsync(Executio
2222
return (SubscriptionExecutionResult)result;
2323
}
2424

25+
[Fact]
26+
public async Task SubscribeGetAll()
27+
{
28+
/* Given */
29+
var addedMessage = new Message
30+
{
31+
Content = "test",
32+
From = new MessageFrom
33+
{
34+
DisplayName = "test",
35+
Id = "1"
36+
},
37+
SentAt = DateTime.Now
38+
};
39+
40+
var chat = new Chat();
41+
var schema = new ChatSchema(chat);
42+
43+
/* When */
44+
var result = await ExecuteSubscribeAsync(new ExecutionOptions
45+
{
46+
Query = "subscription messageGetAll { messageGetAll { from { id displayName } content sentAt } }",
47+
Schema = schema
48+
});
49+
50+
chat.AddMessageGetAll(addedMessage);
51+
52+
/* Then */
53+
var stream = result.Streams.Values.FirstOrDefault();
54+
var message = await stream.FirstOrDefaultAsync();
55+
56+
message.ShouldNotBeNull();
57+
var data = ((Dictionary<string, object>)message.Data);
58+
data.ShouldNotBeNull();
59+
data["messageGetAll"].ShouldNotBeNull();
60+
}
61+
62+
[Fact]
63+
public async Task SubscribeToContent()
64+
{
65+
/* Given */
66+
var addedMessage = new Message
67+
{
68+
Content = "test",
69+
From = new MessageFrom
70+
{
71+
DisplayName = "test",
72+
Id = "1"
73+
},
74+
SentAt = DateTime.Now
75+
};
76+
77+
var chat = new Chat();
78+
var schema = new ChatSchema(chat);
79+
80+
/* When */
81+
var result = await ExecuteSubscribeAsync(new ExecutionOptions
82+
{
83+
Query = "subscription newMessageContent { newMessageContent }",
84+
Schema = schema
85+
});
86+
87+
chat.AddMessage(addedMessage);
88+
89+
/* Then */
90+
var stream = result.Streams.Values.FirstOrDefault();
91+
var message = await stream.FirstOrDefaultAsync();
92+
93+
message.ShouldNotBeNull();
94+
var data = ((Dictionary<string, object>)message.Data);
95+
data.ShouldNotBeNull();
96+
data["newMessageContent"].ShouldNotBeNull();
97+
data["newMessageContent"].ToString().ShouldBe("test");
98+
}
99+
25100
[Fact]
26101
public async Task Subscribe()
27102
{

src/GraphQL/Execution/ParallelExecutionStrategy.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ namespace GraphQL.Execution
66
{
77
public class ParallelExecutionStrategy : ExecutionStrategy
88
{
9-
protected override async Task ExecuteNodeTreeAsync(ExecutionContext context, ObjectExecutionNode rootNode)
9+
protected override Task ExecuteNodeTreeAsync(ExecutionContext context, ObjectExecutionNode rootNode)
10+
=> ExecuteNodeTreeAsync(context, rootNode);
11+
12+
protected async Task ExecuteNodeTreeAsync(ExecutionContext context, ExecutionNode rootNode)
1013
{
1114
var pendingNodes = new List<ExecutionNode>
1215
{

src/GraphQL/Execution/SubscriptionExecutionStrategy.cs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Diagnostics;
34
using System.Reactive.Linq;
45
using System.Threading.Tasks;
56
using GraphQL.Language.AST;
@@ -66,7 +67,7 @@ protected virtual async Task<IObservable<ExecutionResult>> ResolveEventStreamAsy
6667
FieldAst = node.Field,
6768
FieldDefinition = node.FieldDefinition,
6869
ReturnType = node.FieldDefinition.ResolvedType,
69-
ParentType = node.GraphType as IObjectGraphType,
70+
ParentType = node.GetParentType(context.Schema),
7071
Arguments = arguments,
7172
Source = source,
7273
Schema = context.Schema,
@@ -101,11 +102,13 @@ protected virtual async Task<IObservable<ExecutionResult>> ResolveEventStreamAsy
101102
}
102103

103104
return subscription
104-
.Select(value => new ObjectExecutionNode(node.Parent, node.GraphType, node.Field, node.FieldDefinition, node.Path)
105+
.Select(value =>
105106
{
106-
Source = value
107+
var executionNode = BuildExecutionNode(node.Parent, node.GraphType, node.Field, node.FieldDefinition, node.Path);
108+
executionNode.Source = value;
109+
return executionNode;
107110
})
108-
.SelectMany(async objectNode =>
111+
.SelectMany(async executionNode =>
109112
{
110113
foreach (var listener in context.Listeners)
111114
{
@@ -114,8 +117,7 @@ await listener.BeforeExecutionAsync(context.UserContext, context.CancellationTok
114117
}
115118

116119
// Execute the whole execution tree and return the result
117-
await ExecuteNodeTreeAsync(context, objectNode)
118-
.ConfigureAwait(false);
120+
await ExecuteNodeTreeAsync(context, executionNode).ConfigureAwait(false);
119121

120122
foreach (var listener in context.Listeners)
121123
{
@@ -127,7 +129,7 @@ await listener.AfterExecutionAsync(context.UserContext, context.CancellationToke
127129
{
128130
Data = new Dictionary<string, object>
129131
{
130-
{ objectNode.Name, objectNode.ToValue() }
132+
{ executionNode.Name, executionNode.ToValue() }
131133
}
132134
}.With(context);
133135
})

0 commit comments

Comments
 (0)