diff --git a/UPGRADE.md b/UPGRADE.md index 652b4de9ba..d95e663281 100644 --- a/UPGRADE.md +++ b/UPGRADE.md @@ -1,6 +1,21 @@ UPGRADE FROM 0.9 to 0.10 ======================== +Platform +-------- + + * `PlatformInterface::invoke()` now accepts `string|Model` for its first argument, and + `ProviderInterface::invoke()` and `ProviderInterface::supports()` were widened the same way. Custom + implementations and decorators must widen their signatures accordingly: + + ```diff + -public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult + +public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult + + -public function supports(string $model): bool + +public function supports(string|Model $model): bool + ``` + Store ----- diff --git a/docs/components/platform.rst b/docs/components/platform.rst index d19ebf6a9b..72a3fb7473 100644 --- a/docs/components/platform.rst +++ b/docs/components/platform.rst @@ -94,6 +94,35 @@ the default catalog. The ``ModelCatalog`` automatically queries the model inform Message::ofUser(...) )); +Passing a Model Instance +~~~~~~~~~~~~~~~~~~~~~~~~ + +Instead of a model name string, you can hand a fully defined model instance to ``Platform::invoke()``. This +skips the catalog lookup entirely and is useful when a provider ships a model that is not (yet) part of the +shipped catalog, without registering it or replacing the catalog:: + + use Symfony\AI\Platform\Bridge\OpenAi\Gpt; + use Symfony\AI\Platform\Capability; + use Symfony\AI\Platform\Message\Message; + use Symfony\AI\Platform\Message\MessageBag; + + $model = new Gpt('gpt-newest', [ + Capability::INPUT_MESSAGES, + Capability::OUTPUT_TEXT, + Capability::TOOL_CALLING, + ], ['temperature' => 0.5]); + + $result = $platform->invoke($model, new MessageBag(Message::ofUser(...))); + +.. note:: + + You must pass a **bridge-specific** model subclass (e.g. ``Gpt``, ``Claude``, ``Gemini``), not the base + :class:`Symfony\\AI\\Platform\\Model`. Model clients, result converters, and contract normalizers select + the right implementation via the concrete class, so a bare ``Model`` instance has no client to handle it. + The platform routes the instance to the first provider whose model clients accept it; in multi-provider + setups where the same class is shared (e.g. OpenAI and Azure both use ``Gpt``), the first matching provider + wins. + Supported Models & Platforms ---------------------------- diff --git a/src/platform/CHANGELOG.md b/src/platform/CHANGELOG.md index 68f0e5d922..bfe9ab0657 100644 --- a/src/platform/CHANGELOG.md +++ b/src/platform/CHANGELOG.md @@ -1,6 +1,11 @@ CHANGELOG ========= +0.10 +---- + + * Add support for passing a fully defined `Model` instance to `Platform::invoke()` (and `Provider::invoke()`) instead of a model name string, bypassing the model catalog; widen `ProviderInterface::supports()` to `string|Model` to route a model instance to the first provider whose model clients accept it + 0.9 --- diff --git a/src/platform/src/Bridge/Cache/CachePlatform.php b/src/platform/src/Bridge/Cache/CachePlatform.php index ad5813a76a..b2a1b6616e 100644 --- a/src/platform/src/Bridge/Cache/CachePlatform.php +++ b/src/platform/src/Bridge/Cache/CachePlatform.php @@ -13,6 +13,7 @@ use Symfony\AI\Platform\Exception\InvalidArgumentException; use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Platform\Model; use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface; use Symfony\AI\Platform\PlainConverter; use Symfony\AI\Platform\PlatformInterface; @@ -58,12 +59,14 @@ classDiscriminatorResolver: new ClassDiscriminatorFromClassMetadata(new ClassMet ) { } - public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult + public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult { if (null === $this->cache || !\array_key_exists('prompt_cache_key', $options) || '' === $options['prompt_cache_key']) { return $this->platform->invoke($model, $input, $options); } + $modelName = $model instanceof Model ? $model->getName() : $model; + $normalizedInput = match (true) { \is_string($input) => md5($input), \is_array($input) => json_encode($input), @@ -73,7 +76,7 @@ public function invoke(string $model, array|string|object $input, array $options $cacheKey = (new UnicodeString())->join([ $options['prompt_cache_key'] ?? $this->cacheKey, - (new UnicodeString($model))->camel(), + (new UnicodeString($modelName))->camel(), $normalizedInput, ]); @@ -81,8 +84,8 @@ public function invoke(string $model, array|string|object $input, array $options unset($options['prompt_cache_key'], $options['prompt_cache_ttl']); - $cached = $this->cache->get($cacheKey, function (ItemInterface $item) use ($model, $input, $options, $cacheKey, $ttl): array { - $item->tag((new UnicodeString($model))->camel()); + $cached = $this->cache->get($cacheKey, function (ItemInterface $item) use ($model, $modelName, $input, $options, $cacheKey, $ttl): array { + $item->tag((new UnicodeString($modelName))->camel()); if (null !== $ttl) { $item->expiresAfter($ttl); diff --git a/src/platform/src/Bridge/Failover/FailoverPlatform.php b/src/platform/src/Bridge/Failover/FailoverPlatform.php index 52a1e97b4c..fe01a884dd 100644 --- a/src/platform/src/Bridge/Failover/FailoverPlatform.php +++ b/src/platform/src/Bridge/Failover/FailoverPlatform.php @@ -15,6 +15,7 @@ use Psr\Log\NullLogger; use Symfony\AI\Platform\Exception\InvalidArgumentException; use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model; use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface; use Symfony\AI\Platform\PlatformInterface; use Symfony\AI\Platform\Result\DeferredResult; @@ -48,7 +49,7 @@ public function __construct( $this->failedPlatforms = new \WeakMap(); } - public function invoke(string $model, object|array|string $input, array $options = []): DeferredResult + public function invoke(string|Model $model, object|array|string $input, array $options = []): DeferredResult { return $this->do(static fn (PlatformInterface $platform): DeferredResult => $platform->invoke($model, $input, $options)); } diff --git a/src/platform/src/Platform.php b/src/platform/src/Platform.php index 72f262b8eb..7512448478 100644 --- a/src/platform/src/Platform.php +++ b/src/platform/src/Platform.php @@ -13,6 +13,7 @@ use Symfony\AI\Platform\Event\ModelRoutingEvent; use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Exception\ModelNotFoundException; use Symfony\AI\Platform\ModelCatalog\CompositeModelCatalog; use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface; use Symfony\AI\Platform\ModelRouter\CatalogBasedModelRouter; @@ -44,8 +45,12 @@ public function __construct( } } - public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult + public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult { + if ($model instanceof Model) { + return $this->resolveProviderForModel($model)->invoke($model, $input, $options); + } + $event = new ModelRoutingEvent($model, $input, $options); $this->eventDispatcher?->dispatch($event); @@ -64,4 +69,18 @@ public function getModelCatalog(): ModelCatalogInterface ), ); } + + /** + * Routes a fully defined model to the first provider whose model clients accept it. + */ + private function resolveProviderForModel(Model $model): ProviderInterface + { + foreach ($this->providers as $provider) { + if ($provider->supports($model)) { + return $provider; + } + } + + throw new ModelNotFoundException(\sprintf('No provider found for model "%s" (%s).', $model->getName(), $model::class)); + } } diff --git a/src/platform/src/PlatformInterface.php b/src/platform/src/PlatformInterface.php index 109ad018ea..ef753e2bc0 100644 --- a/src/platform/src/PlatformInterface.php +++ b/src/platform/src/PlatformInterface.php @@ -20,11 +20,11 @@ interface PlatformInterface { /** - * @param non-empty-string $model The model name + * @param non-empty-string|Model $model The model name to resolve via the catalog, or a fully defined model * @param array|string|object $input The input data * @param array $options The options to customize the model invocation */ - public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult; + public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult; public function getModelCatalog(): ModelCatalogInterface; } diff --git a/src/platform/src/Provider.php b/src/platform/src/Provider.php index f6ab46bfbb..cec14f8ee8 100644 --- a/src/platform/src/Provider.php +++ b/src/platform/src/Provider.php @@ -54,10 +54,20 @@ public function getName(): string return $this->name; } - public function supports(string $modelName): bool + public function supports(string|Model $model): bool { + if ($model instanceof Model) { + foreach ($this->modelClients as $modelClient) { + if ($modelClient->supports($model)) { + return true; + } + } + + return false; + } + try { - $this->resolvedModel = $this->modelCatalog->getModel($modelName); + $this->resolvedModel = $this->modelCatalog->getModel($model); return true; } catch (ModelNotFoundException) { @@ -67,9 +77,11 @@ public function supports(string $modelName): bool } } - public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult + public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult { - if (null !== $this->resolvedModel && $this->resolvedModel->getName() === $model) { + if ($model instanceof Model) { + $this->resolvedModel = null; + } elseif (null !== $this->resolvedModel && $this->resolvedModel->getName() === $model) { $model = $this->resolvedModel; $this->resolvedModel = null; } else { diff --git a/src/platform/src/ProviderInterface.php b/src/platform/src/ProviderInterface.php index 40313a7bb6..bfa89b7b5d 100644 --- a/src/platform/src/ProviderInterface.php +++ b/src/platform/src/ProviderInterface.php @@ -30,18 +30,18 @@ interface ProviderInterface public function getName(): string; /** - * Whether this provider can handle the given model name. + * Whether this provider can handle the given model name, or fully defined model. * - * @param non-empty-string $modelName + * @param non-empty-string|Model $model */ - public function supports(string $modelName): bool; + public function supports(string|Model $model): bool; /** - * @param non-empty-string $model The model name + * @param non-empty-string|Model $model The model name to resolve via the catalog, or a fully defined model * @param array|string|object $input The input data * @param array $options The options to customize the model invocation */ - public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult; + public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult; public function getModelCatalog(): ModelCatalogInterface; } diff --git a/src/platform/src/Test/InMemoryPlatform.php b/src/platform/src/Test/InMemoryPlatform.php index 49110a871d..9e2ad59180 100644 --- a/src/platform/src/Test/InMemoryPlatform.php +++ b/src/platform/src/Test/InMemoryPlatform.php @@ -41,14 +41,16 @@ public function __construct(private readonly \Closure|string $mockResult) $this->modelCatalog = new FallbackModelCatalog(); } - public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult + public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult { - $model = new class($model) extends Model { - public function __construct(string $name) - { - parent::__construct($name); - } - }; + if (!$model instanceof Model) { + $model = new class($model) extends Model { + public function __construct(string $name) + { + parent::__construct($name); + } + }; + } $result = \is_string($this->mockResult) ? $this->mockResult : ($this->mockResult)($model, $input, $options); if ($result instanceof ResultInterface) { diff --git a/src/platform/src/TraceablePlatform.php b/src/platform/src/TraceablePlatform.php index 8f1163280c..557df2161e 100644 --- a/src/platform/src/TraceablePlatform.php +++ b/src/platform/src/TraceablePlatform.php @@ -46,7 +46,7 @@ public function __construct( $this->resultCache = new \WeakMap(); } - public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult + public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult { $deferredResult = $this->platform->invoke($model, $input, $options); diff --git a/src/platform/tests/PlatformTest.php b/src/platform/tests/PlatformTest.php index 2435fcd24a..7146df64f0 100644 --- a/src/platform/tests/PlatformTest.php +++ b/src/platform/tests/PlatformTest.php @@ -14,6 +14,8 @@ use PHPUnit\Framework\TestCase; use Symfony\AI\Platform\Event\ModelRoutingEvent; use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Exception\ModelNotFoundException; +use Symfony\AI\Platform\Model; use Symfony\AI\Platform\ModelCatalog\CompositeModelCatalog; use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface; use Symfony\AI\Platform\ModelRouterInterface; @@ -142,6 +144,87 @@ public function testModelRoutingEventProviderSkipsRouter() $this->assertSame($deferredResult, $result); } + public function testInvokeWithModelObjectRoutesViaSupports() + { + $model = new Model('custom-model', []); + $deferredResult = new DeferredResult(new PlainConverter(new TextResult('Hello')), $this->createStub(RawResultInterface::class)); + + $provider = $this->createMock(ProviderInterface::class); + $provider->expects($this->once()) + ->method('supports') + ->with($model) + ->willReturn(true); + $provider->expects($this->once()) + ->method('invoke') + ->with($model, 'Hello', []) + ->willReturn($deferredResult); + + $router = $this->createMock(ModelRouterInterface::class); + $router->expects($this->never())->method('resolve'); + + $platform = new Platform([$provider], $router); + + $result = $platform->invoke($model, 'Hello'); + + $this->assertSame($deferredResult, $result); + } + + public function testInvokeWithModelObjectPicksFirstSupportingProvider() + { + $model = new Model('custom-model', []); + $deferredResult = new DeferredResult(new PlainConverter(new TextResult('Hello')), $this->createStub(RawResultInterface::class)); + + $firstProvider = $this->createMock(ProviderInterface::class); + $firstProvider->method('supports')->with($model)->willReturn(false); + $firstProvider->expects($this->never())->method('invoke'); + + $secondProvider = $this->createMock(ProviderInterface::class); + $secondProvider->method('supports')->with($model)->willReturn(true); + $secondProvider->expects($this->once()) + ->method('invoke') + ->with($model, 'Hello', []) + ->willReturn($deferredResult); + + $platform = new Platform([$firstProvider, $secondProvider]); + + $result = $platform->invoke($model, 'Hello'); + + $this->assertSame($deferredResult, $result); + } + + public function testInvokeWithModelObjectThrowsWhenNoProviderSupports() + { + $model = new Model('custom-model', []); + + $provider = $this->createMock(ProviderInterface::class); + $provider->method('supports')->with($model)->willReturn(false); + $provider->expects($this->never())->method('invoke'); + + $platform = new Platform([$provider]); + + $this->expectException(ModelNotFoundException::class); + $this->expectExceptionMessage('No provider found for model "custom-model"'); + + $platform->invoke($model, 'Hello'); + } + + public function testInvokeWithModelObjectDoesNotDispatchModelRoutingEvent() + { + $model = new Model('custom-model', []); + $deferredResult = new DeferredResult(new PlainConverter(new TextResult('Hello')), $this->createStub(RawResultInterface::class)); + + $provider = $this->createStub(ProviderInterface::class); + $provider->method('supports')->willReturn(true); + $provider->method('invoke')->willReturn($deferredResult); + + $eventDispatcher = $this->createMock(EventDispatcherInterface::class); + $eventDispatcher->expects($this->never())->method('dispatch'); + + $platform = new Platform([$provider], eventDispatcher: $eventDispatcher); + + $platform->invoke($model, 'Hello'); + } + public function testGetModelCatalogBuildsComposite() { $catalog1 = $this->createStub(ModelCatalogInterface::class); diff --git a/src/platform/tests/ProviderTest.php b/src/platform/tests/ProviderTest.php index 5f80eee9db..6cf6fcf3d9 100644 --- a/src/platform/tests/ProviderTest.php +++ b/src/platform/tests/ProviderTest.php @@ -62,6 +62,56 @@ public function testSupportsReturnsFalseWhenCatalogDoesNotHaveModel() $this->assertFalse($provider->supports('unknown-model')); } + public function testSupportsWithModelObjectReturnsTrueWhenAModelClientSupportsIt() + { + $model = new Model('custom-model', [Capability::INPUT_MESSAGES]); + + $modelClient = $this->createStub(ModelClientInterface::class); + $modelClient->method('supports')->willReturn(true); + + $provider = new Provider('openai', [$modelClient], [], $this->createStub(ModelCatalogInterface::class)); + + $this->assertTrue($provider->supports($model)); + } + + public function testSupportsWithModelObjectReturnsFalseWhenNoModelClientSupportsIt() + { + $model = new Model('custom-model', [Capability::INPUT_MESSAGES]); + + $modelClient = $this->createStub(ModelClientInterface::class); + $modelClient->method('supports')->willReturn(false); + + $provider = new Provider('openai', [$modelClient], [], $this->createStub(ModelCatalogInterface::class)); + + $this->assertFalse($provider->supports($model)); + } + + public function testInvokeWithModelObjectSkipsCatalogResolution() + { + $model = new Model('custom-model', [Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT]); + $rawResult = $this->createStub(RawResultInterface::class); + + $catalog = $this->createMock(ModelCatalogInterface::class); + $catalog->expects($this->never())->method('getModel'); + + $modelClient = $this->createMock(ModelClientInterface::class); + $modelClient->method('supports')->with($model)->willReturn(true); + $modelClient->expects($this->once()) + ->method('request') + ->with($model) + ->willReturn($rawResult); + + $resultConverter = $this->createStub(ResultConverterInterface::class); + $resultConverter->method('supports')->willReturn(true); + $resultConverter->method('convert')->willReturn(new TextResult('Hello')); + + $provider = new Provider('openai', [$modelClient], [$resultConverter], $catalog); + + $result = $provider->invoke($model, 'Hello'); + + $this->assertInstanceOf(DeferredResult::class, $result); + } + public function testInvokeResolvesModelAndDelegates() { $model = new Model('gpt-4o', [Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT]);