Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/tasty-lobsters-learn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": minor
---

Support hosted Stagehand `model.providerOptions` for Bedrock and Vertex model configuration.
114 changes: 101 additions & 13 deletions packages/core/lib/v3/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ import type {
SerializableResponse,
AgentCacheTransferPayload,
} from "./types/private/index.js";
import type { ModelConfiguration } from "./types/public/model.js";
import type {
ClientOptions,
ModelConfiguration,
} from "./types/public/model.js";
import {
normalizeClientOptionsForModel,
toApiModelClientOptions,
} from "./modelProviderOptions.js";
import { toJsonSchema } from "./zodCompat.js";
import type { StagehandZodSchema } from "./zodCompat.js";

Expand Down Expand Up @@ -97,15 +104,17 @@ interface StagehandAPIConstructorParams {

/**
* Parameters for starting a session via the API client.
* Extends Api.SessionStartRequest with client-specific field (modelApiKey).
*
* Wire format: Api.SessionStartRequest (modelApiKey sent via header, not body)
*/
interface ClientSessionStartParams extends Api.SessionStartRequest {
interface ClientSessionStartParams
extends Omit<Api.SessionStartRequest, "modelClientOptions"> {
/** Model API key - sent via x-model-api-key header, not in request body.
* Optional: when omitted, requests are sent without the x-model-api-key header
* and the server is expected to handle model authentication on its own. */
modelApiKey?: string;
/** SDK model client options serialized into the hosted API wire format. */
modelClientOptions?: ClientOptions;
}

/**
Expand Down Expand Up @@ -180,6 +189,8 @@ export class StagehandAPIClient {
private sessionId?: string;
private modelApiKey?: string;
private modelProvider?: string;
/** Serialized session model config, resent on each hosted action when needed. */
private sessionModelConfig?: Api.ModelConfig;
private region?: BrowserbaseRegion;
private logger: (message: LogLine) => void;
private fetchWithCookies;
Expand All @@ -201,9 +212,26 @@ export class StagehandAPIClient {
this.fetchWithCookies = makeFetchCookie(fetch);
}

private shouldSendModelApiKeyHeader(
modelClientOptions?: Api.ModelClientOptions,
): boolean {
const providerConfig =
modelClientOptions?.providerConfig &&
typeof modelClientOptions.providerConfig === "object" &&
!Array.isArray(modelClientOptions.providerConfig)
? modelClientOptions.providerConfig
: undefined;

return (
providerConfig?.provider !== "bedrock" &&
providerConfig?.provider !== "vertex"
);
}

async init({
modelName,
modelApiKey,
modelClientOptions,
domSettleTimeoutMs,
verbose,
systemPrompt,
Expand All @@ -212,7 +240,26 @@ export class StagehandAPIClient {
browserbaseSessionID,
// browser, TODO for local browsers
}: ClientSessionStartParams): Promise<Api.SessionStartResult> {
this.modelApiKey = modelApiKey;
const serializedModelClientOptions = this.toSessionStartModelClientOptions(
modelClientOptions,
modelName,
);
this.modelApiKey = this.shouldSendModelApiKeyHeader(
serializedModelClientOptions,
)
? modelApiKey
: undefined;
if (
modelName &&
serializedModelClientOptions &&
Object.keys(serializedModelClientOptions).length > 0
) {
this.sessionModelConfig = {
modelName,
...serializedModelClientOptions,
} as Api.ModelConfig;
}

// Extract provider from modelName (e.g., "openai/gpt-5-nano" -> "openai")
this.modelProvider = modelName?.includes("/")
? modelName.split("/")[0]
Expand All @@ -230,6 +277,7 @@ export class StagehandAPIClient {
// Build wire-format request body (Api.SessionStartRequest shape)
const requestBody: Api.SessionStartRequest = {
modelName,
modelClientOptions: serializedModelClientOptions,
domSettleTimeoutMs,
verbose,
systemPrompt,
Expand Down Expand Up @@ -294,6 +342,7 @@ export class StagehandAPIClient {
wireOptions = restOptions as unknown as Api.ActRequest["options"];
}
}
wireOptions = this.ensureModelConfig(wireOptions);

// Build wire-format request body
const requestBody: Api.ActRequest = {
Expand Down Expand Up @@ -332,6 +381,7 @@ export class StagehandAPIClient {
wireOptions = restOptions as unknown as Api.ExtractRequest["options"];
}
}
wireOptions = this.ensureModelConfig(wireOptions);

// Build wire-format request body
const requestBody: Api.ExtractRequest = {
Expand Down Expand Up @@ -367,6 +417,7 @@ export class StagehandAPIClient {
wireOptions = restOptions as unknown as Api.ObserveRequest["options"];
}
}
wireOptions = this.ensureModelConfig(wireOptions);

// Build wire-format request body
const requestBody: Api.ObserveRequest = {
Expand All @@ -387,7 +438,11 @@ export class StagehandAPIClient {
options?: Api.NavigateRequest["options"],
frameId?: string,
): Promise<SerializableResponse | null> {
const requestBody: Api.NavigateRequest = { url, options, frameId };
const requestBody: Api.NavigateRequest = {
url,
options: this.ensureModelConfig(options),
frameId,
};

return this.execute<SerializableResponse | null>({
method: "navigate",
Expand Down Expand Up @@ -424,7 +479,7 @@ export class StagehandAPIClient {
cua: agentConfig.mode === undefined ? agentConfig.cua : undefined,
model: agentConfig.model
? this.prepareModelConfig(agentConfig.model)
: undefined,
: this.sessionModelConfig,
executionModel: agentConfig.executionModel
? this.prepareModelConfig(agentConfig.executionModel)
: undefined,
Expand Down Expand Up @@ -606,7 +661,7 @@ export class StagehandAPIClient {
*/
private prepareModelConfig(
model: ModelConfiguration,
): { modelName: string; apiKey?: string } & Record<string, unknown> {
): { modelName: string } & Record<string, unknown> {
if (typeof model === "string") {
// Extract provider from model string (e.g., "openai/gpt-5-nano" -> "openai")
const provider = model.includes("/") ? model.split("/")[0] : undefined;
Expand All @@ -620,7 +675,14 @@ export class StagehandAPIClient {
};
}

if (!model.apiKey) {
const normalizedModel = {
modelName: model.modelName,
...(this.toSessionStartModelClientOptions(model, model.modelName) ?? {}),
};

const normalizedApiKey = (normalizedModel as Record<string, unknown>)
.apiKey;
if (typeof normalizedApiKey !== "string" || !normalizedApiKey) {
const provider = model.modelName?.includes("/")
? model.modelName.split("/")[0]
: undefined;
Expand All @@ -629,15 +691,41 @@ export class StagehandAPIClient {
? (loadApiKeyFromEnv(provider, this.logger) ?? this.modelApiKey)
: this.modelApiKey;
return {
...model,
...normalizedModel,
...(apiKey ? { apiKey } : {}),
};
}

return model as { modelName: string; apiKey: string } & Record<
string,
unknown
>;
return normalizedModel as { modelName: string } & Record<string, unknown>;
}

/**
* If no model config is present in the wire options, inject the session
* default model config so hosted deployments receive provider-native auth on
* every action.
*/
private ensureModelConfig<T extends { model?: unknown } | undefined>(
wireOptions: T,
): T {
if (!this.sessionModelConfig || wireOptions?.model) {
return wireOptions;
}

return {
...(wireOptions ?? {}),
model: this.sessionModelConfig,
} as T;
}

private toSessionStartModelClientOptions(
options?: ClientOptions,
modelName?: string,
): Api.ModelClientOptions | undefined {
const normalizedOptions = normalizeClientOptionsForModel(
options,
modelName,
);
return toApiModelClientOptions(normalizedOptions, modelName);
}

private consumeFinishedEventData<T>(): T | null {
Expand Down
30 changes: 29 additions & 1 deletion packages/core/lib/v3/llm/LLMProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { LogLine } from "../types/public/logs.js";
import {
AvailableModel,
ClientOptions,
GoogleVertexProviderSettings,
ModelProvider,
} from "../types/public/model.js";
import { AISdkClient } from "./aisdk.js";
Expand Down Expand Up @@ -100,6 +101,32 @@ const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = {
"gemini-2.5-pro-preview-03-25": "google",
};

function isStringRecord(value: unknown): value is Record<string, string> {
return (
typeof value === "object" &&
value !== null &&
!Array.isArray(value) &&
Object.keys(value).length > 0 &&
Object.values(value).every((item) => typeof item === "string")
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
);
}

function hasHostedVertexClientOptions(clientOptions?: ClientOptions): boolean {
const vertexOptions = clientOptions as
| Partial<GoogleVertexProviderSettings>
| undefined;
return Boolean(
vertexOptions &&
(typeof vertexOptions.project === "string" ||
typeof vertexOptions.location === "string" ||
typeof vertexOptions.baseURL === "string" ||
isStringRecord(vertexOptions.headers) ||
(typeof vertexOptions.googleAuthOptions === "object" &&
vertexOptions.googleAuthOptions !== null &&
Object.keys(vertexOptions.googleAuthOptions).length > 0)),
);
}

export function getAISDKLanguageModel(
subProvider: string,
subModelName: string,
Expand Down Expand Up @@ -166,7 +193,8 @@ export class LLMProvider {
if (
subProvider === "vertex" &&
!options?.disableAPI &&
!options?.experimental
!options?.experimental &&
!hasHostedVertexClientOptions(clientOptions)
) {
throw new ExperimentalNotConfiguredError("Vertex provider");
}
Expand Down
Loading
Loading