Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,14 @@ export function setModelDefaults(providerID?: string, modelID?: string): void {
export function applyModelDefaults(
providerID?: string,
modelID?: string,
variant?: string,
): { providerID: string; modelID: string; variant?: string } | undefined {
): { providerID: string; modelID: string } | undefined {
// Explicit params take priority
if (providerID && modelID) {
return { providerID, modelID, ...(variant ? { variant } : {}) };
return { providerID, modelID };
}
// Fall back to env-var defaults
if (_defaultProviderID && _defaultModelID) {
return { providerID: _defaultProviderID, modelID: _defaultModelID, ...(variant ? { variant } : {}) };
return { providerID: _defaultProviderID, modelID: _defaultModelID };
}
// No defaults available — let the server decide
return undefined;
Expand Down
14 changes: 8 additions & 6 deletions src/tools/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ export function registerMessageTools(
const body: Record<string, unknown> = {
parts: [{ type: "text", text }],
};
const model = applyModelDefaults(providerID, modelID, variant);
const model = applyModelDefaults(providerID, modelID);
if (model) body.model = model;
if (variant) body.variant = variant;
if (agent) body.agent = agent;
if (noReply !== undefined) body.noReply = noReply;
if (system) body.system = system;
Expand Down Expand Up @@ -155,8 +156,9 @@ export function registerMessageTools(
const body: Record<string, unknown> = {
parts: [{ type: "text", text }],
};
const model = applyModelDefaults(providerID, modelID, variant);
const model = applyModelDefaults(providerID, modelID);
if (model) body.model = model;
if (variant) body.variant = variant;
if (agent) body.agent = agent;
await client.post(`/session/${sessionId}/prompt_async`, body, { directory });
return toolResult(
Expand Down Expand Up @@ -202,8 +204,9 @@ export function registerMessageTools(
arguments: args ?? "",
};
if (agent) body.agent = agent;
const cmdModel = applyModelDefaults(providerID, modelID, variant);
const cmdModel = applyModelDefaults(providerID, modelID);
if (cmdModel) body.model = cmdModel;
if (variant) body.variant = variant;
const result = await client.post(
`/session/${sessionId}/command`,
body,
Expand All @@ -225,13 +228,12 @@ export function registerMessageTools(
agent: z.string().describe("Agent to use for the shell command"),
providerID: z.string().optional().describe("Provider ID"),
modelID: z.string().optional().describe("Model ID"),
variant: z.string().optional().describe("Model variant"),
directory: directoryParam,
},
async ({ sessionId, command, agent, providerID, modelID, variant, directory }) => {
async ({ sessionId, command, agent, providerID, modelID, directory }) => {
try {
const body: Record<string, unknown> = { command, agent };
const shellModel = applyModelDefaults(providerID, modelID, variant);
const shellModel = applyModelDefaults(providerID, modelID);
if (shellModel) body.model = shellModel;
const result = await client.post(
`/session/${sessionId}/shell`,
Expand Down
15 changes: 10 additions & 5 deletions src/tools/workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ export function registerWorkflowTools(
const body: Record<string, unknown> = {
parts: [{ type: "text", text: prompt }],
};
const model = applyModelDefaults(providerID, modelID, variant);
const model = applyModelDefaults(providerID, modelID);
if (model) body.model = model;
if (variant) body.variant = variant;
if (agent) body.agent = agent;
if (system) body.system = system;

Expand Down Expand Up @@ -308,8 +309,9 @@ export function registerWorkflowTools(
const body: Record<string, unknown> = {
parts: [{ type: "text", text: prompt }],
};
const model = applyModelDefaults(providerID, modelID, variant);
const model = applyModelDefaults(providerID, modelID);
if (model) body.model = model;
if (variant) body.variant = variant;
if (agent) body.agent = agent;

const response = await client.post(
Expand Down Expand Up @@ -607,10 +609,11 @@ export function registerWorkflowTools(
};
// Use the specified model, or let the provider pick its default
if (modelID) {
body.model = { providerID: providerId, modelID, ...(variant ? { variant } : {}) };
body.model = { providerID: providerId, modelID };
} else {
body.providerID = providerId;
}
if (variant) body.variant = variant;

const response = await client.post(
`/session/${sessionId}/message`,
Expand Down Expand Up @@ -688,8 +691,9 @@ export function registerWorkflowTools(
parts: [{ type: "text", text: prompt }],
noReply: false,
};
const model = applyModelDefaults(providerID, modelID, variant);
const model = applyModelDefaults(providerID, modelID);
if (model) body.model = model;
if (variant) body.variant = variant;
if (agent) body.agent = agent;

await client.post(`/session/${sid}/message`, body, { directory });
Expand Down Expand Up @@ -802,8 +806,9 @@ export function registerWorkflowTools(
parts: [{ type: "text", text: prompt }],
noReply: false,
};
const model = applyModelDefaults(providerID, modelID, variant);
const model = applyModelDefaults(providerID, modelID);
if (model) body.model = model;
if (variant) body.variant = variant;
if (agent) body.agent = agent;

await client.post(`/session/${sid}/message`, body, { directory });
Expand Down
86 changes: 84 additions & 2 deletions tests/tools.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ function createMockClient(overrides: Record<string, unknown> = {}) {
// ─── Tool registration capture ───────────────────────────────────────────

function captureTools(registerFn: (server: McpServer, client: OpenCodeClient) => void) {
const tools = new Map<string, { description: string; handler: Function }>();
const tools = new Map<string, { description: string; schema: unknown; handler: Function }>();
const mockServer = {
tool: vi.fn((...args: unknown[]) => {
// Handle both 4-arg (name, desc, schema, handler) and
// 5-arg (name, desc, schema, annotations, handler) forms
const name = args[0] as string;
const description = args[1] as string;
const schema = args[2];
const handler = args[args.length - 1] as Function;
tools.set(name, { description, handler });
tools.set(name, { description, schema, handler });
}),
} as unknown as McpServer;
const mockClient = createMockClient();
Expand Down Expand Up @@ -103,6 +104,17 @@ describe("Tool registration", () => {
});
});

describe("registerMessageTools", () => {
it("does not expose unsupported variant parameter on shell execution", () => {
const { tools } = captureTools(registerMessageTools);
const shell = tools.get("opencode_shell_execute")!;
const schema = shell.schema as Record<string, unknown>;
expect(schema).toHaveProperty("providerID");
expect(schema).toHaveProperty("modelID");
expect(schema).not.toHaveProperty("variant");
});
});

describe("registerFileTools", () => {
it("registers 6 file tools", () => {
const { tools } = captureTools(registerFileTools);
Expand Down Expand Up @@ -230,6 +242,19 @@ describe("Tool handlers", () => {
expect(body.model).toEqual({ providerID: "anthropic", modelID: "claude-3" });
});

it("sends variant as a top-level prompt field", async () => {
await handler({
prompt: "test",
providerID: "openai",
modelID: "gpt-5.4",
variant: "xhigh",
});
const [, body] = (mockClient.post as ReturnType<typeof vi.fn>).mock.calls[1];
expect(body.model).toEqual({ providerID: "openai", modelID: "gpt-5.4" });
expect(body.model).not.toHaveProperty("variant");
expect(body.variant).toBe("xhigh");
});

it("includes agent when set", async () => {
await handler({ prompt: "test", agent: "build" });
const [, body] = (mockClient.post as ReturnType<typeof vi.fn>).mock.calls[1];
Expand Down Expand Up @@ -335,6 +360,34 @@ describe("Tool handlers", () => {
expect(result.content[0].text).toContain("Sure, here you go");
expect(result.content[0].text).not.toContain("WARNING");
});

it("sends variant as a top-level field", async () => {
const mockClient = createMockClient({
post: vi.fn().mockResolvedValueOnce({
info: { id: "m2", role: "assistant" },
parts: [{ type: "text", text: "Sure" }],
}),
});
const tools = new Map<string, Function>();
const mockServer = {
tool: vi.fn((...args: unknown[]) => {
tools.set(args[0] as string, args[args.length - 1] as Function);
}),
} as unknown as McpServer;
registerWorkflowTools(mockServer, mockClient);
const handler = tools.get("opencode_reply")!;
await handler({
sessionId: "s1",
prompt: "follow up",
providerID: "openai",
modelID: "gpt-5.4",
variant: "low",
});
const [, body] = (mockClient.post as ReturnType<typeof vi.fn>).mock.calls[0];
expect(body.model).toEqual({ providerID: "openai", modelID: "gpt-5.4" });
expect(body.model).not.toHaveProperty("variant");
expect(body.variant).toBe("low");
});
});

describe("opencode_context", () => {
Expand Down Expand Up @@ -1156,6 +1209,35 @@ describe("Tool handlers", () => {
expect(result.content[0].text).not.toContain("WARNING");
});

it("sends variant as a top-level field", async () => {
const mockClient = createMockClient({
post: vi.fn().mockResolvedValueOnce({
info: { id: "m1", role: "assistant" },
parts: [{ type: "text", text: "Here is the answer" }],
}),
});
const tools = new Map<string, Function>();
const mockServer = {
tool: vi.fn((...args: unknown[]) => {
tools.set(args[0] as string, args[args.length - 1] as Function);
}),
} as unknown as McpServer;
registerMessageTools(mockServer, mockClient);

const handler = tools.get("opencode_message_send")!;
await handler({
sessionId: "s1",
text: "hello",
providerID: "openai",
modelID: "gpt-5.4",
variant: "xhigh",
});
const [, body] = (mockClient.post as ReturnType<typeof vi.fn>).mock.calls[0];
expect(body.model).toEqual({ providerID: "openai", modelID: "gpt-5.4" });
expect(body.model).not.toHaveProperty("variant");
expect(body.variant).toBe("xhigh");
});

it("returns 'Empty response.' when formatted output is empty string", async () => {
const mockClient = createMockClient({
post: vi.fn().mockResolvedValueOnce(null),
Expand Down