Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/2-sql/5-runtime/src/exports/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ export {
export type {
CreateRuntimeOptions,
Runtime,
RuntimeConnection,
RuntimeQueryable,
RuntimeTelemetryEvent,
RuntimeTransaction,
RuntimeVerifyOptions,
TelemetryOutcome,
TransactionContext,
} from '../sql-runtime';
export { createRuntime } from '../sql-runtime';
export { createRuntime, withTransaction } from '../sql-runtime';
73 changes: 72 additions & 1 deletion packages/2-sql/5-runtime/src/sql-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ import type {
RuntimeVerifyOptions,
TelemetryOutcome,
} from '@prisma-next/runtime-executor';
import { AsyncIterableResult, createRuntimeCore } from '@prisma-next/runtime-executor';
import {
AsyncIterableResult,
createRuntimeCore,
runtimeError,
} from '@prisma-next/runtime-executor';
import type { SqlStorage } from '@prisma-next/sql-contract/types';
import type {
Adapter,
Expand Down Expand Up @@ -95,6 +99,10 @@ export interface RuntimeQueryable {
): AsyncIterableResult<Row>;
}

export interface TransactionContext extends RuntimeQueryable {
readonly invalidated: boolean;
}

interface CoreQueryable {
execute<Row = Record<string, unknown>>(plan: ExecutionPlan<Row>): AsyncIterableResult<Row>;
}
Expand Down Expand Up @@ -238,6 +246,69 @@ class SqlRuntimeImpl<TContract extends Contract<SqlStorage> = Contract<SqlStorag
}
}

function transactionClosedError(): Error {
return runtimeError(
'RUNTIME.TRANSACTION_CLOSED',
'Cannot read from a query result after the transaction has ended. Await the result or call .toArray() inside the transaction callback.',
{},
);
}

export async function withTransaction<R>(
runtime: Runtime,
fn: (tx: TransactionContext) => PromiseLike<R>,
): Promise<R> {
const connection = await runtime.connection();
const transaction = await connection.transaction();

let invalidated = false;
const txContext: TransactionContext = {
get invalidated() {
return invalidated;
},
execute<Row = Record<string, unknown>>(
plan: ExecutionPlan<Row> | SqlQueryPlan<Row>,
): AsyncIterableResult<Row> {
if (invalidated) {
throw transactionClosedError();
}
const inner = transaction.execute(plan);
const guarded = async function* (): AsyncGenerator<Row, void, unknown> {
for await (const row of inner) {
if (invalidated) {
throw transactionClosedError();
}
yield row;
}
};
return new AsyncIterableResult(guarded());
},
};

try {
const result = await fn(txContext);
invalidated = true;
await transaction.commit();
return result;
} catch (error) {
invalidated = true;
try {
await transaction.rollback();
} catch (rollbackError) {
const wrapped = runtimeError(
'RUNTIME.TRANSACTION_ROLLBACK_FAILED',
'Transaction rollback failed after callback error',
{ rollbackError },
);
wrapped.cause = error;
throw wrapped;
}
throw error;
} finally {
await connection.release();
}
}

export function createRuntime<TContract extends Contract<SqlStorage>, TTargetId extends string>(
options: CreateRuntimeOptions<TContract, TTargetId>,
): Runtime {
Expand Down
193 changes: 190 additions & 3 deletions packages/2-sql/5-runtime/test/sql-runtime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import type {
SqlRuntimeTargetDescriptor,
} from '../src/sql-context';
import { createExecutionContext, createSqlExecutionStack } from '../src/sql-context';
import { createRuntime } from '../src/sql-runtime';
import { createRuntime, withTransaction } from '../src/sql-runtime';

const testContract: Contract<SqlStorage> = {
targetFamily: 'sql',
Expand All @@ -35,13 +35,16 @@ const testContract: Contract<SqlStorage> = {
meta: {},
};

interface DriverExecuteSpies {
interface DriverMockSpies {
rootExecute: ReturnType<typeof vi.fn>;
connectionExecute: ReturnType<typeof vi.fn>;
transactionExecute: ReturnType<typeof vi.fn>;
connectionRelease: ReturnType<typeof vi.fn>;
transactionCommit: ReturnType<typeof vi.fn>;
transactionRollback: ReturnType<typeof vi.fn>;
}

type MockSqlDriver = SqlDriver & { __spies: DriverExecuteSpies };
type MockSqlDriver = SqlDriver & { __spies: DriverMockSpies };

function createStubCodecs(): CodecRegistry {
const registry = createCodecRegistry();
Expand Down Expand Up @@ -128,6 +131,9 @@ function createMockDriver(): MockSqlDriver {
rootExecute,
connectionExecute,
transactionExecute,
connectionRelease: connection.release,
transactionCommit: transaction.commit,
transactionRollback: transaction.rollback,
},
});
}
Expand Down Expand Up @@ -324,3 +330,184 @@ describe('createRuntime', () => {
expect(driver.__spies.transactionExecute).not.toHaveBeenCalled();
});
});

describe('withTransaction', () => {
function createRuntimeForTransaction() {
const { stackInstance, context, driver } = createTestSetup();
const runtime = createRuntime({
stackInstance,
context,
driver,
verify: { mode: 'onFirstUse', requireMarker: false },
});
return { runtime, driver };
}

it('commits on successful callback and returns the result', async () => {
const { runtime, driver } = createRuntimeForTransaction();

const result = await withTransaction(runtime, async (tx) => {
await tx.execute(createRawExecutionPlan()).toArray();
return 42;
});

expect(result).toBe(42);
expect(driver.__spies.transactionCommit).toHaveBeenCalledOnce();
expect(driver.__spies.transactionRollback).not.toHaveBeenCalled();
expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce();
});

it('rolls back on callback error and re-throws', async () => {
const { runtime, driver } = createRuntimeForTransaction();
const error = new Error('test error');

await expect(
withTransaction(runtime, async () => {
throw error;
}),
).rejects.toBe(error);

expect(driver.__spies.transactionRollback).toHaveBeenCalledOnce();
expect(driver.__spies.transactionCommit).not.toHaveBeenCalled();
expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce();
});

it('releases connection after commit', async () => {
const { runtime, driver } = createRuntimeForTransaction();

await withTransaction(runtime, async () => 'ok');

expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce();
});

it('releases connection after rollback', async () => {
const { runtime, driver } = createRuntimeForTransaction();

await withTransaction(runtime, async () => {
throw new Error('fail');
}).catch(() => {});

expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce();
});

it('propagates commit failure', async () => {
const { runtime, driver } = createRuntimeForTransaction();
const commitError = new Error('commit failed');
driver.__spies.transactionCommit.mockRejectedValueOnce(commitError);

const result = withTransaction(runtime, async () => 'value');

await expect(result).rejects.toBe(commitError);
});

it('forwards the callback return value', async () => {
const { runtime } = createRuntimeForTransaction();

const result = await withTransaction(runtime, async () => ({
name: 'test',
count: 3,
}));

expect(result).toEqual({ name: 'test', count: 3 });
});

it('executes queries against the transaction', async () => {
const { runtime, driver } = createRuntimeForTransaction();

await withTransaction(runtime, async (tx) => {
await tx.execute(createRawExecutionPlan()).toArray();
});

expect(driver.__spies.transactionExecute).toHaveBeenCalledOnce();
expect(driver.__spies.rootExecute).not.toHaveBeenCalled();
expect(driver.__spies.connectionExecute).not.toHaveBeenCalled();
});

it('throws on execute after commit (invalidation)', async () => {
const { runtime } = createRuntimeForTransaction();
let savedTx: { execute: (plan: ExecutionPlan) => unknown } | undefined;

await withTransaction(runtime, async (tx) => {
savedTx = tx;
});

expect(() => savedTx!.execute(createRawExecutionPlan())).toThrow(
'Cannot read from a query result after the transaction has ended',
);
});

it('throws on iteration of escaped AsyncIterableResult after commit', async () => {
const { runtime } = createRuntimeForTransaction();

const escaped = await withTransaction(runtime, async (tx) => {
return { result: tx.execute(createRawExecutionPlan()) };
});

await expect(escaped.result.toArray()).rejects.toThrow(
'Cannot read from a query result after the transaction has ended',
);
});

it('sets invalidated flag after commit', async () => {
const { runtime } = createRuntimeForTransaction();
let txRef: { invalidated: boolean } | undefined;

await withTransaction(runtime, async (tx) => {
expect(tx.invalidated).toBe(false);
txRef = tx;
});

expect(txRef!.invalidated).toBe(true);
});

it('wraps original error when rollback fails', async () => {
const { runtime, driver } = createRuntimeForTransaction();
const callbackError = new Error('callback failed');
const rollbackError = new Error('rollback failed');
driver.__spies.transactionRollback.mockRejectedValueOnce(rollbackError);

const rejection = withTransaction(runtime, async () => {
throw callbackError;
});

await expect(rejection).rejects.toThrow('Transaction rollback failed after callback error');
await expect(rejection).rejects.toMatchObject({
code: 'RUNTIME.TRANSACTION_ROLLBACK_FAILED',
cause: callbackError,
details: { rollbackError },
});
expect(driver.__spies.connectionRelease).toHaveBeenCalledOnce();
});

it('sets invalidated flag after rollback', async () => {
const { runtime } = createRuntimeForTransaction();
let txRef: { invalidated: boolean } | undefined;

await withTransaction(runtime, async (tx) => {
txRef = tx;
throw new Error('fail');
}).catch(() => {});

expect(txRef!.invalidated).toBe(true);
});

it('releases connection independently across sequential transactions', async () => {
const { runtime, driver } = createRuntimeForTransaction();

await withTransaction(runtime, async (tx) => {
await tx.execute(createRawExecutionPlan()).toArray();
});

await withTransaction(runtime, async (tx) => {
await tx.execute(createRawExecutionPlan()).toArray();
});

await withTransaction(runtime, async () => {
throw new Error('fail');
}).catch(() => {});

expect(driver.__spies.connectionRelease).toHaveBeenCalledTimes(3);
expect(driver.__spies.transactionCommit).toHaveBeenCalledTimes(2);
expect(driver.__spies.transactionRollback).toHaveBeenCalledTimes(1);
});
});
18 changes: 18 additions & 0 deletions packages/3-extensions/postgres/src/runtime/postgres.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ import type {
RuntimeVerifyOptions,
SqlExecutionStackWithDriver,
SqlRuntimeExtensionDescriptor,
TransactionContext,
} from '@prisma-next/sql-runtime';
import {
createExecutionContext,
createRuntime,
createSqlExecutionStack,
withTransaction,
} from '@prisma-next/sql-runtime';
import postgresTarget from '@prisma-next/target-postgres/runtime';
import { type Client, Pool } from 'pg';
Expand All @@ -33,13 +35,19 @@ import {
export type PostgresTargetId = 'postgres';
type OrmClient<TContract extends Contract<SqlStorage>> = ReturnType<typeof ormBuilder<TContract>>;

export interface PostgresTransactionContext<TContract extends Contract<SqlStorage>>
extends TransactionContext {
readonly sql: Db<TContract>;
}

export interface PostgresClient<TContract extends Contract<SqlStorage>> {
readonly sql: Db<TContract>;
readonly orm: OrmClient<TContract>;
readonly context: ExecutionContext<TContract>;
readonly stack: SqlExecutionStackWithDriver<PostgresTargetId>;
connect(bindingInput?: PostgresBindingInput): Promise<Runtime>;
runtime(): Runtime;
transaction<R>(fn: (tx: PostgresTransactionContext<TContract>) => PromiseLike<R>): Promise<R>;
}

export interface PostgresOptionsBase<TContract extends Contract<SqlStorage>> {
Expand Down Expand Up @@ -240,5 +248,15 @@ export default function postgres<TContract extends Contract<SqlStorage>>(
runtime() {
return getRuntime();
},
transaction<R>(fn: (tx: PostgresTransactionContext<TContract>) => PromiseLike<R>): Promise<R> {
return withTransaction(getRuntime(), (txCtx) => {
const txSql: Db<TContract> = sqlBuilder<TContract>({ context });
const tx: PostgresTransactionContext<TContract> = {
...txCtx,
sql: txSql,
};
return fn(tx);
});
},
};
}
Loading
Loading