diff --git a/src/lib/onboard/machine/flow-phases/preflight-gateway.test.ts b/src/lib/onboard/machine/flow-phases/preflight-gateway.test.ts new file mode 100644 index 0000000000..ed0451c5d3 --- /dev/null +++ b/src/lib/onboard/machine/flow-phases/preflight-gateway.test.ts @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { describe, expect, it, vi } from "vitest"; + +import { createSession } from "../../../state/onboard-session"; +import { advanceTo } from "../result"; +import type { OnboardFlowContext } from "../flow-context"; +import { createGatewayPhase, createPreflightPhase } from "./preflight-gateway"; + +function context(): OnboardFlowContext { + return { + resume: false, + fresh: false, + session: createSession(), + agent: null, + recordedSandboxName: null, + requestedSandboxName: null, + sandboxName: null, + fromDockerfile: null, + model: null, + provider: null, + endpointUrl: null, + credentialEnv: null, + hermesAuthMethod: null, + hermesToolGateways: [], + preferredInferenceApi: null, + nimContainer: null, + webSearchConfig: null, + webSearchSupported: false, + selectedMessagingChannels: [], + gpu: null, + sandboxGpuConfig: null, + gpuPassthrough: false, + }; +} + +describe("preflight/gateway flow phases", () => { + it("maps preflight handler outputs into flow context and FSM result", async () => { + const session = createSession({ gpuPassthrough: true }); + const runPreflight = vi.fn(async () => ({ + session, + gpu: { type: "nvidia" }, + sandboxGpuConfig: { mode: "1" }, + gpuPassthrough: true, + result: advanceTo("gateway"), + })); + const phase = createPreflightPhase(runPreflight); + + const result = await phase.run(context()); + + expect(phase.state).toBe("preflight"); + expect(runPreflight).toHaveBeenCalledOnce(); + expect(result.context).toMatchObject({ + session, + gpu: { type: "nvidia" }, + sandboxGpuConfig: { mode: "1" }, + gpuPassthrough: true, + }); + expect(result.result).toMatchObject({ next: "gateway" }); + }); + + it("maps gateway handler outputs into flow context and FSM result", async () => { + const session = createSession({ sandboxName: "my-assistant" }); + const phase = createGatewayPhase(async () => ({ + session, + result: advanceTo("provider_selection"), + })); + + const result = await phase.run(context()); + + expect(phase.state).toBe("gateway"); + expect(result.context.session).toBe(session); + expect(result.result).toMatchObject({ next: "provider_selection" }); + }); +}); diff --git a/src/lib/onboard/machine/flow-phases/preflight-gateway.ts b/src/lib/onboard/machine/flow-phases/preflight-gateway.ts new file mode 100644 index 0000000000..5efba3949f --- /dev/null +++ b/src/lib/onboard/machine/flow-phases/preflight-gateway.ts @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +import type { OnboardFlowContext, OnboardFlowPhaseResult } from "../flow-context"; +import { mergeOnboardFlowContext, onboardFlowPhaseResult } from "../flow-context"; +import type { OnboardSequencePhase } from "../sequence-runner"; + +type PreflightPhaseHandler = (context: Context) => Promise<{ + session: Context["session"]; + gpu: Context["gpu"]; + sandboxGpuConfig: NonNullable; + gpuPassthrough: boolean; + result: OnboardFlowPhaseResult["result"]; +}>; + +type GatewayPhaseHandler = (context: Context) => Promise<{ + session: Context["session"]; + result: OnboardFlowPhaseResult["result"]; +}>; + +export function createPreflightPhase( + runPreflight: PreflightPhaseHandler, +): OnboardSequencePhase { + return { + state: "preflight", + async run(context) { + const result = await runPreflight(context); + return onboardFlowPhaseResult( + mergeOnboardFlowContext(context, { + session: result.session, + gpu: result.gpu, + sandboxGpuConfig: result.sandboxGpuConfig, + gpuPassthrough: result.gpuPassthrough, + } as Partial), + result.result, + ); + }, + }; +} + +export function createGatewayPhase( + runGateway: GatewayPhaseHandler, +): OnboardSequencePhase { + return { + state: "gateway", + async run(context) { + const result = await runGateway(context); + return onboardFlowPhaseResult( + mergeOnboardFlowContext(context, { session: result.session } as Partial), + result.result, + ); + }, + }; +}