Skip to content

Move notification handler registrations to capabilities #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Address more feedback and further cleanup
  • Loading branch information
stephentoub committed Apr 4, 2025
commit 824d01d1dd71e209d3d231b6ac4a014550e41152
1 change: 0 additions & 1 deletion samples/QuickstartWeatherServer/Tools/WeatherTools.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ModelContextProtocol;
using ModelContextProtocol.Server;
using System.ComponentModel;
using System.Net.Http.Json;
using System.Text.Json;

namespace QuickstartWeatherServer.Tools;
Expand Down
24 changes: 8 additions & 16 deletions src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils.Json;
using System.Diagnostics;
using System.Reflection;
using System.Text.Json;

namespace ModelContextProtocol.Client;

/// <inheritdoc/>
internal sealed class McpClient : McpEndpoint, IMcpClient
{
/// <summary>Cached naming information used for client name/version when none is specified.</summary>
private static readonly AssemblyName s_asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();
private static Implementation DefaultImplementation { get; } = new()
{
Name = DefaultAssemblyName.Name ?? nameof(McpClient),
Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0",
};

private readonly IClientTransport _clientTransport;
private readonly McpClientOptions _options;
Expand All @@ -37,17 +38,9 @@ internal sealed class McpClient : McpEndpoint, IMcpClient
public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
: base(loggerFactory)
{
_clientTransport = clientTransport;
options ??= new();

if (options?.ClientInfo is null)
{
options = options?.Clone() ?? new();
options.ClientInfo = new()
{
Name = s_asmName.Name ?? nameof(McpClient),
Version = s_asmName.Version?.ToString() ?? "1.0.0",
};
}
_clientTransport = clientTransport;
_options = options;

EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
Expand Down Expand Up @@ -122,14 +115,13 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
try
{
// Send initialize request
Debug.Assert(_options.ClientInfo is not null, "ClientInfo should be set by the constructor");
var initializeResponse = await this.SendRequestAsync(
RequestMethods.Initialize,
new InitializeRequestParams
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo!
ClientInfo = _options.ClientInfo ?? DefaultImplementation,
},
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
McpJsonUtilities.JsonContext.Default.InitializeResult,
Expand Down
1 change: 0 additions & 1 deletion src/ModelContextProtocol/Client/McpClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using ModelContextProtocol.Utils;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using System.Reflection;

namespace ModelContextProtocol.Client;

Expand Down
10 changes: 0 additions & 10 deletions src/ModelContextProtocol/Client/McpClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,4 @@ public class McpClientOptions
/// Timeout for initialization sequence.
/// </summary>
public TimeSpan InitializationTimeout { get; set; } = TimeSpan.FromSeconds(60);

/// <summary>Creates a shallow clone of the options.</summary>
internal McpClientOptions Clone() =>
new()
{
ClientInfo = ClientInfo,
Capabilities = Capabilities,
ProtocolVersion = ProtocolVersion,
InitializationTimeout = InitializationTimeout
};
}
1 change: 0 additions & 1 deletion src/ModelContextProtocol/Client/McpClientTool.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils.Json;
using ModelContextProtocol.Utils;
using Microsoft.Extensions.AI;
using System.Text.Json;

Expand Down
15 changes: 1 addition & 14 deletions src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Reflection;
using ModelContextProtocol.Server;
using ModelContextProtocol.Server;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Utils;

Expand All @@ -25,18 +24,6 @@ public void Configure(McpServerOptions options)
{
Throw.IfNull(options);

// Configure the option's server information based on the current process,
// if it otherwise lacks server information.
if (options.ServerInfo is not { } serverInfo)
{
var assemblyName = Assembly.GetEntryAssembly()?.GetName();
options.ServerInfo = new()
{
Name = assemblyName?.Name ?? "McpServer",
Version = assemblyName?.Version?.ToString() ?? "1.0.0",
};
}

// Collect all of the provided tools into a tools collection. If the options already has
// a collection, add to it, otherwise create a new one. We want to maintain the identity
// of an existing collection in case someone has provided their own derived type, wants
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Logging;
using ModelContextProtocol.Utils;
using System.ComponentModel;
using System.Diagnostics;
using System.Text;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = n
private static string GetServerName(McpServerOptions serverOptions)
{
Throw.IfNull(serverOptions);
Throw.IfNull(serverOptions.ServerInfo);
Throw.IfNull(serverOptions.ServerInfo.Name);

return serverOptions.ServerInfo.Name;
return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using ModelContextProtocol.Protocol.Messages;

namespace ModelContextProtocol.Protocol.Types;
namespace ModelContextProtocol.Protocol.Types;

/// <summary>
/// A request from the server to get a list of root URIs from the client.
Expand Down
12 changes: 10 additions & 2 deletions src/ModelContextProtocol/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ namespace ModelContextProtocol.Server;
/// <inheritdoc />
internal sealed class McpServer : McpEndpoint, IMcpServer
{
internal static Implementation DefaultImplementation { get; } = new()
{
Name = DefaultAssemblyName.Name ?? nameof(McpServer),
Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0",
};

private readonly EventHandler? _toolsChangedDelegate;
private readonly EventHandler? _promptsChangedDelegate;

Expand All @@ -32,9 +38,11 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
Throw.IfNull(transport);
Throw.IfNull(options);

options ??= new();

ServerOptions = options;
Services = serviceProvider;
_endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})";
_endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})";

_toolsChangedDelegate = delegate
{
Expand Down Expand Up @@ -158,7 +166,7 @@ private void SetInitializeHandler(McpServerOptions options)
{
ProtocolVersion = options.ProtocolVersion,
Instructions = options.ServerInstructions,
ServerInfo = options.ServerInfo,
ServerInfo = options.ServerInfo ?? DefaultImplementation,
Capabilities = ServerCapabilities ?? new(),
});
},
Expand Down
1 change: 0 additions & 1 deletion src/ModelContextProtocol/Server/McpServerExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.Extensions.AI;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Server/McpServerOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class McpServerOptions
/// <summary>
/// Information about this server implementation.
/// </summary>
public required Implementation ServerInfo { get; set; }
public Implementation? ServerInfo { get; set; }

/// <summary>
/// Server capabilities to advertise to the server.
Expand Down
5 changes: 4 additions & 1 deletion src/ModelContextProtocol/Shared/McpEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization.Metadata;
using System.Reflection;

namespace ModelContextProtocol.Shared;

Expand All @@ -19,6 +19,9 @@ namespace ModelContextProtocol.Shared;
/// </summary>
internal abstract class McpEndpoint : IAsyncDisposable
{
/// <summary>Cached naming information used for name/version when none is specified.</summary>
internal static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();

private McpSession? _session;
private CancellationTokenSource? _sessionCts;

Expand Down
1 change: 0 additions & 1 deletion src/ModelContextProtocol/Shared/McpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
Expand Down
1 change: 0 additions & 1 deletion tests/ModelContextProtocol.TestServer/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ private static async Task Main(string[] args)

McpServerOptions options = new()
{
ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" },
Capabilities = new ServerCapabilities()
{
Tools = ConfigureTools(),
Expand Down
1 change: 0 additions & 1 deletion tests/ModelContextProtocol.TestSseServer/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ private static void ConfigureSerilog(ILoggingBuilder loggingBuilder)

private static void ConfigureOptions(McpServerOptions options)
{
options.ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" };
options.Capabilities = new ServerCapabilities()
{
Tools = new(),
Expand Down
30 changes: 8 additions & 22 deletions tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,24 @@ namespace ModelContextProtocol.Tests.Client;

public class McpClientFactoryTests
{
private readonly McpClientOptions _defaultOptions = new()
{
ClientInfo = new() { Name = "TestClient", Version = "1.0.0" }
};

[Fact]
public async Task CreateAsync_WithInvalidArgs_Throws()
{
await Assert.ThrowsAsync<ArgumentNullException>("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken));
await Assert.ThrowsAsync<ArgumentNullException>("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, cancellationToken: TestContext.Current.CancellationToken));

await Assert.ThrowsAsync<ArgumentException>("serverConfig", () => McpClientFactory.CreateAsync(new McpServerConfig()
{
Name = "name",
Id = "id",
TransportType = "somethingunsupported",
}, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken));
}, cancellationToken: TestContext.Current.CancellationToken));

await Assert.ThrowsAsync<InvalidOperationException>(() => McpClientFactory.CreateAsync(new McpServerConfig()
{
Name = "name",
Id = "id",
TransportType = TransportTypes.StdIo,
}, _defaultOptions, (_, __) => null!, cancellationToken: TestContext.Current.CancellationToken));
}, createTransportFunc: (_, __) => null!, cancellationToken: TestContext.Current.CancellationToken));
}

[Fact]
Expand Down Expand Up @@ -78,8 +73,7 @@ public async Task CreateAsync_WithValidStdioConfig_CreatesNewClient()
// Act
await using var client = await McpClientFactory.CreateAsync(
serverConfig,
_defaultOptions,
(_, __) => new NopTransport(),
createTransportFunc: (_, __) => new NopTransport(),
cancellationToken: TestContext.Current.CancellationToken);

// Assert
Expand All @@ -102,8 +96,7 @@ public async Task CreateAsync_WithNoTransportOptions_CreatesNewClient()
// Act
await using var client = await McpClientFactory.CreateAsync(
serverConfig,
_defaultOptions,
(_, __) => new NopTransport(),
createTransportFunc: (_, __) => new NopTransport(),
cancellationToken: TestContext.Current.CancellationToken);

// Assert
Expand All @@ -126,8 +119,7 @@ public async Task CreateAsync_WithValidSseConfig_CreatesNewClient()
// Act
await using var client = await McpClientFactory.CreateAsync(
serverConfig,
_defaultOptions,
(_, __) => new NopTransport(),
createTransportFunc: (_, __) => new NopTransport(),
cancellationToken: TestContext.Current.CancellationToken);

// Assert
Expand Down Expand Up @@ -157,8 +149,7 @@ public async Task CreateAsync_WithSse_CreatesCorrectTransportOptions()
// Act
await using var client = await McpClientFactory.CreateAsync(
serverConfig,
_defaultOptions,
(_, __) => new NopTransport(),
createTransportFunc: (_, __) => new NopTransport(),
cancellationToken: TestContext.Current.CancellationToken);

// Assert
Expand Down Expand Up @@ -186,7 +177,7 @@ public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(s
};

// act & assert
await Assert.ThrowsAsync<ArgumentException>(() => McpClientFactory.CreateAsync(config, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken));
await Assert.ThrowsAsync<ArgumentException>(() => McpClientFactory.CreateAsync(config, cancellationToken: TestContext.Current.CancellationToken));
}

[Theory]
Expand All @@ -205,11 +196,6 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType)

var clientOptions = new McpClientOptions
{
ClientInfo = new Implementation
{
Name = "TestClient",
Version = "1.0.0.0"
},
Capabilities = new ClientCapabilities
{
Sampling = new SamplingCapability
Expand Down
5 changes: 0 additions & 5 deletions tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Tests.Utils;
using OpenAI;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;

namespace ModelContextProtocol.Tests;

Expand Down Expand Up @@ -367,7 +364,6 @@ public async Task Sampling_Stdio(string clientId)
int samplingHandlerCalls = 0;
await using var client = await _fixture.CreateClientAsync(clientId, new()
{
ClientInfo = new() { Name = "Sampling_Stdio", Version = "1.0.0" },
Capabilities = new()
{
Sampling = new()
Expand Down Expand Up @@ -532,7 +528,6 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated()
.CreateSamplingHandler();
await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new()
{
ClientInfo = new() { Name = nameof(SamplingViaChatClient_RequestResponseProperlyPropagated), Version = "1.0.0" },
Capabilities = new()
{
Sampling = new()
Expand Down
2 changes: 0 additions & 2 deletions tests/ModelContextProtocol.Tests/DiagnosticTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using OpenTelemetry.Trace;
using System.Diagnostics;
Expand Down Expand Up @@ -49,7 +48,6 @@ private static async Task RunConnected(Func<IMcpClient, IMcpServer, Task> action

await using (IMcpServer server = McpServerFactory.Create(serverTransport, new()
{
ServerInfo = new Implementation { Name = "TestServer", Version = "1.0.0" },
Capabilities = new()
{
Tools = new()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;

namespace ModelContextProtocol.Tests.Server;
Expand All @@ -13,7 +12,6 @@ public McpServerFactoryTests(ITestOutputHelper testOutputHelper)
{
_options = new McpServerOptions
{
ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" },
ProtocolVersion = "1.0",
InitializationTimeout = TimeSpan.FromSeconds(30)
};
Expand Down
Loading