Skip to content

Commit 6ae4a57

Browse files
anthonjnochafik
andauthored
add custom headers on initial _startOrAuth call (modelcontextprotocol#318)
* add custom headers on initial _startOrAuth call * update client/sse.ts: align commonHeaders w/ streamableHttp version --------- Co-authored-by: Olivier Chafik <[email protected]>
1 parent 7d7896f commit 6ae4a57

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

src/client/sse.test.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,29 @@ describe("SSEClientTransport", () => {
382382
expect(mockAuthProvider.tokens).toHaveBeenCalled();
383383
});
384384

385+
it("attaches custom header from provider on initial SSE connection", async () => {
386+
mockAuthProvider.tokens.mockResolvedValue({
387+
access_token: "test-token",
388+
token_type: "Bearer"
389+
});
390+
const customHeaders = {
391+
"X-Custom-Header": "custom-value",
392+
};
393+
394+
transport = new SSEClientTransport(resourceBaseUrl, {
395+
authProvider: mockAuthProvider,
396+
requestInit: {
397+
headers: customHeaders,
398+
},
399+
});
400+
401+
await transport.start();
402+
403+
expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
404+
expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value");
405+
expect(mockAuthProvider.tokens).toHaveBeenCalled();
406+
});
407+
385408
it("attaches auth header from provider on POST requests", async () => {
386409
mockAuthProvider.tokens.mockResolvedValue({
387410
access_token: "test-token",

src/client/sse.ts

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,8 @@ export class SSEClientTransport implements Transport {
106106
return await this._startOrAuth();
107107
}
108108

109-
private async _commonHeaders(): Promise<HeadersInit> {
110-
const headers = {
111-
...this._requestInit?.headers,
112-
} as HeadersInit & Record<string, string>;
109+
private async _commonHeaders(): Promise<Headers> {
110+
const headers: HeadersInit = {};
113111
if (this._authProvider) {
114112
const tokens = await this._authProvider.tokens();
115113
if (tokens) {
@@ -120,24 +118,24 @@ export class SSEClientTransport implements Transport {
120118
headers["mcp-protocol-version"] = this._protocolVersion;
121119
}
122120

123-
return headers;
121+
return new Headers(
122+
{ ...headers, ...this._requestInit?.headers }
123+
);
124124
}
125125

126126
private _startOrAuth(): Promise<void> {
127-
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
127+
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
128128
return new Promise((resolve, reject) => {
129129
this._eventSource = new EventSource(
130130
this._url.href,
131131
{
132132
...this._eventSourceInit,
133133
fetch: async (url, init) => {
134-
const headers = await this._commonHeaders()
134+
const headers = await this._commonHeaders();
135+
headers.set("Accept", "text/event-stream");
135136
const response = await fetchImpl(url, {
136137
...init,
137-
headers: new Headers({
138-
...headers,
139-
Accept: "text/event-stream"
140-
})
138+
headers,
141139
})
142140

143141
if (response.status === 401 && response.headers.has('www-authenticate')) {
@@ -238,8 +236,7 @@ const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typ
238236
}
239237

240238
try {
241-
const commonHeaders = await this._commonHeaders();
242-
const headers = new Headers(commonHeaders);
239+
const headers = await this._commonHeaders();
243240
headers.set("content-type", "application/json");
244241
const init = {
245242
...this._requestInit,

0 commit comments

Comments
 (0)