From f4da1a06a6450fb4f78eb589722ca554f4ec79eb Mon Sep 17 00:00:00 2001 From: Alexey Orlenko Date: Fri, 10 Apr 2026 17:31:55 +0200 Subject: [PATCH] feat(sql-runtime): add callback-based transaction API --- packages/2-sql/5-runtime/src/exports/index.ts | 6 +- packages/2-sql/5-runtime/src/sql-runtime.ts | 73 ++++++- .../2-sql/5-runtime/test/sql-runtime.test.ts | 193 +++++++++++++++++- .../postgres/src/runtime/postgres.ts | 18 ++ .../postgres/test/postgres.test.ts | 67 ++++++ .../postgres/test/transaction.types.test-d.ts | 29 +++ projects/orm-client-transaction-api/plan.md | 12 +- projects/orm-client-transaction-api/spec.md | 30 +-- test/e2e/framework/test/transaction.test.ts | 84 ++++++++ 9 files changed, 486 insertions(+), 26 deletions(-) create mode 100644 packages/3-extensions/postgres/test/transaction.types.test-d.ts create mode 100644 test/e2e/framework/test/transaction.test.ts diff --git a/packages/2-sql/5-runtime/src/exports/index.ts b/packages/2-sql/5-runtime/src/exports/index.ts index 142d4c6b8d..941500c1bf 100644 --- a/packages/2-sql/5-runtime/src/exports/index.ts +++ b/packages/2-sql/5-runtime/src/exports/index.ts @@ -43,8 +43,12 @@ export { export type { CreateRuntimeOptions, Runtime, + RuntimeConnection, + RuntimeQueryable, RuntimeTelemetryEvent, + RuntimeTransaction, RuntimeVerifyOptions, TelemetryOutcome, + TransactionContext, } from '../sql-runtime'; -export { createRuntime } from '../sql-runtime'; +export { createRuntime, withTransaction } from '../sql-runtime'; diff --git a/packages/2-sql/5-runtime/src/sql-runtime.ts b/packages/2-sql/5-runtime/src/sql-runtime.ts index 9e10243d82..367769e68f 100644 --- a/packages/2-sql/5-runtime/src/sql-runtime.ts +++ b/packages/2-sql/5-runtime/src/sql-runtime.ts @@ -12,7 +12,11 @@ import type { RuntimeVerifyOptions, TelemetryOutcome, } from '@prisma-next/runtime-executor'; -import { AsyncIterableResult, createRuntimeCore } from '@prisma-next/runtime-executor'; +import { + AsyncIterableResult, + createRuntimeCore, + runtimeError, +} from '@prisma-next/runtime-executor'; import type { SqlStorage } from '@prisma-next/sql-contract/types'; import type { Adapter, @@ -95,6 +99,10 @@ export interface RuntimeQueryable { ): AsyncIterableResult; } +export interface TransactionContext extends RuntimeQueryable { + readonly invalidated: boolean; +} + interface CoreQueryable { execute>(plan: ExecutionPlan): AsyncIterableResult; } @@ -238,6 +246,69 @@ class SqlRuntimeImpl = Contract( + runtime: Runtime, + fn: (tx: TransactionContext) => PromiseLike, +): Promise { + const connection = await runtime.connection(); + const transaction = await connection.transaction(); + + let invalidated = false; + const txContext: TransactionContext = { + get invalidated() { + return invalidated; + }, + execute>( + plan: ExecutionPlan | SqlQueryPlan, + ): AsyncIterableResult { + if (invalidated) { + throw transactionClosedError(); + } + const inner = transaction.execute(plan); + const guarded = async function* (): AsyncGenerator { + for await (const row of inner) { + if (invalidated) { + throw transactionClosedError(); + } + yield row; + } + }; + return new AsyncIterableResult(guarded()); + }, + }; + + try { + const result = await fn(txContext); + invalidated = true; + await transaction.commit(); + return result; + } catch (error) { + invalidated = true; + try { + await transaction.rollback(); + } catch (rollbackError) { + const wrapped = runtimeError( + 'RUNTIME.TRANSACTION_ROLLBACK_FAILED', + 'Transaction rollback failed after callback error', + { rollbackError }, + ); + wrapped.cause = error; + throw wrapped; + } + throw error; + } finally { + await connection.release(); + } +} + export function createRuntime, TTargetId extends string>( options: CreateRuntimeOptions, ): Runtime { diff --git a/packages/2-sql/5-runtime/test/sql-runtime.test.ts b/packages/2-sql/5-runtime/test/sql-runtime.test.ts index 4516a9f8dd..830efa7c4d 100644 --- a/packages/2-sql/5-runtime/test/sql-runtime.test.ts +++ b/packages/2-sql/5-runtime/test/sql-runtime.test.ts @@ -21,7 +21,7 @@ import type { SqlRuntimeTargetDescriptor, } from '../src/sql-context'; import { createExecutionContext, createSqlExecutionStack } from '../src/sql-context'; -import { createRuntime } from '../src/sql-runtime'; +import { createRuntime, withTransaction } from '../src/sql-runtime'; const testContract: Contract = { targetFamily: 'sql', @@ -35,13 +35,16 @@ const testContract: Contract = { meta: {}, }; -interface DriverExecuteSpies { +interface DriverMockSpies { rootExecute: ReturnType; connectionExecute: ReturnType; transactionExecute: ReturnType; + connectionRelease: ReturnType; + transactionCommit: ReturnType; + transactionRollback: ReturnType; } -type MockSqlDriver = SqlDriver & { __spies: DriverExecuteSpies }; +type MockSqlDriver = SqlDriver & { __spies: DriverMockSpies }; function createStubCodecs(): CodecRegistry { const registry = createCodecRegistry(); @@ -128,6 +131,9 @@ function createMockDriver(): MockSqlDriver { rootExecute, connectionExecute, transactionExecute, + connectionRelease: connection.release, + transactionCommit: transaction.commit, + transactionRollback: transaction.rollback, }, }); } @@ -324,3 +330,184 @@ describe('createRuntime', () => { expect(driver.__spies.transactionExecute).not.toHaveBeenCalled(); }); }); + +describe('withTransaction', () => { + function createRuntimeForTransaction() { + const { stackInstance, context, driver } = createTestSetup(); + const runtime = createRuntime({ + stackInstance, + context, + driver, + verify: { mode: 'onFirstUse', requireMarker: false }, + }); + return { runtime, driver }; + } + + it('commits on successful callback and returns the result', async () => { + const { runtime, driver } = createRuntimeForTransaction(); + + const result = await withTransaction(runtime, async (tx) => { + await tx.execute(createRawExecutionPlan()).toArray(); + return 42; + }); + + expect(result).toBe(42); + expect(driver.__spies.transactionCommit).toHaveBeenCalledOnce(); + expect(driver.__spies.transactionRollback).not.toHaveBeenCalled(); + expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce(); + }); + + it('rolls back on callback error and re-throws', async () => { + const { runtime, driver } = createRuntimeForTransaction(); + const error = new Error('test error'); + + await expect( + withTransaction(runtime, async () => { + throw error; + }), + ).rejects.toBe(error); + + expect(driver.__spies.transactionRollback).toHaveBeenCalledOnce(); + expect(driver.__spies.transactionCommit).not.toHaveBeenCalled(); + expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce(); + }); + + it('releases connection after commit', async () => { + const { runtime, driver } = createRuntimeForTransaction(); + + await withTransaction(runtime, async () => 'ok'); + + expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce(); + }); + + it('releases connection after rollback', async () => { + const { runtime, driver } = createRuntimeForTransaction(); + + await withTransaction(runtime, async () => { + throw new Error('fail'); + }).catch(() => {}); + + expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce(); + }); + + it('propagates commit failure', async () => { + const { runtime, driver } = createRuntimeForTransaction(); + const commitError = new Error('commit failed'); + driver.__spies.transactionCommit.mockRejectedValueOnce(commitError); + + const result = withTransaction(runtime, async () => 'value'); + + await expect(result).rejects.toBe(commitError); + }); + + it('forwards the callback return value', async () => { + const { runtime } = createRuntimeForTransaction(); + + const result = await withTransaction(runtime, async () => ({ + name: 'test', + count: 3, + })); + + expect(result).toEqual({ name: 'test', count: 3 }); + }); + + it('executes queries against the transaction', async () => { + const { runtime, driver } = createRuntimeForTransaction(); + + await withTransaction(runtime, async (tx) => { + await tx.execute(createRawExecutionPlan()).toArray(); + }); + + expect(driver.__spies.transactionExecute).toHaveBeenCalledOnce(); + expect(driver.__spies.rootExecute).not.toHaveBeenCalled(); + expect(driver.__spies.connectionExecute).not.toHaveBeenCalled(); + }); + + it('throws on execute after commit (invalidation)', async () => { + const { runtime } = createRuntimeForTransaction(); + let savedTx: { execute: (plan: ExecutionPlan) => unknown } | undefined; + + await withTransaction(runtime, async (tx) => { + savedTx = tx; + }); + + expect(() => savedTx!.execute(createRawExecutionPlan())).toThrow( + 'Cannot read from a query result after the transaction has ended', + ); + }); + + it('throws on iteration of escaped AsyncIterableResult after commit', async () => { + const { runtime } = createRuntimeForTransaction(); + + const escaped = await withTransaction(runtime, async (tx) => { + return { result: tx.execute(createRawExecutionPlan()) }; + }); + + await expect(escaped.result.toArray()).rejects.toThrow( + 'Cannot read from a query result after the transaction has ended', + ); + }); + + it('sets invalidated flag after commit', async () => { + const { runtime } = createRuntimeForTransaction(); + let txRef: { invalidated: boolean } | undefined; + + await withTransaction(runtime, async (tx) => { + expect(tx.invalidated).toBe(false); + txRef = tx; + }); + + expect(txRef!.invalidated).toBe(true); + }); + + it('wraps original error when rollback fails', async () => { + const { runtime, driver } = createRuntimeForTransaction(); + const callbackError = new Error('callback failed'); + const rollbackError = new Error('rollback failed'); + driver.__spies.transactionRollback.mockRejectedValueOnce(rollbackError); + + const rejection = withTransaction(runtime, async () => { + throw callbackError; + }); + + await expect(rejection).rejects.toThrow('Transaction rollback failed after callback error'); + await expect(rejection).rejects.toMatchObject({ + code: 'RUNTIME.TRANSACTION_ROLLBACK_FAILED', + cause: callbackError, + details: { rollbackError }, + }); + expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce(); + }); + + it('sets invalidated flag after rollback', async () => { + const { runtime } = createRuntimeForTransaction(); + let txRef: { invalidated: boolean } | undefined; + + await withTransaction(runtime, async (tx) => { + txRef = tx; + throw new Error('fail'); + }).catch(() => {}); + + expect(txRef!.invalidated).toBe(true); + }); + + it('releases connection independently across sequential transactions', async () => { + const { runtime, driver } = createRuntimeForTransaction(); + + await withTransaction(runtime, async (tx) => { + await tx.execute(createRawExecutionPlan()).toArray(); + }); + + await withTransaction(runtime, async (tx) => { + await tx.execute(createRawExecutionPlan()).toArray(); + }); + + await withTransaction(runtime, async () => { + throw new Error('fail'); + }).catch(() => {}); + + expect(driver.__spies.connectionRelease).toHaveBeenCalledTimes(3); + expect(driver.__spies.transactionCommit).toHaveBeenCalledTimes(2); + expect(driver.__spies.transactionRollback).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/3-extensions/postgres/src/runtime/postgres.ts b/packages/3-extensions/postgres/src/runtime/postgres.ts index 4fb73e57ae..0bb3b318bb 100644 --- a/packages/3-extensions/postgres/src/runtime/postgres.ts +++ b/packages/3-extensions/postgres/src/runtime/postgres.ts @@ -15,11 +15,13 @@ import type { RuntimeVerifyOptions, SqlExecutionStackWithDriver, SqlRuntimeExtensionDescriptor, + TransactionContext, } from '@prisma-next/sql-runtime'; import { createExecutionContext, createRuntime, createSqlExecutionStack, + withTransaction, } from '@prisma-next/sql-runtime'; import postgresTarget from '@prisma-next/target-postgres/runtime'; import { type Client, Pool } from 'pg'; @@ -33,6 +35,11 @@ import { export type PostgresTargetId = 'postgres'; type OrmClient> = ReturnType>; +export interface PostgresTransactionContext> + extends TransactionContext { + readonly sql: Db; +} + export interface PostgresClient> { readonly sql: Db; readonly orm: OrmClient; @@ -40,6 +47,7 @@ export interface PostgresClient> { readonly stack: SqlExecutionStackWithDriver; connect(bindingInput?: PostgresBindingInput): Promise; runtime(): Runtime; + transaction(fn: (tx: PostgresTransactionContext) => PromiseLike): Promise; } export interface PostgresOptionsBase> { @@ -240,5 +248,15 @@ export default function postgres>( runtime() { return getRuntime(); }, + transaction(fn: (tx: PostgresTransactionContext) => PromiseLike): Promise { + return withTransaction(getRuntime(), (txCtx) => { + const txSql: Db = sqlBuilder({ context }); + const tx: PostgresTransactionContext = { + ...txCtx, + sql: txSql, + }; + return fn(tx); + }); + }, }; } diff --git a/packages/3-extensions/postgres/test/postgres.test.ts b/packages/3-extensions/postgres/test/postgres.test.ts index e556a31721..a2f79aa403 100644 --- a/packages/3-extensions/postgres/test/postgres.test.ts +++ b/packages/3-extensions/postgres/test/postgres.test.ts @@ -7,6 +7,7 @@ const mocks = vi.hoisted(() => ({ createRuntime: vi.fn(), createExecutionContext: vi.fn(), createSqlExecutionStack: vi.fn(), + withTransaction: vi.fn(), sqlBuilder: vi.fn(), driverCreate: vi.fn(), driverConnect: vi.fn(), @@ -22,6 +23,7 @@ vi.mock('@prisma-next/sql-runtime', () => ({ createExecutionContext: mocks.createExecutionContext, createSqlExecutionStack: mocks.createSqlExecutionStack, createRuntime: mocks.createRuntime, + withTransaction: mocks.withTransaction, })); vi.mock('@prisma-next/sql-contract/validate', () => ({ @@ -71,6 +73,7 @@ describe('postgres', () => { mocks.createRuntime.mockReset(); mocks.createExecutionContext.mockReset(); mocks.createSqlExecutionStack.mockReset(); + mocks.withTransaction.mockReset(); mocks.driverCreate.mockReset(); mocks.driverConnect.mockReset(); mocks.validateContract.mockReset(); @@ -95,6 +98,15 @@ describe('postgres', () => { mocks.createRuntime.mockReturnValue({ id: 'runtime-instance' }); mocks.validateContract.mockReturnValue(contract); mocks.sqlBuilder.mockReturnValue({ lane: 'sql' }); + mocks.withTransaction.mockImplementation( + async (_runtime: unknown, fn: (ctx: unknown) => unknown) => { + const mockTxCtx = { + invalidated: false, + execute: vi.fn(), + }; + return fn(mockTxCtx); + }, + ); }); it('sql is constructed eagerly without runtime; runtime and pool are deferred until runtime() is accessed', () => { @@ -388,4 +400,59 @@ describe('postgres', () => { }), ).toThrow('Unable to determine pg binding type from pg input'); }); + + it('transaction() delegates to withTransaction with the lazy runtime', async () => { + const db = postgres({ + contract, + url: 'postgres://localhost:5432/db', + }); + + const result = await db.transaction(async () => 'tx-value'); + + expect(mocks.withTransaction).toHaveBeenCalledOnce(); + expect(mocks.withTransaction).toHaveBeenCalledWith( + mocks.createRuntime.mock.results[0]?.value, + expect.any(Function), + ); + expect(result).toBe('tx-value'); + }); + + it('transaction() provides sql on the transaction context', async () => { + const txSqlProxy = { lane: 'tx-sql' }; + let callCount = 0; + mocks.sqlBuilder.mockImplementation(() => { + callCount++; + if (callCount === 1) return { lane: 'sql' }; + return txSqlProxy; + }); + + const db = postgres({ + contract, + url: 'postgres://localhost:5432/db', + }); + + let receivedTx: { sql?: unknown } | undefined; + await db.transaction(async (tx) => { + receivedTx = tx; + }); + + expect(receivedTx).toBeDefined(); + expect(receivedTx!.sql).toBe(txSqlProxy); + expect(mocks.sqlBuilder).toHaveBeenCalledTimes(2); + }); + + it('transaction() lazily creates runtime before connect()', async () => { + const db = postgres({ + contract, + url: 'postgres://localhost:5432/db', + }); + + expect(mocks.instantiateExecutionStack).not.toHaveBeenCalled(); + expect(mocks.createRuntime).not.toHaveBeenCalled(); + + await db.transaction(async () => 'value'); + + expect(mocks.instantiateExecutionStack).toHaveBeenCalledTimes(1); + expect(mocks.createRuntime).toHaveBeenCalledTimes(1); + }); }); diff --git a/packages/3-extensions/postgres/test/transaction.types.test-d.ts b/packages/3-extensions/postgres/test/transaction.types.test-d.ts new file mode 100644 index 0000000000..b0466b4785 --- /dev/null +++ b/packages/3-extensions/postgres/test/transaction.types.test-d.ts @@ -0,0 +1,29 @@ +import type { Contract } from '@prisma-next/contract/types'; +import type { SqlStorage } from '@prisma-next/sql-contract/types'; +import { expectTypeOf, test } from 'vitest'; +import type { PostgresClient, PostgresTransactionContext } from '../src/runtime/postgres'; + +type TestContract = Contract; + +test('transaction context does not expose a transaction method', () => { + type HasTransaction = 'transaction' extends keyof PostgresTransactionContext + ? true + : false; + expectTypeOf().toEqualTypeOf(); +}); + +test('db.transaction infers the callback return type correctly', () => { + const db = {} as PostgresClient; + + const numResult = db.transaction(async (_tx) => 42); + expectTypeOf(numResult).toEqualTypeOf>(); + + const objResult = db.transaction(async (_tx) => ({ name: 'test' as const, count: 3 })); + expectTypeOf(objResult).toEqualTypeOf>(); +}); + +test('tx.sql has the same type as db.sql', () => { + type DbSql = PostgresClient['sql']; + type TxSql = PostgresTransactionContext['sql']; + expectTypeOf().toEqualTypeOf(); +}); diff --git a/projects/orm-client-transaction-api/plan.md b/projects/orm-client-transaction-api/plan.md index 28b7b6e316..307b01a1d5 100644 --- a/projects/orm-client-transaction-api/plan.md +++ b/projects/orm-client-transaction-api/plan.md @@ -20,12 +20,12 @@ Implement the core transaction lifecycle in `sql-runtime` and expose it on `Post **Tasks:** -- [ ] **1.1** Add `TransactionContext` interface and `withTransaction` helper to `sql-runtime` (`packages/2-sql/5-runtime/src/sql-runtime.ts`). The helper acquires a connection from the `Runtime`, calls `beginTransaction()`, runs the callback with a `RuntimeQueryable` scoped to the transaction, commits on success, rolls back on throw, and releases the connection in `finally`. Add an `invalidated` flag to the transaction-scoped queryable that is set after commit/rollback — any subsequent `execute()` call throws a clear error per ADR 187. -- [ ] **1.2** Export `TransactionContext`, `withTransaction`, and related types from `sql-runtime` exports (`packages/2-sql/5-runtime/src/exports/index.ts`). Also export `RuntimeConnection` and `RuntimeTransaction` interfaces which are currently internal. -- [ ] **1.3** Add `transaction(fn: (tx: TransactionContext) => PromiseLike): Promise` to the `PostgresClient` interface and implementation (`packages/3-extensions/postgres/src/runtime/postgres.ts`). The implementation delegates to `withTransaction`, passing the runtime (lazy-initializing via `getRuntime()` like existing methods). `TransactionContext` exposes `execute` but NOT `transaction` (no nesting). -- [ ] **1.4** Wire `tx.sql` — create the `Db` proxy bound to the transaction's execute. The SQL builder (`sql()` from `sql-builder`) is stateless and only needs an `ExecutionContext`; the transaction context provides `tx.execute(plan)` for running built plans. Expose `tx.sql` on the transaction context so users can build queries against the same table proxies. -- [ ] **1.5** Unit tests for `withTransaction` lifecycle: successful commit, rollback on throw, connection release on both paths, error propagation, return value forwarding, COMMIT failure propagation. Test the invalidation flag — `execute()` after commit/rollback throws with actionable message. -- [ ] **1.6** Integration test against Postgres: two INSERTs in a transaction are both visible after commit. A throw after the first INSERT rolls back both — neither is visible. +- [x] **1.1** Add `TransactionContext` interface and `withTransaction` helper to `sql-runtime` (`packages/2-sql/5-runtime/src/sql-runtime.ts`). The helper acquires a connection from the `Runtime`, calls `beginTransaction()`, runs the callback with a `RuntimeQueryable` scoped to the transaction, commits on success, rolls back on throw, and releases the connection in `finally`. Add an `invalidated` flag to the transaction-scoped queryable that is set after commit/rollback — any subsequent `execute()` call throws a clear error per ADR 187. +- [x] **1.2** Export `TransactionContext`, `withTransaction`, and related types from `sql-runtime` exports (`packages/2-sql/5-runtime/src/exports/index.ts`). Also export `RuntimeConnection` and `RuntimeTransaction` interfaces which are currently internal. +- [x] **1.3** Add `transaction(fn: (tx: TransactionContext) => PromiseLike): Promise` to the `PostgresClient` interface and implementation (`packages/3-extensions/postgres/src/runtime/postgres.ts`). The implementation delegates to `withTransaction`, passing the runtime (lazy-initializing via `getRuntime()` like existing methods). `TransactionContext` exposes `execute` but NOT `transaction` (no nesting). +- [x] **1.4** Wire `tx.sql` — create the `Db` proxy bound to the transaction's execute. The SQL builder (`sql()` from `sql-builder`) is stateless and only needs an `ExecutionContext`; the transaction context provides `tx.execute(plan)` for running built plans. Expose `tx.sql` on the transaction context so users can build queries against the same table proxies. +- [x] **1.5** Unit tests for `withTransaction` lifecycle: successful commit, rollback on throw, connection release on both paths, error propagation, return value forwarding, COMMIT failure propagation. Test the invalidation flag — `execute()` after commit/rollback throws with actionable message. +- [x] **1.6** Integration test against Postgres: two INSERTs in a transaction are both visible after commit. A throw after the first INSERT rolls back both — neither is visible. ### Milestone 2: ORM integration diff --git a/projects/orm-client-transaction-api/spec.md b/projects/orm-client-transaction-api/spec.md index 23b9d70014..722574d5fc 100644 --- a/projects/orm-client-transaction-api/spec.md +++ b/projects/orm-client-transaction-api/spec.md @@ -83,23 +83,23 @@ await db.transaction(async (tx) => { ## Core API -- [ ] `db.transaction(callback)` is callable on `PostgresClient` and returns `Promise` where `T` is the callback's return value. +- [x] `db.transaction(callback)` is callable on `PostgresClient` and returns `Promise` where `T` is the callback's return value. - [ ] The callback receives a context object with `orm`, `sql`, and `execute` properties. - [ ] `tx.orm` has the same collection types and API as `db.orm`. -- [ ] `tx.sql` has the same table proxy types as `db.sql`. -- [ ] `tx.execute(plan)` executes a query plan against the transaction's connection. +- [x] `tx.sql` has the same table proxy types as `db.sql`. +- [x] `tx.execute(plan)` executes a query plan against the transaction's connection. ## Lifecycle -- [ ] A successful callback triggers `COMMIT` and the promise resolves with the return value. -- [ ] A throwing callback triggers `ROLLBACK` and the promise rejects with the original error. -- [ ] The connection is released after both commit and rollback paths. -- [ ] Multiple sequential transactions can reuse connections from the pool without leaking. +- [x] A successful callback triggers `COMMIT` and the promise resolves with the return value. +- [x] A throwing callback triggers `ROLLBACK` and the promise rejects with the original error. +- [x] The connection is released after both commit and rollback paths. +- [x] Multiple sequential transactions can reuse connections from the pool without leaking. ## Atomicity -- [ ] Two writes within `transaction` are both visible after commit (integration test against Postgres). -- [ ] A throw after the first write rolls back all writes — neither is visible (integration test). +- [x] Two writes within `transaction` are both visible after commit (integration test against Postgres). +- [x] A throw after the first write rolls back all writes — neither is visible (integration test). ## ORM integration @@ -108,18 +108,18 @@ await db.transaction(async (tx) => { ## Type safety -- [ ] `tx` does **not** have a `transaction` method (compile-time check via negative type test). -- [ ] `db.transaction` infers the callback return type correctly (type-level test). +- [x] `tx` does **not** have a `transaction` method (compile-time check via negative type test). +- [x] `db.transaction` infers the callback return type correctly (type-level test). ## Escaped result safety -- [ ] An `AsyncIterableResult` created inside a transaction that is consumed after commit/rollback produces a clear error message. -- [ ] `await db.transaction((tx) => tx.execute(plan))` drains eagerly via `PromiseLike` and returns `Row[]` (the safe common case). +- [x] An `AsyncIterableResult` created inside a transaction that is consumed after commit/rollback produces a clear error message. +- [x] `await db.transaction((tx) => tx.execute(plan))` drains eagerly via `PromiseLike` and returns `Row[]` (the safe common case). ## Edge cases -- [ ] Calling `transaction` before `connect()` auto-connects lazily (same as `db.runtime()`). -- [ ] If the callback returns without throwing but `COMMIT` fails, the promise rejects with the commit error. +- [x] Calling `transaction` before `connect()` auto-connects lazily (same as `db.runtime()`). +- [x] If the callback returns without throwing but `COMMIT` fails, the promise rejects with the commit error. # Other Considerations diff --git a/test/e2e/framework/test/transaction.test.ts b/test/e2e/framework/test/transaction.test.ts new file mode 100644 index 0000000000..f226913ecf --- /dev/null +++ b/test/e2e/framework/test/transaction.test.ts @@ -0,0 +1,84 @@ +import { dirname, resolve } from 'node:path'; +import { fileURLToPath } from 'node:url'; +import { withTransaction } from '@prisma-next/sql-runtime'; +import { describe, expect, it } from 'vitest'; +import type { Contract } from './fixtures/generated/contract.d'; +import { withTestRuntime } from './utils'; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = dirname(__filename); +const contractJsonPath = resolve(__dirname, 'fixtures/generated/contract.json'); + +describe('transaction E2E', { timeout: 30000 }, () => { + it('commits both writes atomically', async () => { + await withTestRuntime(contractJsonPath, async ({ db, runtime, client }) => { + await withTransaction(runtime, async (tx) => { + await tx.execute(db.user.insert({ email: 'tx-user-1@example.com' }).build()); + await tx.execute(db.user.insert({ email: 'tx-user-2@example.com' }).build()); + }); + + const result = await client.query( + `SELECT email FROM "user" WHERE email IN ('tx-user-1@example.com', 'tx-user-2@example.com') ORDER BY email`, + ); + expect(result.rows).toEqual([ + { email: 'tx-user-1@example.com' }, + { email: 'tx-user-2@example.com' }, + ]); + }); + }); + + it('rolls back all writes on error', async () => { + await withTestRuntime(contractJsonPath, async ({ db, runtime, client }) => { + await expect( + withTransaction(runtime, async (tx) => { + await tx.execute(db.user.insert({ email: 'tx-rollback@example.com' }).build()); + throw new Error('deliberate rollback'); + }), + ).rejects.toThrow('deliberate rollback'); + + const result = await client.query( + `SELECT email FROM "user" WHERE email = 'tx-rollback@example.com'`, + ); + expect(result.rows).toEqual([]); + }); + }); + + it('forwards the callback return value after commit', async () => { + await withTestRuntime(contractJsonPath, async ({ db, runtime }) => { + const result = await withTransaction(runtime, async (tx) => { + await tx.execute(db.user.insert({ email: 'tx-return@example.com' }).build()); + return { inserted: true }; + }); + + expect(result).toEqual({ inserted: true }); + }); + }); + + it('collects returned stream before commit', async () => { + await withTestRuntime(contractJsonPath, async ({ db, runtime }) => { + const result = await withTransaction(runtime, async (tx) => { + await tx.execute(db.user.insert({ email: 'tx-user-1@example.com' }).build()); + await tx.execute(db.user.insert({ email: 'tx-user-2@example.com' }).build()); + return tx.execute(db.user.select('email').build()); + }); + + expect(result).toEqual([ + { email: 'tx-user-1@example.com' }, + { email: 'tx-user-2@example.com' }, + ]); + }); + }); + + it('rejects escaped AsyncIterableResult consumed after commit', async () => { + await withTestRuntime(contractJsonPath, async ({ db, runtime }) => { + const escaped = await withTransaction(runtime, async (tx) => { + await tx.execute(db.user.insert({ email: 'tx-escape@example.com' }).build()); + return { rows: tx.execute(db.user.select('email').build()) }; + }); + + await expect(escaped.rows.toArray()).rejects.toThrow( + 'Cannot read from a query result after the transaction has ended', + ); + }); + }); +});