Skip to content

fix(amazonq): add cancel support to loading developer profiles #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
} from './qConfigurationServer'
import { TestFeatures } from '@aws/language-server-runtimes/testing'
import { CodeWhispererServiceToken } from '../../shared/codeWhispererService'
import { InitializeParams, Server } from '@aws/language-server-runtimes/server-interface'
import { CancellationTokenSource, InitializeParams, Server } from '@aws/language-server-runtimes/server-interface'
import { AmazonQTokenServiceManager } from '../../shared/amazonQServiceManager/AmazonQTokenServiceManager'
import { setCredentialsForAmazonQTokenServiceManagerFactory } from '../../shared/testUtils'
import { Q_CONFIGURATION_SECTION } from '../../shared/constants'
Expand Down Expand Up @@ -96,6 +96,7 @@ describe('ServerConfigurationProvider', () => {
let codeWhispererService: StubbedInstance<CodeWhispererServiceToken>
let testFeatures: TestFeatures
let listAvailableProfilesHandlerSpy: sinon.SinonSpy
let tokenSource: CancellationTokenSource

const setCredentials = setCredentialsForAmazonQTokenServiceManagerFactory(() => testFeatures)

Expand All @@ -119,6 +120,7 @@ describe('ServerConfigurationProvider', () => {
}

beforeEach(() => {
tokenSource = new CancellationTokenSource()
codeWhispererService = stubInterface<CodeWhispererServiceToken>()
codeWhispererService.listAvailableCustomizations.resolves({
customizations: [],
Expand Down Expand Up @@ -152,14 +154,14 @@ describe('ServerConfigurationProvider', () => {
it(`does not use listAvailableProfiles handler when developer profiles is disabled`, async () => {
setupServerConfigurationProvider(false)

const result = await serverConfigurationProvider.listAvailableProfiles()
const result = await serverConfigurationProvider.listAvailableProfiles(tokenSource.token)

sinon.assert.notCalled(listAvailableProfilesHandlerSpy)
assert.deepStrictEqual(result, [])
})

it(`uses listAvailableProfiles handler when developer profiles is enabled`, async () => {
await serverConfigurationProvider.listAvailableProfiles()
await serverConfigurationProvider.listAvailableProfiles(tokenSource.token)

sinon.assert.calledOnce(listAvailableProfilesHandlerSpy)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ export const QConfigurationServerToken =
case Q_CONFIGURATION_SECTION:
;[customizations, developerProfiles] = await Promise.all([
serverConfigurationProvider.listAvailableCustomizations(),
serverConfigurationProvider.listAvailableProfiles(),
serverConfigurationProvider.listAvailableProfiles(token),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should cancel whole onGetConfigurationFromServer request and return RequestCanceled LSP Error from here to signal that client cancellation was handled by server correctly.

https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#errorCodes

	 * The client has canceled a request and a server has detected
	 * the cancel.
	 */
	export const RequestCancelled: [integer](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#integer) = -32800;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we could for all switch cases here pass the token to a method e.g. handleTokenCancellation, which will throw a RequestCancelled error in case the request was cancelled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure -- I'm going to refrain from adding it to the Q_CUSTOMIZATIONS_CONFIGURATION_SECTION case statement, because customizations code currently does not support tokens.

])

throwIfCancelled(token)

return amazonQServiceManager.getEnableDeveloperProfileSupport()
? { customizations, developerProfiles }
: { customizations }
Expand All @@ -91,7 +93,9 @@ export const QConfigurationServerToken =

return customizations
case Q_DEVELOPER_PROFILES_CONFIGURATION_SECTION:
developerProfiles = await serverConfigurationProvider.listAvailableProfiles()
developerProfiles = await serverConfigurationProvider.listAvailableProfiles(token)

throwIfCancelled(token)

return developerProfiles
default:
Expand All @@ -115,6 +119,12 @@ export const QConfigurationServerToken =
return () => {}
}

function throwIfCancelled(token: CancellationToken) {
if (token.isCancellationRequested) {
throw new ResponseError(LSPErrorCodes.RequestCancelled, 'Request cancelled')
}
}

const ON_GET_CONFIGURATION_FROM_SERVER_ERROR_PREFIX = 'Failed to fetch: '

export class ServerConfigurationProvider {
Expand All @@ -130,7 +140,7 @@ export class ServerConfigurationProvider {
)
}

async listAvailableProfiles(): Promise<AmazonQDeveloperProfile[]> {
async listAvailableProfiles(token: CancellationToken): Promise<AmazonQDeveloperProfile[]> {
if (!this.serviceManager.getEnableDeveloperProfileSupport()) {
this.logging.debug('Q developer profiles disabled - returning empty list')
return []
Expand All @@ -140,6 +150,7 @@ export class ServerConfigurationProvider {
const profiles = await this.listAllAvailableProfilesHandler({
connectionType: this.credentialsProvider.getConnectionType(),
logging: this.logging,
token: token,
})

return profiles
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ export class AmazonQTokenServiceManager extends BaseAmazonQServiceManager<CodeWh
const profiles = await getListAllAvailableProfilesHandler(this.serviceFactory)({
connectionType: 'identityCenter',
logging: this.logging,
token: token,
})

this.handleTokenCancellationRequest(token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import * as assert from 'assert'
import { StubbedInstance, stubInterface } from 'ts-sinon'
import { CodeWhispererServiceToken } from '../codeWhispererService'
import { SsoConnectionType } from '../utils'
import { AWSInitializationOptions, Logging } from '@aws/language-server-runtimes/server-interface'
import {
AWSInitializationOptions,
CancellationTokenSource,
Logging,
} from '@aws/language-server-runtimes/server-interface'
import {
AmazonQDeveloperProfile,
getListAllAvailableProfilesHandler,
Expand Down Expand Up @@ -35,6 +39,7 @@ describe('ListAllAvailableProfiles Handler', () => {

let codeWhispererService: StubbedInstance<CodeWhispererServiceToken>
let handler: ListAllAvailableProfilesHandler
let tokenSource: CancellationTokenSource

const listAvailableProfilesResponse = {
profiles: [
Expand All @@ -57,13 +62,15 @@ describe('ListAllAvailableProfiles Handler', () => {
codeWhispererService.listAvailableProfiles.resolves(listAvailableProfilesResponse)

handler = getListAllAvailableProfilesHandler(() => codeWhispererService)
tokenSource = new CancellationTokenSource()
})

it('should aggregrate profiles retrieved from different regions', async () => {
const profiles = await handler({
connectionType: 'identityCenter',
logging,
endpoints: SOME_AWS_Q_ENDPOINTS,
token: tokenSource.token,
})

assert.strictEqual(
Expand All @@ -78,6 +85,7 @@ describe('ListAllAvailableProfiles Handler', () => {
const profiles = await handler({
connectionType,
logging,
token: tokenSource.token,
})

assert.deepStrictEqual(profiles, [])
Expand All @@ -102,6 +110,7 @@ describe('ListAllAvailableProfiles Handler', () => {
connectionType: 'identityCenter',
logging,
endpoints: SOME_AWS_Q_ENDPOINT,
token: tokenSource.token,
})

assert.strictEqual(codeWhispererService.listAvailableProfiles.callCount, EXPECTED_CALLS)
Expand All @@ -115,6 +124,7 @@ describe('ListAllAvailableProfiles Handler', () => {
connectionType: 'identityCenter',
logging,
endpoints: SOME_AWS_Q_ENDPOINT,
token: tokenSource.token,
})

assert.strictEqual(codeWhispererService.listAvailableProfiles.callCount, MAX_EXPECTED_PAGES)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
AWSInitializationOptions,
CancellationToken,
Logging,
LSPErrorCodes,
ResponseError,
Expand All @@ -22,6 +23,7 @@ export interface ListAllAvailableProfilesHandlerParams {
connectionType: SsoConnectionType
logging: Logging
endpoints?: Map<string, string> // override option for flexibility, we default to all (AWS_Q_ENDPOINTS)
token: CancellationToken
}

export type ListAllAvailableProfilesHandler = (
Expand All @@ -33,7 +35,7 @@ const MAX_Q_DEVELOPER_PROFILES_PER_PAGE = 10

export const getListAllAvailableProfilesHandler =
(service: (region: string, endpoint: string) => CodeWhispererServiceToken): ListAllAvailableProfilesHandler =>
async ({ connectionType, logging, endpoints }) => {
async ({ connectionType, logging, endpoints, token }) => {
if (!connectionType || connectionType !== 'identityCenter') {
logging.debug('Connection type is not set or not identityCenter - returning empty response.')
return []
Expand All @@ -42,13 +44,21 @@ export const getListAllAvailableProfilesHandler =
let allProfiles: AmazonQDeveloperProfile[] = []
const qEndpoints = endpoints ?? AWS_Q_ENDPOINTS

if (token.isCancellationRequested) {
return []
}

const result = await Promise.allSettled(
Array.from(qEndpoints.entries(), ([region, endpoint]) => {
const codeWhispererService = service(region, endpoint)
return fetchProfilesFromRegion(codeWhispererService, region, logging)
return fetchProfilesFromRegion(codeWhispererService, region, logging, token)
})
)

if (token.isCancellationRequested) {
return []
}

const fulfilledResults = result.filter(settledResult => settledResult.status === 'fulfilled')

if (fulfilledResults.length === 0) {
Expand All @@ -63,7 +73,8 @@ export const getListAllAvailableProfilesHandler =
async function fetchProfilesFromRegion(
service: CodeWhispererServiceToken,
region: string,
logging: Logging
logging: Logging,
token: CancellationToken
): Promise<AmazonQDeveloperProfile[]> {
let allRegionalProfiles: AmazonQDeveloperProfile[] = []
let nextToken: string | undefined = undefined
Expand All @@ -73,6 +84,10 @@ async function fetchProfilesFromRegion(
do {
logging.debug(`Fetching profiles from region: ${region} (iteration: ${numberOfPages})`)

if (token.isCancellationRequested) {
return allRegionalProfiles
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we could just return an empty list here, request might be cancelled after some profiles are already aggregated due to pagination and higher up we we will return an empty list after cancellation anyways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is a concern for the caller, whether the caller is within flare or from a client. In both places, the caller can decide what to do given that they have access to the token cancellation state.

}

const response = await service.listAvailableProfiles({
maxResults: MAX_Q_DEVELOPER_PROFILES_PER_PAGE,
nextToken: nextToken,
Expand Down
Loading