Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions UPGRADE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
UPGRADE FROM 0.9 to 0.10
========================

Platform
--------

* `PlatformInterface::invoke()` now accepts `string|Model` for its first argument, and
`ProviderInterface::invoke()` and `ProviderInterface::supports()` were widened the same way. Custom
implementations and decorators must widen their signatures accordingly:

```diff
-public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult
+public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult

-public function supports(string $model): bool
+public function supports(string|Model $model): bool
```

Store
-----

Expand Down
29 changes: 29 additions & 0 deletions docs/components/platform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,35 @@ the default catalog. The ``ModelCatalog`` automatically queries the model inform
Message::ofUser(...)
));

Passing a Model Instance
~~~~~~~~~~~~~~~~~~~~~~~~

Instead of a model name string, you can hand a fully defined model instance to ``Platform::invoke()``. This
skips the catalog lookup entirely and is useful when a provider ships a model that is not (yet) part of the
shipped catalog, without registering it or replacing the catalog::

use Symfony\AI\Platform\Bridge\OpenAi\Gpt;
use Symfony\AI\Platform\Capability;
use Symfony\AI\Platform\Message\Message;
use Symfony\AI\Platform\Message\MessageBag;

$model = new Gpt('gpt-newest', [
Capability::INPUT_MESSAGES,
Capability::OUTPUT_TEXT,
Capability::TOOL_CALLING,
], ['temperature' => 0.5]);

$result = $platform->invoke($model, new MessageBag(Message::ofUser(...)));

.. note::

You must pass a **bridge-specific** model subclass (e.g. ``Gpt``, ``Claude``, ``Gemini``), not the base
:class:`Symfony\\AI\\Platform\\Model`. Model clients, result converters, and contract normalizers select
the right implementation via the concrete class, so a bare ``Model`` instance has no client to handle it.
The platform routes the instance to the first provider whose model clients accept it; in multi-provider
setups where the same class is shared (e.g. OpenAI and Azure both use ``Gpt``), the first matching provider
wins.

Supported Models & Platforms
----------------------------

Expand Down
5 changes: 5 additions & 0 deletions src/platform/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
CHANGELOG
=========

0.10
----

* Add support for passing a fully defined `Model` instance to `Platform::invoke()` (and `Provider::invoke()`) instead of a model name string, bypassing the model catalog; widen `ProviderInterface::supports()` to `string|Model` to route a model instance to the first provider whose model clients accept it

0.9
---

Expand Down
11 changes: 7 additions & 4 deletions src/platform/src/Bridge/Cache/CachePlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Message\MessageBag;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface;
use Symfony\AI\Platform\PlainConverter;
use Symfony\AI\Platform\PlatformInterface;
Expand Down Expand Up @@ -58,12 +59,14 @@ classDiscriminatorResolver: new ClassDiscriminatorFromClassMetadata(new ClassMet
) {
}

public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult
public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult
{
if (null === $this->cache || !\array_key_exists('prompt_cache_key', $options) || '' === $options['prompt_cache_key']) {
return $this->platform->invoke($model, $input, $options);
}

$modelName = $model instanceof Model ? $model->getName() : $model;

$normalizedInput = match (true) {
\is_string($input) => md5($input),
\is_array($input) => json_encode($input),
Expand All @@ -73,16 +76,16 @@ public function invoke(string $model, array|string|object $input, array $options

$cacheKey = (new UnicodeString())->join([
$options['prompt_cache_key'] ?? $this->cacheKey,
(new UnicodeString($model))->camel(),
(new UnicodeString($modelName))->camel(),
$normalizedInput,
]);

$ttl = $options['prompt_cache_ttl'] ?? $this->cacheTtl;

unset($options['prompt_cache_key'], $options['prompt_cache_ttl']);

$cached = $this->cache->get($cacheKey, function (ItemInterface $item) use ($model, $input, $options, $cacheKey, $ttl): array {
$item->tag((new UnicodeString($model))->camel());
$cached = $this->cache->get($cacheKey, function (ItemInterface $item) use ($model, $modelName, $input, $options, $cacheKey, $ttl): array {
$item->tag((new UnicodeString($modelName))->camel());

if (null !== $ttl) {
$item->expiresAfter($ttl);
Expand Down
3 changes: 2 additions & 1 deletion src/platform/src/Bridge/Failover/FailoverPlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use Psr\Log\NullLogger;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface;
use Symfony\AI\Platform\PlatformInterface;
use Symfony\AI\Platform\Result\DeferredResult;
Expand Down Expand Up @@ -48,7 +49,7 @@ public function __construct(
$this->failedPlatforms = new \WeakMap();
}

public function invoke(string $model, object|array|string $input, array $options = []): DeferredResult
public function invoke(string|Model $model, object|array|string $input, array $options = []): DeferredResult
{
return $this->do(static fn (PlatformInterface $platform): DeferredResult => $platform->invoke($model, $input, $options));
}
Expand Down
21 changes: 20 additions & 1 deletion src/platform/src/Platform.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

use Symfony\AI\Platform\Event\ModelRoutingEvent;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Exception\ModelNotFoundException;
use Symfony\AI\Platform\ModelCatalog\CompositeModelCatalog;
use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface;
use Symfony\AI\Platform\ModelRouter\CatalogBasedModelRouter;
Expand Down Expand Up @@ -44,8 +45,12 @@ public function __construct(
}
}

public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult
public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult
{
if ($model instanceof Model) {
return $this->resolveProviderForModel($model)->invoke($model, $input, $options);
}

$event = new ModelRoutingEvent($model, $input, $options);
$this->eventDispatcher?->dispatch($event);

Expand All @@ -64,4 +69,18 @@ public function getModelCatalog(): ModelCatalogInterface
),
);
}

/**
* Routes a fully defined model to the first provider whose model clients accept it.
*/
private function resolveProviderForModel(Model $model): ProviderInterface
{
foreach ($this->providers as $provider) {
if ($provider->supports($model)) {
return $provider;
}
}

throw new ModelNotFoundException(\sprintf('No provider found for model "%s" (%s).', $model->getName(), $model::class));
}
}
4 changes: 2 additions & 2 deletions src/platform/src/PlatformInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
interface PlatformInterface
{
/**
* @param non-empty-string $model The model name
* @param non-empty-string|Model $model The model name to resolve via the catalog, or a fully defined model
* @param array<mixed>|string|object $input The input data
* @param array<string, mixed> $options The options to customize the model invocation
*/
public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult;
public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult;

public function getModelCatalog(): ModelCatalogInterface;
}
20 changes: 16 additions & 4 deletions src/platform/src/Provider.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,20 @@ public function getName(): string
return $this->name;
}

public function supports(string $modelName): bool
public function supports(string|Model $model): bool
{
if ($model instanceof Model) {
foreach ($this->modelClients as $modelClient) {
if ($modelClient->supports($model)) {
return true;
}
}

return false;
}

try {
$this->resolvedModel = $this->modelCatalog->getModel($modelName);
$this->resolvedModel = $this->modelCatalog->getModel($model);

return true;
} catch (ModelNotFoundException) {
Expand All @@ -67,9 +77,11 @@ public function supports(string $modelName): bool
}
}

public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult
public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult
{
if (null !== $this->resolvedModel && $this->resolvedModel->getName() === $model) {
if ($model instanceof Model) {
$this->resolvedModel = null;
} elseif (null !== $this->resolvedModel && $this->resolvedModel->getName() === $model) {
$model = $this->resolvedModel;
$this->resolvedModel = null;
} else {
Expand Down
10 changes: 5 additions & 5 deletions src/platform/src/ProviderInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ interface ProviderInterface
public function getName(): string;

/**
* Whether this provider can handle the given model name.
* Whether this provider can handle the given model name, or fully defined model.
*
* @param non-empty-string $modelName
* @param non-empty-string|Model $model
*/
public function supports(string $modelName): bool;
public function supports(string|Model $model): bool;

/**
* @param non-empty-string $model The model name
* @param non-empty-string|Model $model The model name to resolve via the catalog, or a fully defined model
* @param array<mixed>|string|object $input The input data
* @param array<string, mixed> $options The options to customize the model invocation
*/
public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult;
public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult;

public function getModelCatalog(): ModelCatalogInterface;
}
16 changes: 9 additions & 7 deletions src/platform/src/Test/InMemoryPlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ public function __construct(private readonly \Closure|string $mockResult)
$this->modelCatalog = new FallbackModelCatalog();
}

public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult
public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult
{
$model = new class($model) extends Model {
public function __construct(string $name)
{
parent::__construct($name);
}
};
if (!$model instanceof Model) {
$model = new class($model) extends Model {
public function __construct(string $name)
{
parent::__construct($name);
}
};
}
$result = \is_string($this->mockResult) ? $this->mockResult : ($this->mockResult)($model, $input, $options);

if ($result instanceof ResultInterface) {
Expand Down
2 changes: 1 addition & 1 deletion src/platform/src/TraceablePlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public function __construct(
$this->resultCache = new \WeakMap();
}

public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult
public function invoke(string|Model $model, array|string|object $input, array $options = []): DeferredResult
{
$deferredResult = $this->platform->invoke($model, $input, $options);

Expand Down
83 changes: 83 additions & 0 deletions src/platform/tests/PlatformTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
use PHPUnit\Framework\TestCase;
use Symfony\AI\Platform\Event\ModelRoutingEvent;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Exception\ModelNotFoundException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelCatalog\CompositeModelCatalog;
use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface;
use Symfony\AI\Platform\ModelRouterInterface;
Expand Down Expand Up @@ -142,6 +144,87 @@ public function testModelRoutingEventProviderSkipsRouter()
$this->assertSame($deferredResult, $result);
}

public function testInvokeWithModelObjectRoutesViaSupports()
{
$model = new Model('custom-model', []);
$deferredResult = new DeferredResult(new PlainConverter(new TextResult('Hello')), $this->createStub(RawResultInterface::class));

$provider = $this->createMock(ProviderInterface::class);
$provider->expects($this->once())
->method('supports')
->with($model)
->willReturn(true);
$provider->expects($this->once())
->method('invoke')
->with($model, 'Hello', [])
->willReturn($deferredResult);

$router = $this->createMock(ModelRouterInterface::class);
$router->expects($this->never())->method('resolve');

$platform = new Platform([$provider], $router);

$result = $platform->invoke($model, 'Hello');

$this->assertSame($deferredResult, $result);
}

public function testInvokeWithModelObjectPicksFirstSupportingProvider()
{
$model = new Model('custom-model', []);
$deferredResult = new DeferredResult(new PlainConverter(new TextResult('Hello')), $this->createStub(RawResultInterface::class));

$firstProvider = $this->createMock(ProviderInterface::class);
$firstProvider->method('supports')->with($model)->willReturn(false);
$firstProvider->expects($this->never())->method('invoke');

$secondProvider = $this->createMock(ProviderInterface::class);
$secondProvider->method('supports')->with($model)->willReturn(true);
$secondProvider->expects($this->once())
->method('invoke')
->with($model, 'Hello', [])
->willReturn($deferredResult);

$platform = new Platform([$firstProvider, $secondProvider]);

$result = $platform->invoke($model, 'Hello');

$this->assertSame($deferredResult, $result);
}

public function testInvokeWithModelObjectThrowsWhenNoProviderSupports()
{
$model = new Model('custom-model', []);

$provider = $this->createMock(ProviderInterface::class);
$provider->method('supports')->with($model)->willReturn(false);
$provider->expects($this->never())->method('invoke');

$platform = new Platform([$provider]);

$this->expectException(ModelNotFoundException::class);
$this->expectExceptionMessage('No provider found for model "custom-model"');

$platform->invoke($model, 'Hello');
}

public function testInvokeWithModelObjectDoesNotDispatchModelRoutingEvent()
{
$model = new Model('custom-model', []);
$deferredResult = new DeferredResult(new PlainConverter(new TextResult('Hello')), $this->createStub(RawResultInterface::class));

$provider = $this->createStub(ProviderInterface::class);
$provider->method('supports')->willReturn(true);
$provider->method('invoke')->willReturn($deferredResult);

$eventDispatcher = $this->createMock(EventDispatcherInterface::class);
$eventDispatcher->expects($this->never())->method('dispatch');

$platform = new Platform([$provider], eventDispatcher: $eventDispatcher);

$platform->invoke($model, 'Hello');
}

public function testGetModelCatalogBuildsComposite()
{
$catalog1 = $this->createStub(ModelCatalogInterface::class);
Expand Down
Loading