Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ export function resolvePrimaryKeyColumn(
return contract.storage.tables[tableName]?.primaryKey?.columns[0] ?? 'id';
}

export function resolveColumnCodecId(
contract: SqlContract<SqlStorage>,
tableName: string,
columnName: string,
): string | undefined {
return contract.storage.tables[tableName]?.columns[columnName]?.codecId;
}

export function assertReturningCapability(contract: SqlContract<SqlStorage>, action: string): void {
if (hasContractCapability(contract, 'returning')) {
return;
Expand Down
2 changes: 1 addition & 1 deletion packages/3-extensions/sql-orm-client/src/collection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ export class Collection<
: isWhereDirectInput(input)
? input
: shorthandToWhereExpr(this.ctx.context, this.modelName, input);
const filter = normalizeWhereArg(whereArg, { contract: this.contract });
const filter = normalizeWhereArg(whereArg);

if (!filter) {
return this as Collection<TContract, ModelName, Row, WithWhereState<State>>;
Expand Down
22 changes: 13 additions & 9 deletions packages/3-extensions/sql-orm-client/src/filters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import {
type AnyExpression,
BinaryExpr,
ColumnRef,
LiteralExpr,
NullCheckExpr,
OrExpr,
ParamRef,
} from '@prisma-next/sql-relational-core/ast';
import type { ExecutionContext } from '@prisma-next/sql-relational-core/query-lane-context';
import { resolveColumnCodecId } from './collection-contract';
import type { ShorthandWhereFilter } from './types';

export function and(...exprs: AnyExpression[]): AndExpr {
Expand Down Expand Up @@ -62,8 +63,8 @@ export function shorthandToWhereExpr<
continue;
}

assertFieldHasEqualityTrait(context, tableName, columnName, modelName, fieldName);
exprs.push(BinaryExpr.eq(left, LiteralExpr.of(value)));
const codecId = assertEqualityCodecId(context, tableName, columnName, modelName, fieldName);
exprs.push(BinaryExpr.eq(left, ParamRef.of(value, { name: columnName, codecId })));
}

if (exprs.length === 0) {
Expand All @@ -73,21 +74,24 @@ export function shorthandToWhereExpr<
return exprs.length === 1 ? exprs[0] : and(...exprs);
}

function assertFieldHasEqualityTrait(
function assertEqualityCodecId(
context: ExecutionContext,
tableName: string,
columnName: string,
modelName: string,
fieldName: string,
): void {
const tables = context.contract.storage?.tables as
| Record<string, { columns?: Record<string, { codecId?: string }> }>
| undefined;
const codecId = tables?.[tableName]?.columns?.[columnName]?.codecId;
): string {
const codecId = resolveColumnCodecId(context.contract, tableName, columnName);
const traits = codecId ? context.codecs.traitsOf(codecId) : [];
if (!traits.includes('equality')) {
throw new Error(
`Shorthand filter on "${modelName}.${fieldName}": field does not support equality comparisons`,
);
}
if (!codecId) {
throw new Error(
`Shorthand filter on "${modelName}.${fieldName}": column "${columnName}" has no codec`,
);
}
return codecId;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
BinaryExpr,
type BinaryOp,
ColumnRef,
LiteralExpr,
ParamRef,
} from '@prisma-next/sql-relational-core/ast';
import { createAggregateBuilder, isAggregateSelector } from './aggregate-builder';
import { mapStorageRowToModelFields } from './collection-runtime';
Expand Down Expand Up @@ -156,7 +156,7 @@ function createHavingComparisonMethods<T extends number | null>(
metric: AggregateExpr,
): HavingComparisonMethods<T> {
const buildBinaryExpr = (op: BinaryOp, value: unknown): AnyExpression =>
new BinaryExpr(op, metric, LiteralExpr.of(value));
new BinaryExpr(op, metric, ParamRef.of(value));

return {
eq(value) {
Expand Down
30 changes: 17 additions & 13 deletions packages/3-extensions/sql-orm-client/src/model-accessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
TableSource,
} from '@prisma-next/sql-relational-core/ast';
import type { ExecutionContext } from '@prisma-next/sql-relational-core/query-lane-context';
import { resolveColumnCodecId } from './collection-contract';
import { and, not } from './filters';
import {
COMPARISON_METHODS_META,
Expand Down Expand Up @@ -59,40 +60,43 @@ export function createModelAccessor<
}

const columnName = fieldToColumn[prop] ?? prop;
const traits = resolveFieldTraits(contract, tableName, columnName, context);
return createScalarFieldAccessor(tableName, columnName, traits);
const columnMeta = resolveColumnMeta(contract, tableName, columnName, context);
return createScalarFieldAccessor(tableName, columnName, columnMeta);
},
});
}

function resolveFieldTraits(
interface ColumnMeta {
readonly codecId: string;
readonly traits: readonly string[];
}

function resolveColumnMeta(
contract: SqlContract<SqlStorage>,
tableName: string,
columnName: string,
context: ExecutionContext,
): readonly string[] {
const tables = contract.storage?.tables as
| Record<string, { columns?: Record<string, { codecId?: string }> }>
| undefined;
const codecId = tables?.[tableName]?.columns?.[columnName]?.codecId;
// unknown columns get no trait-gated methods
if (!codecId) return [];
return context.codecs.traitsOf(codecId);
): ColumnMeta | undefined {
const codecId = resolveColumnCodecId(contract, tableName, columnName);
if (!codecId) return undefined;
return { codecId, traits: context.codecs.traitsOf(codecId) };
}

function createScalarFieldAccessor(
tableName: string,
columnName: string,
traits: readonly string[],
columnMeta: ColumnMeta | undefined,
): Partial<ComparisonMethodFns<unknown>> {
const column = ColumnRef.of(tableName, columnName);
const methods: Record<string, unknown> = {};
const traits = columnMeta?.traits ?? [];
const codecId = columnMeta?.codecId ?? '';

for (const [name, meta] of Object.entries(COMPARISON_METHODS_META)) {
if (meta.traits.some((t) => !traits.includes(t))) {
continue;
}
methods[name] = meta.create(column);
methods[name] = meta.create(column, codecId);
}

return methods as Partial<ComparisonMethodFns<unknown>>;
Expand Down
19 changes: 13 additions & 6 deletions packages/3-extensions/sql-orm-client/src/mutation-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ import {
type AnyExpression,
BinaryExpr,
ColumnRef,
LiteralExpr,
ParamRef,
} from '@prisma-next/sql-relational-core/ast';
import type { ExecutionContext } from '@prisma-next/sql-relational-core/query-lane-context';
import { resolveModelTableName, resolvePrimaryKeyColumn } from './collection-contract';
import {
resolveColumnCodecId,
resolveModelTableName,
resolvePrimaryKeyColumn,
} from './collection-contract';
import {
acquireRuntimeScope,
mapModelDataToStorageRow,
Expand Down Expand Up @@ -486,7 +490,7 @@ async function applyChildOwnedMutation(
}

if (!mutation.criteria || mutation.criteria.length === 0) {
const parentJoinWhere = buildChildJoinWhere(relation, parentValues);
const parentJoinWhere = buildChildJoinWhere(contract, relation, parentValues);
await executeUpdateCount(scope, contract, relation.relatedTableName, setValues, [
parentJoinWhere,
]);
Expand All @@ -505,7 +509,7 @@ async function applyChildOwnedMutation(
);
}

const parentJoinWhere = buildChildJoinWhere(relation, parentValues);
const parentJoinWhere = buildChildJoinWhere(contract, relation, parentValues);
await executeUpdateCount(scope, contract, relation.relatedTableName, setValues, [
and(parentJoinWhere, criterionWhere),
]);
Expand Down Expand Up @@ -542,16 +546,19 @@ function readParentColumnValues(
}

function buildChildJoinWhere(
contract: SqlContract<SqlStorage>,
relation: RelationDefinition,
childValues: Map<string, unknown>,
): AnyExpression {
const exprs: AnyExpression[] = [];
const tableName = relation.relatedTableName;

for (const [childColumn, parentValue] of childValues.entries()) {
const codecId = resolveColumnCodecId(contract, tableName, childColumn);
exprs.push(
BinaryExpr.eq(
ColumnRef.of(relation.relatedTableName, childColumn),
LiteralExpr.of(parentValue),
ColumnRef.of(tableName, childColumn),
ParamRef.of(parentValue, { name: childColumn, ...(codecId ? { codecId } : {}) }),
),
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,16 @@ function toAggregateExpr(tableName: string, selector: AggregateSelector<unknown>
return new AggregateExpr(selector.fn, ColumnRef.of(tableName, selector.column));
}

// ORM HAVING filters use literal binding (values inlined at plan-build time),
// not parameterized binding. ParamRef is rejected because the ORM's grouped
// collection API always produces literal comparisons for having() predicates.
function validateGroupedComparable(value: AnyExpression): AnyExpression {
switch (value.kind) {
case 'param-ref':
throw new Error('ParamRef is not supported in grouped having expressions');
case 'literal':
case 'column-ref':
case 'identifier-ref':
case 'aggregate':
case 'operation':
return value;
case 'list':
if (value.values.some((entry) => entry.kind === 'param-ref')) {
throw new Error('ParamRef is not supported in grouped having expressions');
}
return value;
default:
throw new Error(`Unsupported comparable kind in grouped having: "${value.kind}"`);
Expand Down
45 changes: 30 additions & 15 deletions packages/3-extensions/sql-orm-client/src/query-plan-select.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ import {
JsonArrayAggExpr,
JsonObjectExpr,
ListExpression,
LiteralExpr,
OrderByItem,
OrExpr,
ParamRef,
ProjectionItem,
SelectAst,
SubqueryExpr,
TableSource,
} from '@prisma-next/sql-relational-core/ast';
import type { SqlQueryPlan } from '@prisma-next/sql-relational-core/plan';
import { resolveColumnCodecId } from './collection-contract';
import { buildOrmQueryPlan, deriveParamsFromAst, resolveTableColumns } from './query-plan-meta';
import type { CollectionState, IncludeExpr, OrderExpr } from './types';
import { bindWhereExpr } from './where-binding';
import { combineWhereExprs } from './where-utils';

type CursorOrderEntry = OrderExpr & {
Expand Down Expand Up @@ -57,16 +57,31 @@ function toOrderBy(
);
}

function createBoundaryExpr(tableName: string, entry: CursorOrderEntry): AnyExpression {
function columnParam(
contract: SqlContract<SqlStorage>,
tableName: string,
column: string,
value: unknown,
): ParamRef {
const codecId = resolveColumnCodecId(contract, tableName, column);
return ParamRef.of(value, { name: column, ...(codecId ? { codecId } : {}) });
}

function createBoundaryExpr(
contract: SqlContract<SqlStorage>,
tableName: string,
entry: CursorOrderEntry,
): AnyExpression {
const comparator: BinaryOp = entry.direction === 'asc' ? 'gt' : 'lt';
return new BinaryExpr(
comparator,
ColumnRef.of(tableName, entry.column),
LiteralExpr.of(entry.value),
columnParam(contract, tableName, entry.column, entry.value),
);
}

function buildLexicographicCursorWhere(
contract: SqlContract<SqlStorage>,
tableName: string,
entries: readonly CursorOrderEntry[],
): AnyExpression {
Expand All @@ -77,12 +92,12 @@ function buildLexicographicCursorWhere(
branchExprs.push(
BinaryExpr.eq(
ColumnRef.of(tableName, prefixEntry.column),
LiteralExpr.of(prefixEntry.value),
columnParam(contract, tableName, prefixEntry.column, prefixEntry.value),
),
);
}

branchExprs.push(createBoundaryExpr(tableName, entry));
branchExprs.push(createBoundaryExpr(contract, tableName, entry));
if (branchExprs.length === 1) {
return branchExprs[0] as AnyExpression;
}
Expand All @@ -98,6 +113,7 @@ function buildLexicographicCursorWhere(
}

function buildCursorWhere(
contract: SqlContract<SqlStorage>,
tableName: string,
orderBy: readonly OrderExpr[] | undefined,
cursor: Readonly<Record<string, unknown>> | undefined,
Expand All @@ -120,10 +136,10 @@ function buildCursorWhere(

const firstEntry = entries[0];
if (entries.length === 1 && firstEntry !== undefined) {
return createBoundaryExpr(tableName, firstEntry);
return createBoundaryExpr(contract, tableName, firstEntry);
}

return buildLexicographicCursorWhere(tableName, entries);
return buildLexicographicCursorWhere(contract, tableName, entries);
}

function createTableRefRemapper(fromTable: string, toTable: string): AstRewriter {
Expand Down Expand Up @@ -153,18 +169,17 @@ function buildStateWhere(
): AnyExpression | undefined {
const filterTableName = options?.filterTableName;
const cursorTableName = filterTableName ?? tableName;
const cursorWhere = buildCursorWhere(cursorTableName, state.orderBy, state.cursor);
const cursorWhere = buildCursorWhere(contract, cursorTableName, state.orderBy, state.cursor);
const remappedFilters =
filterTableName && filterTableName !== tableName
? state.filters.map((filter) =>
filter.rewrite(createTableRefRemapper(filterTableName, tableName)),
)
: state.filters;
const boundCursorWhere = cursorWhere ? bindWhereExpr(contract, cursorWhere) : undefined;
const remappedCursorWhere =
boundCursorWhere && filterTableName && filterTableName !== tableName
? boundCursorWhere.rewrite(createTableRefRemapper(filterTableName, tableName))
: boundCursorWhere;
cursorWhere && filterTableName && filterTableName !== tableName
? cursorWhere.rewrite(createTableRefRemapper(filterTableName, tableName))
: cursorWhere;
const filters = remappedCursorWhere ? [...remappedFilters, remappedCursorWhere] : remappedFilters;
return combineWhereExprs(filters);
}
Expand Down Expand Up @@ -406,15 +421,15 @@ export function compileRelationSelect(
): SqlQueryPlan<Record<string, unknown>> {
const inFilter: AnyExpression = BinaryExpr.in(
ColumnRef.of(relatedTableName, fkColumn),
ListExpression.fromValues(parentPks),
ListExpression.of(parentPks.map((pk) => columnParam(contract, relatedTableName, fkColumn, pk))),
);

return compileSelect(contract, relatedTableName, {
...nestedState,
includes: [],
limit: undefined,
offset: undefined,
filters: [bindWhereExpr(contract, inFilter), ...nestedState.filters],
filters: [inFilter, ...nestedState.filters],
});
}

Expand Down
Loading
Loading