Skip to content

Delay SSE GET connection until after session ID is established #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 186 additions & 1 deletion Sources/MCP/Base/Transports/HTTPClientTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,66 @@ import Logging
import FoundationNetworking
#endif

public actor HTTPClientTransport: Actor, Transport {
/// An implementation of the MCP Streamable HTTP transport protocol for clients.
///
/// This transport implements the [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http)
/// specification from the Model Context Protocol.
///
/// It supports:
/// - Sending JSON-RPC messages via HTTP POST requests
/// - Receiving responses via both direct JSON responses and SSE streams
/// - Session management using the `Mcp-Session-Id` header
/// - Automatic reconnection for dropped SSE streams
/// - Platform-specific optimizations for different operating systems
///
/// The transport supports two modes:
/// - Regular HTTP (`streaming=false`): Simple request/response pattern
/// - Streaming HTTP with SSE (`streaming=true`): Enables server-to-client push messages
///
/// - Important: Server-Sent Events (SSE) functionality is not supported on Linux platforms.
public actor HTTPClientTransport: Transport {
/// The server endpoint URL to connect to
public let endpoint: URL
private let session: URLSession

/// The session ID assigned by the server, used for maintaining state across requests
public private(set) var sessionID: String?
private let streaming: Bool
private var streamingTask: Task<Void, Never>?

/// Logger instance for transport-related events
public nonisolated let logger: Logger

/// Maximum time to wait for a session ID before proceeding with SSE connection
public let sseInitializationTimeout: TimeInterval

private var isConnected = false
private let messageStream: AsyncThrowingStream<Data, Swift.Error>
private let messageContinuation: AsyncThrowingStream<Data, Swift.Error>.Continuation

private var initialSessionIDSignalTask: Task<Void, Never>?
private var initialSessionIDContinuation: CheckedContinuation<Void, Never>?

/// Creates a new HTTP transport client with the specified endpoint
///
/// - Parameters:
/// - endpoint: The server URL to connect to
/// - configuration: URLSession configuration to use for HTTP requests
/// - streaming: Whether to enable SSE streaming mode (default: true)
/// - sseInitializationTimeout: Maximum time to wait for session ID before proceeding with SSE (default: 10 seconds)
/// - logger: Optional logger instance for transport events
public init(
endpoint: URL,
configuration: URLSessionConfiguration = .default,
streaming: Bool = true,
sseInitializationTimeout: TimeInterval = 10,
logger: Logger? = nil
) {
self.init(
endpoint: endpoint,
session: URLSession(configuration: configuration),
streaming: streaming,
sseInitializationTimeout: sseInitializationTimeout,
logger: logger
)
}
Expand All @@ -39,11 +77,13 @@ public actor HTTPClientTransport: Actor, Transport {
endpoint: URL,
session: URLSession,
streaming: Bool = false,
sseInitializationTimeout: TimeInterval = 10,
logger: Logger? = nil
) {
self.endpoint = endpoint
self.session = session
self.streaming = streaming
self.sseInitializationTimeout = sseInitializationTimeout

// Create message stream
var continuation: AsyncThrowingStream<Data, Swift.Error>.Continuation!
Expand All @@ -58,11 +98,37 @@ public actor HTTPClientTransport: Actor, Transport {
)
}

// Setup the initial session ID signal
private func setupInitialSessionIDSignal() {
self.initialSessionIDSignalTask = Task {
await withCheckedContinuation { continuation in
self.initialSessionIDContinuation = continuation
// This task will suspend here until continuation.resume() is called
}
}
}

// Trigger the initial session ID signal when a session ID is established
private func triggerInitialSessionIDSignal() {
if let continuation = self.initialSessionIDContinuation {
continuation.resume()
self.initialSessionIDContinuation = nil // Consume the continuation
logger.debug("Initial session ID signal triggered for SSE task.")
}
}

/// Establishes connection with the transport
///
/// This prepares the transport for communication and sets up SSE streaming
/// if streaming mode is enabled. The actual HTTP connection happens with the
/// first message sent.
public func connect() async throws {
guard !isConnected else { return }
isConnected = true

// Setup initial session ID signal
setupInitialSessionIDSignal()

if streaming {
// Start listening to server events
streamingTask = Task { await startListeningForServerEvents() }
Expand All @@ -72,6 +138,9 @@ public actor HTTPClientTransport: Actor, Transport {
}

/// Disconnects from the transport
///
/// This terminates any active connections, cancels the streaming task,
/// and releases any resources being used by the transport.
public func disconnect() async {
guard isConnected else { return }
isConnected = false
Expand All @@ -86,10 +155,28 @@ public actor HTTPClientTransport: Actor, Transport {
// Clean up message stream
messageContinuation.finish()

// Cancel the initial session ID signal task if active
initialSessionIDSignalTask?.cancel()
initialSessionIDSignalTask = nil
// Resume the continuation if it's still pending to avoid leaks
initialSessionIDContinuation?.resume()
initialSessionIDContinuation = nil

logger.info("HTTP clienttransport disconnected")
}

/// Sends data through an HTTP POST request
///
/// This sends a JSON-RPC message to the server via HTTP POST and processes
/// the response according to the MCP Streamable HTTP specification. It handles:
///
/// - Adding appropriate Accept headers for both JSON and SSE
/// - Including the session ID in requests if one has been established
/// - Processing different response types (JSON vs SSE)
/// - Handling HTTP error codes according to the specification
///
/// - Parameter data: The JSON-RPC message to send
/// - Throws: MCPError for transport failures or server errors
public func send(_ data: Data) async throws {
guard isConnected else {
throw MCPError.internalError("Transport not connected")
Expand Down Expand Up @@ -129,7 +216,12 @@ public actor HTTPClientTransport: Actor, Transport {

// Extract session ID if present
if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") {
let wasSessionIDNil = (self.sessionID == nil)
self.sessionID = newSessionID
if wasSessionIDNil {
// Trigger signal on first session ID
triggerInitialSessionIDSignal()
}
logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"])
}

Expand Down Expand Up @@ -161,7 +253,12 @@ public actor HTTPClientTransport: Actor, Transport {

// Extract session ID if present
if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") {
let wasSessionIDNil = (self.sessionID == nil)
self.sessionID = newSessionID
if wasSessionIDNil {
// Trigger signal on first session ID
triggerInitialSessionIDSignal()
}
logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"])
}

Expand Down Expand Up @@ -238,13 +335,29 @@ public actor HTTPClientTransport: Actor, Transport {
}

/// Receives data in an async sequence
///
/// This returns an AsyncThrowingStream that emits Data objects representing
/// each JSON-RPC message received from the server. This includes:
///
/// - Direct responses to client requests
/// - Server-initiated messages delivered via SSE streams
///
/// - Returns: An AsyncThrowingStream of Data objects
public func receive() -> AsyncThrowingStream<Data, Swift.Error> {
return messageStream
}

// MARK: - SSE

/// Starts listening for server events using SSE
///
/// This establishes a long-lived HTTP connection using Server-Sent Events (SSE)
/// to enable server-to-client push messaging. It handles:
///
/// - Waiting for session ID if needed
/// - Opening the SSE connection
/// - Automatic reconnection on connection drops
/// - Processing received events
private func startListeningForServerEvents() async {
#if os(Linux)
// SSE is not fully supported on Linux
Expand All @@ -257,6 +370,63 @@ public actor HTTPClientTransport: Actor, Transport {
// This is the original code for platforms that support SSE
guard isConnected else { return }

// Wait for the initial session ID signal, but only if sessionID isn't already set
if self.sessionID == nil, let signalTask = self.initialSessionIDSignalTask {
logger.debug("SSE streaming task waiting for initial sessionID signal...")

// Race the signalTask against a timeout
let timeoutTask = Task {
try? await Task.sleep(for: .seconds(self.sseInitializationTimeout))
return false
}

let signalCompletionTask = Task {
await signalTask.value
return true // Indicates signal received
}

// Use TaskGroup to race the two tasks
var signalReceived = false
do {
signalReceived = try await withThrowingTaskGroup(of: Bool.self) { group in
group.addTask {
await signalCompletionTask.value
}
group.addTask {
await timeoutTask.value
}

// Take the first result and cancel the other task
if let firstResult = try await group.next() {
group.cancelAll()
return firstResult
}
return false
}
} catch {
logger.error("Error while waiting for session ID signal: \(error)")
}

// Clean up tasks
timeoutTask.cancel()

if signalReceived {
logger.debug("SSE streaming task proceeding after initial sessionID signal.")
} else {
logger.warning(
"Timeout waiting for initial sessionID signal. SSE stream will proceed (sessionID might be nil)."
)
}
} else if self.sessionID != nil {
logger.debug(
"Initial sessionID already available. Proceeding with SSE streaming task immediately."
)
} else {
logger.info(
"Proceeding with SSE connection attempt; sessionID is nil. This might be expected for stateless servers or if initialize hasn't provided one yet."
)
}

// Retry loop for connection drops
while isConnected && !Task.isCancelled {
do {
Expand All @@ -274,6 +444,11 @@ public actor HTTPClientTransport: Actor, Transport {

#if !os(Linux)
/// Establishes an SSE connection to the server
///
/// This initiates a GET request to the server endpoint with appropriate
/// headers to establish an SSE stream according to the MCP specification.
///
/// - Throws: MCPError for connection failures or server errors
private func connectToEventStream() async throws {
guard isConnected else { return }

Expand Down Expand Up @@ -309,13 +484,23 @@ public actor HTTPClientTransport: Actor, Transport {

// Extract session ID if present
if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") {
let wasSessionIDNil = (self.sessionID == nil)
self.sessionID = newSessionID
if wasSessionIDNil {
// Trigger signal on first session ID, though this is unlikely to happen here
// as GET usually follows a POST that would have already set the session ID
triggerInitialSessionIDSignal()
}
logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"])
}

try await self.processSSE(stream)
}

/// Processes an SSE byte stream, extracting events and delivering them
///
/// - Parameter stream: The URLSession.AsyncBytes stream to process
/// - Throws: Error for stream processing failures
private func processSSE(_ stream: URLSession.AsyncBytes) async throws {
do {
for try await event in stream.events {
Expand Down
Loading