Skip to content

Commit 5eb3768

Browse files
authored
fix(amazonq): handle IAM credentials expiration field to be aws sdk versions compatible and add refresh logic to codewhisperer IAM client (#2349)
## Problem Chat and InlineSuggestions doesnt work in Sagemaker instances after initial credential expiration since refresh logic isnt using expiration field properly ## Solution - Standardizes IAM credentials expiration field handling across the codebase to support both AWS SDK v2 (`expireTime`) and v3 (`expiration`) formats - Add credential fetch and refresh lazy loading logic to Codewhisperer IAM client to ensure fresh credentials are used
1 parent f0364c3 commit 5eb3768

File tree

7 files changed

+125
-20
lines changed

7 files changed

+125
-20
lines changed

core/aws-lsp-core/src/credentials/credentialsProvider.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ export interface IamCredentials {
44
accessKeyId: string
55
secretAccessKey: string
66
sessionToken?: string
7+
expiration?: Date // v3 format
78
}
89

910
export interface BearerToken {
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*!
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
import * as assert from 'assert'
7+
import { IdeCredentialsProvider } from './ideCredentialsProvider'
8+
import { IamCredentials } from './credentialsProvider'
9+
import { Connection } from 'vscode-languageserver'
10+
import * as sinon from 'sinon'
11+
12+
describe('IdeCredentialsProvider', function () {
13+
let provider: IdeCredentialsProvider
14+
let mockConnection: sinon.SinonStubbedInstance<Connection>
15+
16+
beforeEach(function () {
17+
mockConnection = {
18+
console: {
19+
info: sinon.stub(),
20+
log: sinon.stub(),
21+
warn: sinon.stub(),
22+
error: sinon.stub(),
23+
},
24+
} as any
25+
provider = new IdeCredentialsProvider(mockConnection as any)
26+
})
27+
28+
describe('validateIamCredentialsFields', function () {
29+
it('throws error when accessKeyId is missing', function () {
30+
const credentials = {
31+
secretAccessKey: 'secret',
32+
} as IamCredentials
33+
34+
assert.throws(() => provider['validateIamCredentialsFields'](credentials), /Missing property: accessKeyId/)
35+
})
36+
37+
it('throws error when secretAccessKey is missing', function () {
38+
const credentials = {
39+
accessKeyId: 'key',
40+
} as IamCredentials
41+
42+
assert.throws(
43+
() => provider['validateIamCredentialsFields'](credentials),
44+
/Missing property: secretAccessKey/
45+
)
46+
})
47+
})
48+
})

core/aws-lsp-core/src/credentials/ideCredentialsProvider.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ export class IdeCredentialsProvider implements CredentialsProvider {
5858
credentialsProtocolMethodNames.iamCredentialsUpdate,
5959
async (request: UpdateCredentialsRequest) => {
6060
try {
61-
const iamCredentials = await this.decodeCredentialsRequestToken<IamCredentials>(request)
61+
const rawCredentials = await this.decodeCredentialsRequestToken<
62+
IamCredentials & { expireTime?: Date }
63+
>(request)
64+
65+
// Normalize legacy expireTime field to standard expiration field
66+
const iamCredentials: IamCredentials = {
67+
...rawCredentials,
68+
expiration: rawCredentials.expiration || rawCredentials.expireTime,
69+
}
6270

6371
this.validateIamCredentialsFields(iamCredentials)
6472

server/aws-lsp-codewhisperer/src/shared/amazonQServiceManager/AmazonQIAMServiceManager.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,13 @@ export class AmazonQIAMServiceManager extends BaseAmazonQServiceManager<
7878
return this.cachedStreamingClient
7979
}
8080

81-
public handleOnCredentialsDeleted(_type: CredentialsType): void {
82-
return
81+
public handleOnCredentialsDeleted(type: CredentialsType): void {
82+
if (type === 'iam') {
83+
this.cachedCodewhispererService?.abortInflightRequests()
84+
this.cachedCodewhispererService = undefined
85+
this.cachedStreamingClient?.abortInflightRequests()
86+
this.cachedStreamingClient = undefined
87+
}
8388
}
8489

8590
public override handleOnUpdateConfiguration(

server/aws-lsp-codewhisperer/src/shared/codeWhispererService.ts

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,46 @@ export class CodeWhispererServiceIAM extends CodeWhispererServiceBase {
234234
region: this.codeWhispererRegion,
235235
endpoint: this.codeWhispererEndpoint,
236236
credentialProvider: new CredentialProviderChain([
237-
() => credentialsProvider.getCredentials('iam') as Credentials,
237+
() => {
238+
const credentials = new Credentials({
239+
accessKeyId: '',
240+
secretAccessKey: '',
241+
sessionToken: '',
242+
})
243+
244+
credentials.get = callback => {
245+
logging.info('CodeWhispererServiceIAM: Attempting to get credentials')
246+
247+
Promise.resolve(credentialsProvider.getCredentials('iam'))
248+
.then((creds: any) => {
249+
logging.info('CodeWhispererServiceIAM: Successfully got credentials')
250+
251+
credentials.accessKeyId = creds.accessKeyId as string
252+
credentials.secretAccessKey = creds.secretAccessKey as string
253+
credentials.sessionToken = creds.sessionToken as string
254+
credentials.expireTime = creds.expiration as Date
255+
callback()
256+
})
257+
.catch(err => {
258+
logging.error(`CodeWhispererServiceIAM: Failed to get credentials: ${err.message}`)
259+
callback(err)
260+
})
261+
}
262+
263+
credentials.needsRefresh = () => {
264+
return (
265+
!credentials.accessKeyId ||
266+
!credentials.secretAccessKey ||
267+
(credentials.expireTime && credentials.expireTime.getTime() - Date.now() < 60000)
268+
) // 1 min buffer
269+
}
270+
271+
credentials.refresh = callback => {
272+
credentials.get(callback)
273+
}
274+
275+
return credentials
276+
},
238277
]),
239278
}
240279
this.client = createCodeWhispererSigv4Client(options, sdkInitializator, logging)

server/aws-lsp-codewhisperer/src/shared/streamingClientService.test.ts

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -272,19 +272,19 @@ describe('StreamingClientServiceIAM', () => {
272272
expect(streamingClientServiceIAM['inflightRequests'].size).to.eq(0)
273273
})
274274

275-
it('uses expireTime from credentials when available', async () => {
275+
it('uses expiration from credentials when available', async () => {
276276
// Get the credential provider function from the client config
277277
const credentialProvider = streamingClientServiceIAM.client.config.credentials
278278
expect(credentialProvider).to.not.be.undefined
279279

280280
// Reset call count on the stub
281281
features.credentialsProvider.getCredentials.resetHistory()
282282

283-
// Set up credentials with expireTime
283+
// Set up credentials with expiration
284284
const futureDate = new Date(Date.now() + 3600000) // 1 hour in the future
285285
const CREDENTIALS_WITH_EXPIRY = {
286286
...MOCKED_IAM_CREDENTIALS,
287-
expireTime: futureDate.toISOString(),
287+
expiration: futureDate,
288288
}
289289
features.credentialsProvider.getCredentials.withArgs('iam').returns(CREDENTIALS_WITH_EXPIRY)
290290

@@ -293,34 +293,29 @@ describe('StreamingClientServiceIAM', () => {
293293
await clock.tickAsync(TIME_TO_ADVANCE_MS)
294294
const credentials = await credentialsPromise
295295

296-
// Verify expiration is set to the expireTime from credentials
296+
// Verify expiration is set to the expiration from credentials
297297
expect(credentials.expiration).to.be.instanceOf(Date)
298298
expect(credentials.expiration.getTime()).to.equal(futureDate.getTime())
299299
})
300300

301-
it('falls back to current date when expireTime is not available', async () => {
301+
it('forces refresh when expiration is not available in credentials', async () => {
302302
// Get the credential provider function from the client config
303303
const credentialProvider = streamingClientServiceIAM.client.config.credentials
304304
expect(credentialProvider).to.not.be.undefined
305305

306306
// Reset call count on the stub
307307
features.credentialsProvider.getCredentials.resetHistory()
308308

309-
// Set up credentials without expireTime
309+
// Set up credentials without expiration
310310
features.credentialsProvider.getCredentials.withArgs('iam').returns(MOCKED_IAM_CREDENTIALS)
311311

312-
// Set a fixed time for testing
313-
const fixedNow = new Date()
314-
clock.tick(0) // Ensure clock is at the fixed time
315-
316312
// Call the credential provider
317313
const credentialsPromise = (credentialProvider as any)()
318314
await clock.tickAsync(TIME_TO_ADVANCE_MS)
319315
const credentials = await credentialsPromise
320316

321-
// Verify expiration is set to current date when expireTime is missing
317+
// Verify expiration is set to current date to force refresh when not provided in credentials
322318
expect(credentials.expiration).to.be.instanceOf(Date)
323-
// The expiration should be very close to the current time
324-
expect(credentials.expiration.getTime()).to.be.closeTo(fixedNow.getTime(), 100)
319+
expect(credentials.expiration.getTime()).to.be.closeTo(Date.now(), 1000)
325320
})
326321
})

server/aws-lsp-codewhisperer/src/shared/streamingClientService.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@ import {
1212
SendMessageCommandInput as SendMessageCommandInputQDeveloperStreaming,
1313
SendMessageCommandOutput as SendMessageCommandOutputQDeveloperStreaming,
1414
} from '@amzn/amazon-q-developer-streaming-client'
15-
import { CredentialsProvider, SDKInitializator, Logging } from '@aws/language-server-runtimes/server-interface'
15+
import {
16+
CredentialsProvider,
17+
SDKInitializator,
18+
Logging,
19+
IamCredentials,
20+
} from '@aws/language-server-runtimes/server-interface'
1621
import { getBearerTokenFromProvider, isUsageLimitError } from './utils'
1722
import { ConfiguredRetryStrategy } from '@aws-sdk/util-retry'
1823
import { CredentialProviderChain, Credentials } from 'aws-sdk'
@@ -188,13 +193,17 @@ export class StreamingClientServiceIAM extends StreamingClientServiceBase {
188193

189194
// Create a credential provider that fetches fresh credentials on each request
190195
const iamCredentialProvider: AwsCredentialIdentityProvider = async (): Promise<AwsCredentialIdentity> => {
191-
const creds = (await credentialsProvider.getCredentials('iam')) as Credentials
196+
const creds = (await credentialsProvider.getCredentials('iam')) as IamCredentials
192197
logging.log(`Fetching new IAM credentials`)
198+
if (!creds) {
199+
logging.log('Failed to fetch IAM credentials: No IAM credentials found')
200+
throw new Error('No IAM credentials found')
201+
}
193202
return {
194203
accessKeyId: creds.accessKeyId,
195204
secretAccessKey: creds.secretAccessKey,
196205
sessionToken: creds.sessionToken,
197-
expiration: creds.expireTime ? new Date(creds.expireTime) : new Date(), // Force refresh on each request if creds do not have expiration time
206+
expiration: creds.expiration ? new Date(creds.expiration) : new Date(), // Force refresh if expiration field is not available
198207
}
199208
}
200209

0 commit comments

Comments
 (0)