|
| 1 | +/*--------------------------------------------------------------------------------------------- |
| 2 | + * Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | + * Licensed under the MIT License. See License.txt in the project root for license information. |
| 4 | + *--------------------------------------------------------------------------------------------*/ |
| 5 | + |
| 6 | +import * as vscode from "vscode"; |
| 7 | +import * as TypeMoq from "typemoq"; |
| 8 | +import { expect } from "chai"; |
| 9 | +import { createSqlAgentRequestHandler } from "../../src/chat/chatAgentRequestHandler"; |
| 10 | +import { CopilotService } from "../../src/services/copilotService"; |
| 11 | +import VscodeWrapper from "../../src/controllers/vscodeWrapper"; |
| 12 | +import * as Utils from "../../src/models/utils"; |
| 13 | +import * as sinon from "sinon"; |
| 14 | +import * as telemetry from "../../src/telemetry/telemetry"; |
| 15 | +import { GetNextMessageResponse, MessageType } from "../../src/models/contracts/copilot"; |
| 16 | +import { ActivityObject, ActivityStatus } from "../../src/sharedInterfaces/telemetry"; |
| 17 | + |
| 18 | +suite("Chat Agent Request Handler Tests", () => { |
| 19 | + let mockCopilotService: TypeMoq.IMock<CopilotService>; |
| 20 | + let mockVscodeWrapper: TypeMoq.IMock<VscodeWrapper>; |
| 21 | + let mockContext: TypeMoq.IMock<vscode.ExtensionContext>; |
| 22 | + let mockLmChat: TypeMoq.IMock<vscode.LanguageModelChat>; |
| 23 | + let mockChatStream: TypeMoq.IMock<vscode.ChatResponseStream>; |
| 24 | + let mockChatRequest: TypeMoq.IMock<vscode.ChatRequest>; |
| 25 | + let mockChatContext: TypeMoq.IMock<vscode.ChatContext>; |
| 26 | + let mockToken: TypeMoq.IMock<vscode.CancellationToken>; |
| 27 | + let mockTextDocument: TypeMoq.IMock<vscode.TextDocument>; |
| 28 | + let mockConfiguration: TypeMoq.IMock<vscode.WorkspaceConfiguration>; |
| 29 | + let mockLanguageModelChatResponse: TypeMoq.IMock<vscode.LanguageModelChatResponse>; |
| 30 | + let mockActivityObject: TypeMoq.IMock<ActivityObject>; |
| 31 | + let generateGuidStub: sinon.SinonStub; |
| 32 | + let selectChatModelsStub: sinon.SinonStub; |
| 33 | + let startActivityStub: sinon.SinonStub; |
| 34 | + let sendActionEventStub: sinon.SinonStub; |
| 35 | + let openTextDocumentStub: sinon.SinonStub; |
| 36 | + |
| 37 | + // Sample data for tests |
| 38 | + const sampleConnectionUri = "file:///path/to/sample.sql"; |
| 39 | + const sampleConversationUri = "conversationUri1"; |
| 40 | + const samplePrompt = "Tell me about my database schema"; |
| 41 | + const sampleCorrelationId = "12345678-1234-1234-1234-123456789012"; |
| 42 | + const sampleReplyText = "Here is information about your database schema"; |
| 43 | + |
| 44 | + setup(() => { |
| 45 | + // Create the mock activity object for startActivity to return |
| 46 | + mockActivityObject = TypeMoq.Mock.ofType<ActivityObject>(); |
| 47 | + mockActivityObject |
| 48 | + .setup((x) => x.end(TypeMoq.It.isAny(), TypeMoq.It.isAny())) |
| 49 | + .returns(() => undefined); |
| 50 | + mockActivityObject |
| 51 | + .setup((x) => |
| 52 | + x.endFailed( |
| 53 | + TypeMoq.It.isAny(), |
| 54 | + TypeMoq.It.isAny(), |
| 55 | + TypeMoq.It.isAny(), |
| 56 | + TypeMoq.It.isAny(), |
| 57 | + TypeMoq.It.isAny(), |
| 58 | + ), |
| 59 | + ) |
| 60 | + .returns(() => undefined); |
| 61 | + |
| 62 | + // Stub telemetry functions |
| 63 | + startActivityStub = sinon |
| 64 | + .stub(telemetry, "startActivity") |
| 65 | + .returns(mockActivityObject.object); |
| 66 | + sendActionEventStub = sinon.stub(telemetry, "sendActionEvent"); |
| 67 | + |
| 68 | + // Stub the generateGuid function using sinon |
| 69 | + generateGuidStub = sinon.stub(Utils, "generateGuid").returns(sampleCorrelationId); |
| 70 | + |
| 71 | + // Create a mock LanguageModelChat |
| 72 | + mockLmChat = TypeMoq.Mock.ofType<vscode.LanguageModelChat>(); |
| 73 | + |
| 74 | + // Stub the vscode.lm.selectChatModels function |
| 75 | + // First, ensure the lm object exists |
| 76 | + if (!vscode.lm) { |
| 77 | + // Create the object if it doesn't exist for testing |
| 78 | + (vscode as any).lm = { selectChatModels: () => Promise.resolve([]) }; |
| 79 | + } |
| 80 | + |
| 81 | + // Now stub the function |
| 82 | + selectChatModelsStub = sinon |
| 83 | + .stub(vscode.lm, "selectChatModels") |
| 84 | + .resolves([mockLmChat.object]); |
| 85 | + |
| 86 | + // Mock CopilotService |
| 87 | + mockCopilotService = TypeMoq.Mock.ofType<CopilotService>(); |
| 88 | + |
| 89 | + // Mock VscodeWrapper |
| 90 | + mockVscodeWrapper = TypeMoq.Mock.ofType<VscodeWrapper>(); |
| 91 | + mockVscodeWrapper.setup((x) => x.activeTextEditorUri).returns(() => sampleConnectionUri); |
| 92 | + |
| 93 | + // Mock configuration |
| 94 | + mockConfiguration = TypeMoq.Mock.ofType<vscode.WorkspaceConfiguration>(); |
| 95 | + mockConfiguration |
| 96 | + .setup((x) => x.get(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())) |
| 97 | + .returns(() => false); |
| 98 | + mockVscodeWrapper |
| 99 | + .setup((x) => x.getConfiguration()) |
| 100 | + .returns(() => mockConfiguration.object); |
| 101 | + |
| 102 | + // Mock ExtensionContext |
| 103 | + mockContext = TypeMoq.Mock.ofType<vscode.ExtensionContext>(); |
| 104 | + mockContext |
| 105 | + .setup((x) => x.languageModelAccessInformation) |
| 106 | + .returns( |
| 107 | + () => |
| 108 | + ({ |
| 109 | + canSendRequest: () => "allowed", |
| 110 | + }) as any, |
| 111 | + ); |
| 112 | + |
| 113 | + // Mock ChatResponseStream |
| 114 | + mockChatStream = TypeMoq.Mock.ofType<vscode.ChatResponseStream>(); |
| 115 | + mockChatStream.setup((x) => x.progress(TypeMoq.It.isAnyString())).returns(() => undefined); |
| 116 | + mockChatStream.setup((x) => x.markdown(TypeMoq.It.isAnyString())).returns(() => undefined); |
| 117 | + |
| 118 | + // Mock Chat Request |
| 119 | + mockChatRequest = TypeMoq.Mock.ofType<vscode.ChatRequest>(); |
| 120 | + mockChatRequest.setup((x) => x.prompt).returns(() => samplePrompt); |
| 121 | + mockChatRequest.setup((x) => x.references).returns(() => []); |
| 122 | + |
| 123 | + // Mock Chat Context |
| 124 | + mockChatContext = TypeMoq.Mock.ofType<vscode.ChatContext>(); |
| 125 | + mockChatContext.setup((x) => x.history).returns(() => []); |
| 126 | + |
| 127 | + // Mock CancellationToken |
| 128 | + mockToken = TypeMoq.Mock.ofType<vscode.CancellationToken>(); |
| 129 | + |
| 130 | + // Mock LanguageModelChatResponse |
| 131 | + mockLanguageModelChatResponse = TypeMoq.Mock.ofType<vscode.LanguageModelChatResponse>(); |
| 132 | + mockLanguageModelChatResponse |
| 133 | + .setup((x) => x.stream) |
| 134 | + .returns(() => |
| 135 | + (async function* () { |
| 136 | + yield new vscode.LanguageModelTextPart(sampleReplyText); |
| 137 | + })(), |
| 138 | + ); |
| 139 | + |
| 140 | + // Mock Language Model API |
| 141 | + mockLmChat |
| 142 | + .setup((x) => x.sendRequest(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())) |
| 143 | + .returns(() => Promise.resolve(mockLanguageModelChatResponse.object)); |
| 144 | + |
| 145 | + // Mock TextDocument for reference handling |
| 146 | + mockTextDocument = TypeMoq.Mock.ofType<vscode.TextDocument>(); |
| 147 | + mockTextDocument |
| 148 | + .setup((x) => x.getText(TypeMoq.It.isAny())) |
| 149 | + .returns(() => "SELECT * FROM users"); |
| 150 | + mockTextDocument.setup((x) => x.languageId).returns(() => "sql"); |
| 151 | + |
| 152 | + // Stub the workspace.openTextDocument method instead of replacing the entire workspace object |
| 153 | + openTextDocumentStub = sinon |
| 154 | + .stub(vscode.workspace, "openTextDocument") |
| 155 | + .resolves(mockTextDocument.object); |
| 156 | + }); |
| 157 | + |
| 158 | + teardown(() => { |
| 159 | + // Restore all stubbed functions |
| 160 | + generateGuidStub.restore(); |
| 161 | + |
| 162 | + if (selectChatModelsStub) { |
| 163 | + selectChatModelsStub.restore(); |
| 164 | + } |
| 165 | + |
| 166 | + if (startActivityStub) { |
| 167 | + startActivityStub.restore(); |
| 168 | + } |
| 169 | + |
| 170 | + if (sendActionEventStub) { |
| 171 | + sendActionEventStub.restore(); |
| 172 | + } |
| 173 | + |
| 174 | + if (openTextDocumentStub) { |
| 175 | + openTextDocumentStub.restore(); |
| 176 | + } |
| 177 | + |
| 178 | + // Clean up any remaining stubs |
| 179 | + sinon.restore(); |
| 180 | + }); |
| 181 | + |
| 182 | + test("Creates a valid chat request handler", () => { |
| 183 | + const handler = createSqlAgentRequestHandler( |
| 184 | + mockCopilotService.object, |
| 185 | + mockVscodeWrapper.object, |
| 186 | + mockContext.object, |
| 187 | + ); |
| 188 | + |
| 189 | + expect(handler).to.be.a("function"); |
| 190 | + }); |
| 191 | + |
| 192 | + test("Returns early with a default response when no models are found", async () => { |
| 193 | + // Setup stub to return empty array for this specific test |
| 194 | + selectChatModelsStub.resolves([]); |
| 195 | + |
| 196 | + const handler = createSqlAgentRequestHandler( |
| 197 | + mockCopilotService.object, |
| 198 | + mockVscodeWrapper.object, |
| 199 | + mockContext.object, |
| 200 | + ); |
| 201 | + |
| 202 | + const result = await handler( |
| 203 | + mockChatRequest.object, |
| 204 | + mockChatContext.object, |
| 205 | + mockChatStream.object, |
| 206 | + mockToken.object, |
| 207 | + ); |
| 208 | + |
| 209 | + mockChatStream.verify((x) => x.markdown("No model found."), TypeMoq.Times.once()); |
| 210 | + expect(result).to.deep.equal({ |
| 211 | + metadata: { command: "", correlationId: sampleCorrelationId }, |
| 212 | + }); |
| 213 | + }); |
| 214 | + |
| 215 | + test("Handles successful conversation flow with complete message type", async () => { |
| 216 | + // Setup mocks for startConversation |
| 217 | + mockCopilotService |
| 218 | + .setup((x) => |
| 219 | + x.startConversation( |
| 220 | + TypeMoq.It.isAnyString(), |
| 221 | + TypeMoq.It.isAnyString(), |
| 222 | + TypeMoq.It.isAnyString(), |
| 223 | + ), |
| 224 | + ) |
| 225 | + .returns(() => Promise.resolve(true)); |
| 226 | + |
| 227 | + // Mock the getNextMessage to return a Complete message type |
| 228 | + const completeResponse: GetNextMessageResponse = { |
| 229 | + conversationUri: sampleConversationUri, |
| 230 | + messageType: MessageType.Complete, |
| 231 | + responseText: "Conversation completed", |
| 232 | + tools: [], |
| 233 | + requestMessages: [], |
| 234 | + }; |
| 235 | + |
| 236 | + mockCopilotService |
| 237 | + .setup((x) => |
| 238 | + x.getNextMessage( |
| 239 | + TypeMoq.It.isAnyString(), |
| 240 | + TypeMoq.It.isAnyString(), |
| 241 | + TypeMoq.It.isAny(), |
| 242 | + TypeMoq.It.isAny(), |
| 243 | + ), |
| 244 | + ) |
| 245 | + .returns(() => Promise.resolve(completeResponse)); |
| 246 | + |
| 247 | + const handler = createSqlAgentRequestHandler( |
| 248 | + mockCopilotService.object, |
| 249 | + mockVscodeWrapper.object, |
| 250 | + mockContext.object, |
| 251 | + ); |
| 252 | + |
| 253 | + const result = await handler( |
| 254 | + mockChatRequest.object, |
| 255 | + mockChatContext.object, |
| 256 | + mockChatStream.object, |
| 257 | + mockToken.object, |
| 258 | + ); |
| 259 | + |
| 260 | + mockCopilotService.verify( |
| 261 | + (x) => x.startConversation(TypeMoq.It.isAnyString(), sampleConnectionUri, samplePrompt), |
| 262 | + TypeMoq.Times.once(), |
| 263 | + ); |
| 264 | + |
| 265 | + mockCopilotService.verify( |
| 266 | + (x) => |
| 267 | + x.getNextMessage( |
| 268 | + TypeMoq.It.isAnyString(), |
| 269 | + TypeMoq.It.isAnyString(), |
| 270 | + TypeMoq.It.isAny(), |
| 271 | + TypeMoq.It.isAny(), |
| 272 | + ), |
| 273 | + TypeMoq.Times.once(), |
| 274 | + ); |
| 275 | + |
| 276 | + // Verify startActivity was called |
| 277 | + sinon.assert.called(startActivityStub); |
| 278 | + |
| 279 | + // Verify end was called on the activity object |
| 280 | + mockActivityObject.verify( |
| 281 | + (x) => x.end(ActivityStatus.Succeeded, TypeMoq.It.isAny()), |
| 282 | + TypeMoq.Times.once(), |
| 283 | + ); |
| 284 | + |
| 285 | + expect(result).to.deep.equal({ |
| 286 | + metadata: { command: "", correlationId: sampleCorrelationId }, |
| 287 | + }); |
| 288 | + }); |
| 289 | + |
| 290 | + test("Handles conversation with Fragment message type", async () => { |
| 291 | + // Setup mocks for startConversation |
| 292 | + mockCopilotService |
| 293 | + .setup((x) => |
| 294 | + x.startConversation( |
| 295 | + TypeMoq.It.isAnyString(), |
| 296 | + TypeMoq.It.isAnyString(), |
| 297 | + TypeMoq.It.isAnyString(), |
| 298 | + ), |
| 299 | + ) |
| 300 | + .returns(() => Promise.resolve(true)); |
| 301 | + |
| 302 | + // First return a Fragment message type |
| 303 | + const fragmentResponse: GetNextMessageResponse = { |
| 304 | + conversationUri: sampleConversationUri, |
| 305 | + messageType: MessageType.Fragment, |
| 306 | + responseText: "Fragment message", |
| 307 | + tools: [], |
| 308 | + requestMessages: [], |
| 309 | + }; |
| 310 | + |
| 311 | + // Then return a Complete message type |
| 312 | + const completeResponse: GetNextMessageResponse = { |
| 313 | + conversationUri: sampleConversationUri, |
| 314 | + messageType: MessageType.Complete, |
| 315 | + responseText: "Conversation completed", |
| 316 | + tools: [], |
| 317 | + requestMessages: [], |
| 318 | + }; |
| 319 | + |
| 320 | + let callCount = 0; |
| 321 | + const responses = [fragmentResponse, completeResponse]; |
| 322 | + |
| 323 | + mockCopilotService |
| 324 | + .setup((x) => |
| 325 | + x.getNextMessage( |
| 326 | + TypeMoq.It.isAnyString(), |
| 327 | + TypeMoq.It.isAnyString(), |
| 328 | + TypeMoq.It.isAny(), |
| 329 | + TypeMoq.It.isAny(), |
| 330 | + ), |
| 331 | + ) |
| 332 | + .returns(() => { |
| 333 | + return Promise.resolve(responses[callCount++]); |
| 334 | + }); |
| 335 | + |
| 336 | + const handler = createSqlAgentRequestHandler( |
| 337 | + mockCopilotService.object, |
| 338 | + mockVscodeWrapper.object, |
| 339 | + mockContext.object, |
| 340 | + ); |
| 341 | + |
| 342 | + await handler( |
| 343 | + mockChatRequest.object, |
| 344 | + mockChatContext.object, |
| 345 | + mockChatStream.object, |
| 346 | + mockToken.object, |
| 347 | + ); |
| 348 | + |
| 349 | + mockCopilotService.verify( |
| 350 | + (x) => |
| 351 | + x.getNextMessage( |
| 352 | + TypeMoq.It.isAnyString(), |
| 353 | + TypeMoq.It.isAnyString(), |
| 354 | + TypeMoq.It.isAny(), |
| 355 | + TypeMoq.It.isAny(), |
| 356 | + ), |
| 357 | + TypeMoq.Times.exactly(2), |
| 358 | + ); |
| 359 | + }); |
| 360 | + |
| 361 | + test("Handles errors during conversation gracefully", async () => { |
| 362 | + // Setup mocks for startConversation to throw |
| 363 | + mockCopilotService |
| 364 | + .setup((x) => |
| 365 | + x.startConversation( |
| 366 | + TypeMoq.It.isAnyString(), |
| 367 | + TypeMoq.It.isAnyString(), |
| 368 | + TypeMoq.It.isAnyString(), |
| 369 | + ), |
| 370 | + ) |
| 371 | + .throws(new Error("Connection failed")); |
| 372 | + |
| 373 | + const handler = createSqlAgentRequestHandler( |
| 374 | + mockCopilotService.object, |
| 375 | + mockVscodeWrapper.object, |
| 376 | + mockContext.object, |
| 377 | + ); |
| 378 | + |
| 379 | + await handler( |
| 380 | + mockChatRequest.object, |
| 381 | + mockChatContext.object, |
| 382 | + mockChatStream.object, |
| 383 | + mockToken.object, |
| 384 | + ); |
| 385 | + |
| 386 | + // Should show error message |
| 387 | + mockChatStream.verify( |
| 388 | + (x) => x.markdown(TypeMoq.It.is((msg) => msg.toString().includes("An error occurred"))), |
| 389 | + TypeMoq.Times.once(), |
| 390 | + ); |
| 391 | + }); |
| 392 | +}); |
0 commit comments