Skip to content

Commit 5cbc17d

Browse files
authored
Merge pull request #98 from dlants/dlants-subagents
Dlants subagents
2 parents 8ed7921 + 486d4cf commit 5cbc17d

File tree

7 files changed

+155
-56
lines changed

7 files changed

+155
-56
lines changed

node/chat/thread.ts

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
type Msg as ToolManagerMsg,
1313
type ToolRequest,
1414
type ToolRequestId,
15+
CHAT_TOOL_SPECS,
1516
} from "../tools/toolManager.ts";
1617
import { Counter } from "../utils/uniqueId.ts";
1718
import { FileSnapshots } from "../tools/file-snapshots.ts";
@@ -179,6 +180,8 @@ export class Thread {
179180
},
180181
messages: [],
181182
};
183+
184+
this.updateTokenCount();
182185
}
183186

184187
update(msg: RootMsg): void {
@@ -319,13 +322,10 @@ export class Thread {
319322
event: msg.event,
320323
});
321324

322-
// setTimeout to avoid dispatch-in-dispatch
323-
setTimeout(() =>
324-
this.context.dispatch({
325-
type: "sidebar-update-token-count",
326-
tokenCount: this.getEstimatedTokenCount(),
327-
}),
328-
);
325+
// If this is a content_block_stop event, update the token count
326+
if (msg.event.type === "content_block_stop") {
327+
this.updateTokenCount();
328+
}
329329
return;
330330
}
331331

@@ -345,6 +345,9 @@ export class Thread {
345345
messages: [],
346346
};
347347
this.contextManager.reset();
348+
349+
this.updateTokenCount();
350+
348351
return undefined;
349352
}
350353

@@ -547,16 +550,22 @@ export class Thread {
547550

548551
async sendMessage(content?: string): Promise<void> {
549552
await this.prepareUserMessage(content);
553+
this.updateTokenCount();
550554
const messages = this.getMessages();
555+
551556
const request = getProvider(
552557
this.context.nvim,
553558
this.state.profile,
554-
).sendMessage(messages, (event) => {
555-
this.myDispatch({
556-
type: "stream-event",
557-
event,
558-
});
559-
});
559+
).sendMessage(
560+
messages,
561+
(event) => {
562+
this.myDispatch({
563+
type: "stream-event",
564+
event,
565+
});
566+
},
567+
CHAT_TOOL_SPECS,
568+
);
560569

561570
this.myDispatch({
562571
type: "conversation-state",
@@ -585,6 +594,8 @@ Use the compact_thread tool to analyze my next prompt and extract only the relev
585594
My next prompt will be:
586595
${content}`;
587596

597+
this.updateTokenCount();
598+
588599
const request = getProvider(
589600
this.context.nvim,
590601
this.state.profile,
@@ -757,6 +768,22 @@ Come up with a succinct thread title for this prompt. It should be less than 80
757768
0,
758769
);
759770
}
771+
772+
updateTokenCount() {
773+
const messages = this.getMessages();
774+
const tokenCount = getProvider(
775+
this.context.nvim,
776+
this.state.profile,
777+
).countTokens(messages, CHAT_TOOL_SPECS);
778+
779+
// setTimeout to avoid dispatch-in-dispatch
780+
setTimeout(() =>
781+
this.context.dispatch({
782+
type: "sidebar-update-token-count",
783+
tokenCount,
784+
}),
785+
);
786+
}
760787
}
761788

762789
/**

node/magenta.spec.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ Stopped (tool_use) [input: 0, output: 0]
172172
const displayState = driver.getVisibleState();
173173
{
174174
const winbar = await displayState.inputWindow.getOption("winbar");
175-
expect(winbar).toBe(`Magenta Input (claude-sonnet-3.7)`);
175+
expect(winbar).toContain(`Magenta Input (claude-sonnet-3.7)`);
176176
}
177177
await driver.nvim.call("nvim_command", ["Magenta profile gpt-4o"]);
178178
{
@@ -184,7 +184,7 @@ Stopped (tool_use) [input: 0, output: 0]
184184
apiKeyEnvVar: "OPENAI_API_KEY",
185185
});
186186
const winbar = await displayState.inputWindow.getOption("winbar");
187-
expect(winbar).toBe(`Magenta Input (gpt-4o)`);
187+
expect(winbar).toContain(`Magenta Input (gpt-4o)`);
188188
}
189189
});
190190
});

node/providers/anthropic.ts

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import Anthropic from "@anthropic-ai/sdk";
2-
import * as ToolManager from "../tools/toolManager.ts";
32
import { extendError, type Result } from "../utils/result.ts";
3+
import type { ToolRequest, ToolRequestId } from "../tools/toolManager.ts";
44
import type { Nvim } from "../nvim/nvim-node";
55
import {
66
type Provider,
@@ -13,7 +13,6 @@ import {
1313
} from "./provider-types.ts";
1414
import { assertUnreachable } from "../utils/assertUnreachable.ts";
1515
import { DEFAULT_SYSTEM_PROMPT } from "./constants.ts";
16-
import type { ToolRequest, ToolRequestId } from "../tools/toolManager.ts";
1716
import { validateInput } from "../tools/helpers.ts";
1817

1918
export type MessageParam = Omit<Anthropic.MessageParam, "content"> & {
@@ -73,6 +72,7 @@ export class AnthropicProvider implements Provider {
7372

7473
createStreamParameters(
7574
messages: ProviderMessage[],
75+
tools: Array<ProviderToolSpec>,
7676
options?: { disableCaching?: boolean },
7777
): MessageStreamParams {
7878
const anthropicMessages = messages.map((m): MessageParam => {
@@ -155,14 +155,12 @@ export class AnthropicProvider implements Provider {
155155
cacheControlItemsPlaced = placeCacheBreakpoints(anthropicMessages);
156156
}
157157

158-
const tools: Anthropic.Tool[] = ToolManager.CHAT_TOOL_SPECS.map(
159-
(t): Anthropic.Tool => {
160-
return {
161-
...t,
162-
input_schema: t.input_schema as Anthropic.Messages.Tool.InputSchema,
163-
};
164-
},
165-
);
158+
const anthropicTools: Anthropic.Tool[] = tools.map((t): Anthropic.Tool => {
159+
return {
160+
...t,
161+
input_schema: t.input_schema as Anthropic.Messages.Tool.InputSchema,
162+
};
163+
});
166164

167165
return {
168166
messages: anthropicMessages,
@@ -189,7 +187,7 @@ export class AnthropicProvider implements Provider {
189187
disable_parallel_tool_use: this.disableParallelToolUseFlag || undefined,
190188
},
191189
tools: [
192-
...tools,
190+
...anthropicTools,
193191
{
194192
type: "web_search_20250305",
195193
name: "web_search",
@@ -199,27 +197,25 @@ export class AnthropicProvider implements Provider {
199197
};
200198
}
201199

202-
async countTokens(messages: Array<ProviderMessage>): Promise<number> {
203-
const params = this.createStreamParameters(messages);
204-
const lastMessage = params.messages[params.messages.length - 1];
205-
if (!lastMessage || lastMessage.role != "user") {
206-
params.messages.push({ role: "user", content: "test" });
207-
}
208-
const res = await this.client.messages.countTokens({
209-
messages: params.messages,
210-
model: params.model,
211-
system: params.system as Anthropic.TextBlockParam[],
212-
tools: params.tools as Anthropic.Tool[],
213-
});
214-
return res.input_tokens;
200+
countTokens(
201+
messages: Array<ProviderMessage>,
202+
tools: Array<ProviderToolSpec>,
203+
): number {
204+
const CHARS_PER_TOKEN = 4;
205+
206+
let charCount = DEFAULT_SYSTEM_PROMPT.length;
207+
charCount += JSON.stringify(tools).length;
208+
charCount += JSON.stringify(messages).length;
209+
210+
return Math.ceil(charCount / CHARS_PER_TOKEN);
215211
}
216212

217213
forceToolUse(
218214
messages: Array<ProviderMessage>,
219215
spec: ProviderToolSpec,
220216
): ProviderToolUseRequest {
221217
const request = this.client.messages.stream({
222-
...this.createStreamParameters(messages, { disableCaching: true }),
218+
...this.createStreamParameters(messages, [], { disableCaching: true }),
223219
tools: [
224220
{
225221
...spec,
@@ -348,13 +344,14 @@ export class AnthropicProvider implements Provider {
348344
sendMessage(
349345
messages: Array<ProviderMessage>,
350346
onStreamEvent: (event: ProviderStreamEvent) => void,
347+
tools: Array<ProviderToolSpec>,
351348
): ProviderStreamRequest {
352349
let requestActive = true;
353350
const request = this.client.messages
354351
.stream(
355352
this.createStreamParameters(
356353
messages,
357-
// Use default caching behavior for regular messages
354+
tools,
358355
) as Anthropic.Messages.MessageStreamParams,
359356
)
360357
.on("streamEvent", (e) => {

node/providers/mock.ts

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
type ProviderStreamEvent,
1313
} from "./provider-types.ts";
1414
import { setClient } from "./provider.ts";
15+
import { DEFAULT_SYSTEM_PROMPT } from "./constants.ts";
1516

1617
type MockRequest = {
1718
messages: Array<ProviderMessage>;
@@ -39,13 +40,25 @@ export class MockProvider implements Provider {
3940

4041
setModel(_model: string): void {}
4142

42-
createStreamParameters(messages: Array<ProviderMessage>): unknown {
43-
return messages;
43+
createStreamParameters(
44+
messages: Array<ProviderMessage>,
45+
tools: Array<ProviderToolSpec>,
46+
_options?: { disableCaching?: boolean },
47+
): unknown {
48+
return { messages, tools };
4449
}
4550

46-
// eslint-disable-next-line @typescript-eslint/require-await
47-
async countTokens(messages: Array<ProviderMessage>): Promise<number> {
48-
return messages.length;
51+
countTokens(
52+
messages: Array<ProviderMessage>,
53+
tools: Array<ProviderToolSpec>,
54+
): number {
55+
const CHARS_PER_TOKEN = 4;
56+
57+
let charCount = DEFAULT_SYSTEM_PROMPT.length;
58+
charCount += JSON.stringify(tools).length;
59+
charCount += JSON.stringify(messages).length;
60+
61+
return Math.ceil(charCount / CHARS_PER_TOKEN);
4962
}
5063

5164
forceToolUse(
@@ -72,6 +85,7 @@ export class MockProvider implements Provider {
7285
sendMessage(
7386
messages: Array<ProviderMessage>,
7487
onStreamEvent: (event: ProviderStreamEvent) => void,
88+
_tools: Array<ProviderToolSpec>,
7589
): ProviderStreamRequest {
7690
const request: MockRequest = {
7791
messages,

node/providers/openai.ts

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import OpenAI from "openai";
2-
import * as ToolManager from "../tools/toolManager.ts";
32
import { type Result } from "../utils/result.ts";
3+
import type { ToolRequest, ToolRequestId } from "../tools/toolManager.ts";
44
import type {
55
StopReason,
66
Provider,
@@ -13,7 +13,6 @@ import type {
1313
ProviderToolUseResponse,
1414
} from "./provider-types.ts";
1515
import { assertUnreachable } from "../utils/assertUnreachable.ts";
16-
import type { ToolRequestId } from "../tools/toolManager.ts";
1716
import type { Nvim } from "../nvim/nvim-node";
1817
import type { Stream } from "openai/streaming.mjs";
1918
import { DEFAULT_SYSTEM_PROMPT } from "./constants.ts";
@@ -53,8 +52,21 @@ export class OpenAIProvider implements Provider {
5352
this.model = model;
5453
}
5554

55+
countTokens(
56+
messages: Array<ProviderMessage>,
57+
tools: Array<ProviderToolSpec>,
58+
): number {
59+
const CHARS_PER_TOKEN = 4;
60+
let charCount = DEFAULT_SYSTEM_PROMPT.length;
61+
charCount += JSON.stringify(tools).length;
62+
charCount += JSON.stringify(messages).length;
63+
return Math.ceil(charCount / CHARS_PER_TOKEN);
64+
}
65+
5666
createStreamParameters(
5767
messages: Array<ProviderMessage>,
68+
tools: Array<ProviderToolSpec>,
69+
_options?: { disableCaching?: boolean },
5870
): OpenAI.Responses.ResponseCreateParamsStreaming {
5971
const openaiMessages: OpenAI.Responses.ResponseInputItem[] = [
6072
{
@@ -118,7 +130,7 @@ export class OpenAIProvider implements Provider {
118130
// see https://platform.openai.com/docs/guides/function-calling#parallel-function-calling-and-structured-outputs
119131
// this recommends disabling parallel tool calls when strict adherence to schema is needed
120132
parallel_tool_calls: false,
121-
tools: ToolManager.CHAT_TOOL_SPECS.map((s): OpenAI.Responses.Tool => {
133+
tools: tools.map((s): OpenAI.Responses.Tool => {
122134
return {
123135
type: "function",
124136
name: s.name,
@@ -136,7 +148,7 @@ export class OpenAIProvider implements Provider {
136148
): ProviderToolUseRequest {
137149
let aborted = false;
138150
const promise = (async (): Promise<ProviderToolUseResponse> => {
139-
const params = this.createStreamParameters(messages);
151+
const params = this.createStreamParameters(messages, [spec]);
140152
const response = await this.client.responses.create({
141153
...params,
142154
tool_choice: "required",
@@ -153,7 +165,7 @@ export class OpenAIProvider implements Provider {
153165
});
154166

155167
const tool = response.output[0];
156-
let toolRequest: Result<ToolManager.ToolRequest, { rawRequest: unknown }>;
168+
let toolRequest: Result<ToolRequest, { rawRequest: unknown }>;
157169
try {
158170
if (!(tool && tool.type == "function_call")) {
159171
throw new Error(
@@ -177,7 +189,7 @@ export class OpenAIProvider implements Provider {
177189
toolName: tool.name,
178190
id: tool.call_id as unknown as ToolRequestId,
179191
input: input.value,
180-
} as ToolManager.ToolRequest,
192+
} as ToolRequest,
181193
}
182194
: { ...input, rawRequest: tool.arguments };
183195
} catch (error) {
@@ -220,6 +232,7 @@ export class OpenAIProvider implements Provider {
220232
sendMessage(
221233
messages: Array<ProviderMessage>,
222234
onStreamEvent: (event: ProviderStreamEvent) => void,
235+
tools: Array<ProviderToolSpec>,
223236
): ProviderStreamRequest {
224237
let request: Stream<OpenAI.Responses.ResponseStreamEvent>;
225238
let stopReason: StopReason | undefined;
@@ -230,7 +243,7 @@ export class OpenAIProvider implements Provider {
230243
stopReason: StopReason;
231244
}> => {
232245
request = await this.client.responses.create(
233-
this.createStreamParameters(messages),
246+
this.createStreamParameters(messages, tools),
234247
);
235248

236249
for await (const event of request) {

node/providers/provider-types.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,15 @@ export type ProviderMessageContent =
9393

9494
export interface Provider {
9595
setModel(model: string): void;
96-
createStreamParameters(messages: Array<ProviderMessage>): unknown;
97-
// countTokens(messages: Array<ProviderMessage>): Promise<number>;
96+
createStreamParameters(
97+
messages: Array<ProviderMessage>,
98+
tools: Array<ProviderToolSpec>,
99+
options?: { disableCaching?: boolean },
100+
): unknown;
101+
countTokens(
102+
messages: Array<ProviderMessage>,
103+
tools: Array<ProviderToolSpec>,
104+
): number;
98105
forceToolUse(
99106
messages: Array<ProviderMessage>,
100107
spec: ProviderToolSpec,
@@ -103,6 +110,7 @@ export interface Provider {
103110
sendMessage(
104111
messages: Array<ProviderMessage>,
105112
onStreamEvent: (event: ProviderStreamEvent) => void,
113+
tools: Array<ProviderToolSpec>,
106114
): ProviderStreamRequest;
107115
}
108116

0 commit comments

Comments
 (0)