diff --git a/core/src/auth/auth_credential.ts b/core/src/auth/auth_credential.ts index 3034e085..f7446e3a 100644 --- a/core/src/auth/auth_credential.ts +++ b/core/src/auth/auth_credential.ts @@ -44,6 +44,7 @@ export interface OAuth2Auth { * verify the state */ authUri?: string; + nonce?: string; state?: string; codeVerifier?: string; /** @@ -54,8 +55,11 @@ export interface OAuth2Auth { authCode?: string; accessToken?: string; refreshToken?: string; + idToken?: string; expiresAt?: number; expiresIn?: number; + audience?: string; + tokenEndpointAuthMethod?: string; } /** diff --git a/core/src/auth/auth_handler.ts b/core/src/auth/auth_handler.ts index d697f18f..4ba7ddef 100644 --- a/core/src/auth/auth_handler.ts +++ b/core/src/auth/auth_handler.ts @@ -5,11 +5,12 @@ */ import {State} from '../sessions/state.js'; +import {randomUUID} from '../utils/env_aware_utils.js'; import {AuthCredential} from './auth_credential.js'; import {AuthConfig} from './auth_tool.js'; +import {OAuth2CredentialExchanger} from './oauth2/oauth2_credential_exchanger.js'; -// TODO(b/425992518): Implement the rest /** * A handler that handles the auth flow in Agent Development Kit to help * orchestrates the credential request and response flow (e.g. OAuth flow) @@ -24,6 +25,28 @@ export class AuthHandler { return state.get(credentialKey); } + async parseAndStoreAuthResponse(state: State): Promise { + const credentialKey = 'temp:' + this.authConfig.credentialKey; + + if (this.authConfig.exchangedAuthCredential) { + state.set(credentialKey, this.authConfig.exchangedAuthCredential); + } + + const authSchemeType = this.authConfig.authScheme.type; + if (!['oauth2', 'openIdConnect'].includes(authSchemeType)) { + return; + } + + if (this.authConfig.exchangedAuthCredential) { + const exchanger = new OAuth2CredentialExchanger(); + const exchangedCredential = await exchanger.exchange({ + authCredential: this.authConfig.exchangedAuthCredential, + authScheme: this.authConfig.authScheme, + }); + state.set(credentialKey, exchangedCredential.credential); + } + } + generateAuthRequest(): AuthConfig { const authSchemeType = this.authConfig.authScheme.type; @@ -79,7 +102,66 @@ export class AuthHandler { * auth scheme. */ generateAuthUri(): AuthCredential | undefined { - return this.authConfig.rawAuthCredential; - // TODO - b/425992518: Implement the rest of the function + const authScheme = this.authConfig.authScheme; + const authCredential = this.authConfig.rawAuthCredential; + + if (!authCredential || !authCredential.oauth2) { + return authCredential; + } + + let authorizationEndpoint = ''; + let scopes: string[] = []; + + if ('authorizationEndpoint' in authScheme) { + authorizationEndpoint = authScheme.authorizationEndpoint; + scopes = authScheme.scopes || []; + } else if (authScheme.type === 'oauth2' && authScheme.flows) { + const flows = authScheme.flows; + const flow = + flows.implicit || + flows.authorizationCode || + flows.clientCredentials || + flows.password; + + if (flow) { + if ('authorizationUrl' in flow && flow.authorizationUrl) { + authorizationEndpoint = flow.authorizationUrl; + } else if ('tokenUrl' in flow && flow.tokenUrl) { + authorizationEndpoint = flow.tokenUrl; + } + + if (flow.scopes) { + scopes = Object.keys(flow.scopes); + } + } + } + + if (!authorizationEndpoint) { + throw new Error('Authorization endpoint not configured in auth scheme.'); + } + + const state = randomUUID(); + const url = new URL(authorizationEndpoint); + url.searchParams.set('client_id', authCredential.oauth2.clientId || ''); + url.searchParams.set( + 'redirect_uri', + authCredential.oauth2.redirectUri || '', + ); + url.searchParams.set('response_type', 'code'); + url.searchParams.set('scope', scopes.join(' ')); + url.searchParams.set('state', state); + url.searchParams.set('access_type', 'offline'); + url.searchParams.set('prompt', 'consent'); + + const exchangedAuthCredential: AuthCredential = { + ...authCredential, + oauth2: { + ...authCredential.oauth2, + authUri: url.toString(), + state, + }, + }; + + return exchangedAuthCredential; } } diff --git a/core/src/auth/auth_preprocessor.ts b/core/src/auth/auth_preprocessor.ts new file mode 100644 index 00000000..e247c57c --- /dev/null +++ b/core/src/auth/auth_preprocessor.ts @@ -0,0 +1,191 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + REQUEST_EUC_FUNCTION_CALL_NAME, + handleFunctionCallsAsync, +} from '../agents/functions.js'; +import {InvocationContext} from '../agents/invocation_context.js'; +import {isLlmAgent} from '../agents/llm_agent.js'; +import {BaseLlmRequestProcessor} from '../agents/processors/base_llm_processor.js'; +import {ReadonlyContext} from '../agents/readonly_context.js'; +import { + Event, + getFunctionCalls, + getFunctionResponses, +} from '../events/event.js'; +import {State} from '../sessions/state.js'; +import {BaseTool} from '../tools/base_tool.js'; +import {AuthHandler} from './auth_handler.js'; +import {AuthConfig, AuthToolArguments} from './auth_tool.js'; + +const TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'; + +async function storeAuthAndCollectResumeTargets( + events: Event[], + authFcIds: Set, + authResponses: Record, + state: State, +): Promise> { + const requestedAuthConfigById: Record = {}; + for (const event of events) { + const eventFunctionCalls = getFunctionCalls(event); + for (const functionCall of eventFunctionCalls) { + if ( + functionCall.id && + authFcIds.has(functionCall.id) && + functionCall.name === REQUEST_EUC_FUNCTION_CALL_NAME + ) { + const args = functionCall.args as unknown as AuthToolArguments; + if (args && args.authConfig) { + requestedAuthConfigById[functionCall.id] = args.authConfig; + } + } + } + } + + for (const fcId of authFcIds) { + if (!(fcId in authResponses)) { + continue; + } + const authConfig = authResponses[fcId] as AuthConfig; + const requestedAuthConfig = requestedAuthConfigById[fcId]; + if (requestedAuthConfig && requestedAuthConfig.credentialKey) { + authConfig.credentialKey = requestedAuthConfig.credentialKey; + } + await new AuthHandler(authConfig).parseAndStoreAuthResponse(state); + } + + const toolsToResume: Set = new Set(); + for (const fcId of authFcIds) { + const requestedAuthConfig = requestedAuthConfigById[fcId]; + if (!requestedAuthConfig) { + continue; + } + for (const event of events) { + const eventFunctionCalls = getFunctionCalls(event); + for (const functionCall of eventFunctionCalls) { + if ( + functionCall.id === fcId && + functionCall.name === REQUEST_EUC_FUNCTION_CALL_NAME + ) { + const args = functionCall.args as unknown as AuthToolArguments; + if (args && args.functionCallId) { + if ( + args.functionCallId.startsWith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX) + ) { + continue; + } + toolsToResume.add(args.functionCallId); + } + } + } + } + } + + return toolsToResume; +} + +export class AuthPreprocessor extends BaseLlmRequestProcessor { + override async *runAsync( + invocationContext: InvocationContext, + ): AsyncGenerator { + const agent = invocationContext.agent; + if (!isLlmAgent(agent)) { + return; + } + + const events = invocationContext.session.events; + if (!events || events.length === 0) { + return; + } + + let lastEventWithContent = null; + for (let i = events.length - 1; i >= 0; i--) { + const event = events[i]; + if (event.content !== undefined) { + lastEventWithContent = event; + break; + } + } + + if (!lastEventWithContent || lastEventWithContent.author !== 'user') { + return; + } + + const responses = getFunctionResponses(lastEventWithContent); + if (!responses || responses.length === 0) { + return; + } + + const authFcIds: Set = new Set(); + const authResponses: Record = {}; + + for (const functionCallResponse of responses) { + if (functionCallResponse.name !== REQUEST_EUC_FUNCTION_CALL_NAME) { + continue; + } + if (functionCallResponse.id) { + authFcIds.add(functionCallResponse.id); + authResponses[functionCallResponse.id] = functionCallResponse.response; + } + } + + if (authFcIds.size === 0) { + return; + } + + const state = new State(invocationContext.session.state); + const toolsToResume = await storeAuthAndCollectResumeTargets( + events, + authFcIds, + authResponses, + state, + ); + + if (toolsToResume.size === 0) { + return; + } + + for (let i = events.length - 2; i >= 0; i--) { + const event = events[i]; + const functionCalls = getFunctionCalls(event); + if (!functionCalls || functionCalls.length === 0) { + continue; + } + + const hasMatchingCall = functionCalls.some((call) => + call.id ? toolsToResume.has(call.id) : false, + ); + + if (hasMatchingCall) { + const canonicalTools = await agent.canonicalTools( + new ReadonlyContext(invocationContext), + ); + const toolsDict: Record = {}; + for (const tool of canonicalTools) { + toolsDict[tool.name] = tool; + } + + const functionResponseEvent = await handleFunctionCallsAsync({ + invocationContext, + functionCallEvent: event, + toolsDict, + beforeToolCallbacks: agent.canonicalBeforeToolCallbacks, + afterToolCallbacks: agent.canonicalAfterToolCallbacks, + filters: toolsToResume, + }); + + if (functionResponseEvent) { + yield functionResponseEvent; + } + return; + } + } + } +} + +export const AUTH_PREPROCESSOR = new AuthPreprocessor(); diff --git a/core/src/auth/oauth2/oauth2_utils.ts b/core/src/auth/oauth2/oauth2_utils.ts index 1cbd7715..135ad3cf 100644 --- a/core/src/auth/oauth2/oauth2_utils.ts +++ b/core/src/auth/oauth2/oauth2_utils.ts @@ -40,6 +40,7 @@ interface OAuth2TokenResponse { access_token?: string; refresh_token?: string; expires_in?: number; + id_token?: string; } /** @@ -67,6 +68,7 @@ export async function fetchOAuth2Tokens( return { accessToken: data.access_token, refreshToken: data.refresh_token, + idToken: data.id_token, expiresIn: data.expires_in, expiresAt: data.expires_in ? Date.now() + data.expires_in * 1000 diff --git a/core/test/auth/auth_handler_test.ts b/core/test/auth/auth_handler_test.ts index 19bbcc31..f657c68a 100644 --- a/core/test/auth/auth_handler_test.ts +++ b/core/test/auth/auth_handler_test.ts @@ -5,7 +5,19 @@ */ import {AuthConfig, AuthCredentialTypes, AuthHandler, State} from '@google/adk'; -import {describe, expect, it} from 'vitest'; +import {describe, expect, it, vi} from 'vitest'; + +vi.mock('../../src/auth/oauth2/oauth2_credential_exchanger.js', () => ({ + OAuth2CredentialExchanger: class { + exchange = vi.fn().mockResolvedValue({ + credential: { + authType: 'oauth2', + oauth2: {accessToken: 'mockAccessToken'}, + }, + wasExchanged: true, + }); + }, +})); describe('AuthHandler', () => { describe('getAuthResponse', () => { @@ -41,6 +53,70 @@ describe('AuthHandler', () => { }); }); + describe('parseAndStoreAuthResponse', () => { + it('stores exchangedAuthCredential when present for non-oauth2', async () => { + const authConfig: AuthConfig = { + credentialKey: 'testKey', + authScheme: {type: 'apiKey', name: 'testKey', in: 'header'}, + exchangedAuthCredential: { + authType: AuthCredentialTypes.API_KEY, + apiKey: 'testToken', + }, + }; + const handler = new AuthHandler(authConfig); + const state = new State(); + + await handler.parseAndStoreAuthResponse(state); + + expect(state.get('temp:testKey')).toEqual({ + authType: 'apiKey', + apiKey: 'testToken', + }); + }); + + it('returns early if scheme type is not oauth2 or openIdConnect', async () => { + const authConfig: AuthConfig = { + credentialKey: 'testKey', + authScheme: {type: 'apiKey', name: 'testKey', in: 'header'}, + }; + const handler = new AuthHandler(authConfig); + const state = new State(); + + await handler.parseAndStoreAuthResponse(state); + + expect(state.get('temp:testKey')).toBeUndefined(); + }); + + it('stores exchangedCredential.credential for oauth2 when exchange happens', async () => { + const authConfig: AuthConfig = { + credentialKey: 'testKey', + authScheme: { + type: 'oauth2', + flows: { + authorizationCode: { + authorizationUrl: 'https://auth.com', + tokenUrl: 'https://token.com', + scopes: {}, + }, + }, + }, + exchangedAuthCredential: { + authType: AuthCredentialTypes.OAUTH2, + oauth2: {authCode: '123'}, + }, + }; + const handler = new AuthHandler(authConfig); + const state = new State(); + + await handler.parseAndStoreAuthResponse(state); + + expect(state.get('temp:testKey')).toEqual({ + authType: 'oauth2', + oauth2: {accessToken: 'mockAccessToken'}, + }); + }); + }); + describe('generateAuthRequest', () => { it('returns original config if scheme type is not oauth2 or openIdConnect', () => { const authConfig: AuthConfig = { @@ -199,14 +275,15 @@ describe('AuthHandler', () => { const request = handler.generateAuthRequest(); - expect(request.exchangedAuthCredential).toEqual( - authConfig.rawAuthCredential, // As per current implementation of generateAuthUri returning rawAuthCredential + expect(request.exchangedAuthCredential).toBeDefined(); + expect(request.exchangedAuthCredential?.oauth2?.authUri).toContain( + 'https://auth.com', ); }); }); describe('generateAuthUri', () => { - it('returns rawAuthCredential (current implementation)', () => { + it('generates auth URI for oauth2 scheme with flows', () => { const authConfig: AuthConfig = { credentialKey: 'testKey', authScheme: { @@ -214,6 +291,101 @@ describe('AuthHandler', () => { flows: { authorizationCode: { authorizationUrl: 'https://auth.com', + tokenUrl: 'https://token.com', + scopes: {scope1: 'desc'}, + }, + }, + }, + rawAuthCredential: { + authType: AuthCredentialTypes.OAUTH2, + oauth2: { + clientId: 'id', + clientSecret: 'secret', + redirectUri: 'https://redirect.com', + }, + }, + }; + const handler = new AuthHandler(authConfig); + + const uri = handler.generateAuthUri(); + + expect(uri).toBeDefined(); + expect(uri?.oauth2?.authUri).toContain('https://auth.com'); + expect(uri?.oauth2?.authUri).toContain('client_id=id'); + expect(uri?.oauth2?.authUri).toContain( + 'redirect_uri=https%3A%2F%2Fredirect.com', + ); + expect(uri?.oauth2?.authUri).toContain('scope=scope1'); + expect(uri?.oauth2?.authUri).not.toContain('secret'); + expect(uri?.oauth2?.state).toBeDefined(); + }); + + it('throws if authorization endpoint is missing', () => { + const authConfig: AuthConfig = { + credentialKey: 'testKey', + authScheme: { + type: 'oauth2', + flows: { + clientCredentials: { + tokenUrl: '', + scopes: {}, + }, + }, + }, + rawAuthCredential: { + authType: AuthCredentialTypes.OAUTH2, + oauth2: {clientId: 'id'}, + }, + }; + const handler = new AuthHandler(authConfig); + + expect(() => handler.generateAuthUri()).toThrow( + 'Authorization endpoint not configured in auth scheme.', + ); + }); + + it('generates auth URI for scheme with authorizationEndpoint (OpenIdConnect)', () => { + const authConfig: AuthConfig = { + credentialKey: 'testKey', + authScheme: { + type: 'openIdConnect', + authorizationEndpoint: 'https://oidc-auth.com', + scopes: ['openid'], + tokenEndpoint: '', + openIdConnectUrl: 'https://oidc-auth.com', + }, + rawAuthCredential: { + authType: AuthCredentialTypes.OAUTH2, + oauth2: {clientId: 'id', redirectUri: 'https://redirect.com'}, + }, + }; + const handler = new AuthHandler(authConfig); + + const uri = handler.generateAuthUri(); + + expect(uri).toBeDefined(); + expect(uri?.oauth2?.authUri).toContain('https://oidc-auth.com'); + }); + + it('returns original credential if rawAuthCredential or oauth2 is missing', () => { + const authConfig: AuthConfig = { + credentialKey: 'testKey', + authScheme: {type: 'oauth2', flows: {}}, + }; + const handler = new AuthHandler(authConfig); + + const uri = handler.generateAuthUri(); + + expect(uri).toBeUndefined(); + }); + + it('uses tokenUrl as fallback for authorizationEndpoint if authorizationUrl is missing', () => { + const authConfig: AuthConfig = { + credentialKey: 'testKey', + authScheme: { + type: 'oauth2', + flows: { + clientCredentials: { tokenUrl: 'https://token.com', scopes: {}, }, @@ -228,7 +400,8 @@ describe('AuthHandler', () => { const uri = handler.generateAuthUri(); - expect(uri).toBe(authConfig.rawAuthCredential); + expect(uri).toBeDefined(); + expect(uri?.oauth2?.authUri).toContain('https://token.com'); }); }); }); diff --git a/core/test/auth/auth_preprocessor_test.ts b/core/test/auth/auth_preprocessor_test.ts new file mode 100644 index 00000000..8c816dde --- /dev/null +++ b/core/test/auth/auth_preprocessor_test.ts @@ -0,0 +1,166 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import {Event, createEvent} from '@google/adk'; +import {Mock, describe, expect, it, vi} from 'vitest'; +import {REQUEST_EUC_FUNCTION_CALL_NAME} from '../../src/agents/functions.js'; +import {InvocationContext} from '../../src/agents/invocation_context.js'; +import {AUTH_PREPROCESSOR} from '../../src/auth/auth_preprocessor.js'; + +vi.mock('../../src/agents/functions.js', async (importOriginal) => { + const actual = (await importOriginal()) as { + handleFunctionCallsAsync: Mock; + }; + return { + ...actual, + handleFunctionCallsAsync: vi.fn().mockResolvedValue({ + id: 'mockResponseEvent', + author: 'system', + } as Event), + }; +}); + +vi.mock('../../src/auth/auth_handler.js', () => ({ + AuthHandler: class { + parseAndStoreAuthResponse = vi.fn().mockResolvedValue(undefined); + }, +})); + +describe('AuthPreprocessor', () => { + const LLM_AGENT_SYMBOL = Symbol.for('google.adk.llmAgent'); + + it('skips if agent is not LlmAgent', async () => { + const invocationContext = { + agent: {}, // Not an LlmAgent + session: {events: []}, + } as unknown as InvocationContext; + + const generator = AUTH_PREPROCESSOR.runAsync(invocationContext); + const result = await generator.next(); + + expect(result.done).toBe(true); + }); + + it('skips if no events are present', async () => { + const invocationContext = { + agent: {[LLM_AGENT_SYMBOL]: true}, + session: {events: []}, + } as unknown as InvocationContext; + + const generator = AUTH_PREPROCESSOR.runAsync(invocationContext); + const result = await generator.next(); + + expect(result.done).toBe(true); + }); + + it('skips if last event is not from user', async () => { + const invocationContext = { + agent: {[LLM_AGENT_SYMBOL]: true}, + session: { + events: [ + {author: 'system', content: {parts: [{text: 'hello'}]}} as Event, + ], + }, + } as unknown as InvocationContext; + + const generator = AUTH_PREPROCESSOR.runAsync(invocationContext); + const result = await generator.next(); + + expect(result.done).toBe(true); + }); + + it('skips if no function responses for request_credential are found', async () => { + const invocationContext = { + agent: {[LLM_AGENT_SYMBOL]: true}, + session: { + events: [ + { + author: 'user', + content: { + parts: [{text: 'hello'}], + }, + } as Event, + ], + }, + } as unknown as InvocationContext; + + const generator = AUTH_PREPROCESSOR.runAsync(invocationContext); + const result = await generator.next(); + + expect(result.done).toBe(true); + }); + + it('processes adk_request_credential responses and resumes tools', async () => { + const invocationContext = { + agent: { + [LLM_AGENT_SYMBOL]: true, + canonicalTools: vi.fn().mockResolvedValue([]), + canonicalBeforeToolCallbacks: [], + canonicalAfterToolCallbacks: [], + }, + session: { + state: {}, + events: [ + createEvent({ + author: 'agent', + content: { + parts: [ + { + functionCall: { + id: 'toolFc1', + name: 'someTool', + args: {}, + }, + }, + ], + }, + }), + createEvent({ + author: 'agent', + id: 'originalEvent', + content: { + parts: [ + { + functionCall: { + id: 'fc1', + name: REQUEST_EUC_FUNCTION_CALL_NAME, + args: { + authConfig: {credentialKey: 'testKey'}, + functionCallId: 'toolFc1', + }, + }, + }, + ], + }, + }), + createEvent({ + author: 'user', + content: { + parts: [ + { + functionResponse: { + id: 'fc1', + name: REQUEST_EUC_FUNCTION_CALL_NAME, + response: {authType: 'apiKey', apiKey: 'test'}, + }, + }, + ], + }, + }), + ], + }, + } as unknown as InvocationContext; + + const generator = AUTH_PREPROCESSOR.runAsync(invocationContext); + const result = await generator.next(); + + expect(result.done).toBe(false); + expect(result.value).toEqual({ + id: 'mockResponseEvent', + author: 'system', + }); + }); +});