From 79f3317bd0e2ca2c35d576c8e22ee203c8b14f8f Mon Sep 17 00:00:00 2001 From: Alexey Orlenko Date: Tue, 31 Mar 2026 13:45:55 +0200 Subject: [PATCH] fix(sql-orm-client): use `ParamRef`s in mutations and having expressions --- .../sql-orm-client/src/collection-contract.ts | 8 + .../sql-orm-client/src/collection.ts | 2 +- .../sql-orm-client/src/filters.ts | 22 +- .../sql-orm-client/src/grouped-collection.ts | 4 +- .../sql-orm-client/src/model-accessor.ts | 30 +- .../sql-orm-client/src/mutation-executor.ts | 19 +- .../src/query-plan-aggregate.ts | 7 - .../sql-orm-client/src/query-plan-select.ts | 45 +- .../3-extensions/sql-orm-client/src/types.ts | 42 +- .../sql-orm-client/src/where-binding.ts | 190 ------ .../sql-orm-client/src/where-interop.ts | 22 +- .../test/collection.state.test.ts | 46 +- .../sql-orm-client/test/filters.test.ts | 118 +++- .../test/integration/group-by.test.ts | 4 +- .../test/model-accessor.test.ts | 71 ++- .../test/mutation-executor.test.ts | 12 +- .../test/query-plan-aggregate.test.ts | 105 ++-- .../test/query-plan-select.test.ts | 53 +- .../test/rich-collection.test.ts | 9 +- .../test/rich-filters-and-where.test.ts | 22 +- .../test/rich-query-plans.test.ts | 17 +- .../sql-orm-client/test/where-binding.test.ts | 549 ------------------ 22 files changed, 415 insertions(+), 982 deletions(-) delete mode 100644 packages/3-extensions/sql-orm-client/src/where-binding.ts delete mode 100644 packages/3-extensions/sql-orm-client/test/where-binding.test.ts diff --git a/packages/3-extensions/sql-orm-client/src/collection-contract.ts b/packages/3-extensions/sql-orm-client/src/collection-contract.ts index f5d97a2278..1f04f42266 100644 --- a/packages/3-extensions/sql-orm-client/src/collection-contract.ts +++ b/packages/3-extensions/sql-orm-client/src/collection-contract.ts @@ -179,6 +179,14 @@ export function resolvePrimaryKeyColumn( return contract.storage.tables[tableName]?.primaryKey?.columns[0] ?? 'id'; } +export function resolveColumnCodecId( + contract: SqlContract, + tableName: string, + columnName: string, +): string | undefined { + return contract.storage.tables[tableName]?.columns[columnName]?.codecId; +} + export function assertReturningCapability(contract: SqlContract, action: string): void { if (hasContractCapability(contract, 'returning')) { return; diff --git a/packages/3-extensions/sql-orm-client/src/collection.ts b/packages/3-extensions/sql-orm-client/src/collection.ts index 9eae94cf6d..97018de4ef 100644 --- a/packages/3-extensions/sql-orm-client/src/collection.ts +++ b/packages/3-extensions/sql-orm-client/src/collection.ts @@ -171,7 +171,7 @@ export class Collection< : isWhereDirectInput(input) ? input : shorthandToWhereExpr(this.ctx.context, this.modelName, input); - const filter = normalizeWhereArg(whereArg, { contract: this.contract }); + const filter = normalizeWhereArg(whereArg); if (!filter) { return this as Collection>; diff --git a/packages/3-extensions/sql-orm-client/src/filters.ts b/packages/3-extensions/sql-orm-client/src/filters.ts index 7040cbc535..8eecf86e6f 100644 --- a/packages/3-extensions/sql-orm-client/src/filters.ts +++ b/packages/3-extensions/sql-orm-client/src/filters.ts @@ -4,11 +4,12 @@ import { type AnyExpression, BinaryExpr, ColumnRef, - LiteralExpr, NullCheckExpr, OrExpr, + ParamRef, } from '@prisma-next/sql-relational-core/ast'; import type { ExecutionContext } from '@prisma-next/sql-relational-core/query-lane-context'; +import { resolveColumnCodecId } from './collection-contract'; import type { ShorthandWhereFilter } from './types'; export function and(...exprs: AnyExpression[]): AndExpr { @@ -62,8 +63,8 @@ export function shorthandToWhereExpr< continue; } - assertFieldHasEqualityTrait(context, tableName, columnName, modelName, fieldName); - exprs.push(BinaryExpr.eq(left, LiteralExpr.of(value))); + const codecId = assertEqualityCodecId(context, tableName, columnName, modelName, fieldName); + exprs.push(BinaryExpr.eq(left, ParamRef.of(value, { name: columnName, codecId }))); } if (exprs.length === 0) { @@ -73,21 +74,24 @@ export function shorthandToWhereExpr< return exprs.length === 1 ? exprs[0] : and(...exprs); } -function assertFieldHasEqualityTrait( +function assertEqualityCodecId( context: ExecutionContext, tableName: string, columnName: string, modelName: string, fieldName: string, -): void { - const tables = context.contract.storage?.tables as - | Record }> - | undefined; - const codecId = tables?.[tableName]?.columns?.[columnName]?.codecId; +): string { + const codecId = resolveColumnCodecId(context.contract, tableName, columnName); const traits = codecId ? context.codecs.traitsOf(codecId) : []; if (!traits.includes('equality')) { throw new Error( `Shorthand filter on "${modelName}.${fieldName}": field does not support equality comparisons`, ); } + if (!codecId) { + throw new Error( + `Shorthand filter on "${modelName}.${fieldName}": column "${columnName}" has no codec`, + ); + } + return codecId; } diff --git a/packages/3-extensions/sql-orm-client/src/grouped-collection.ts b/packages/3-extensions/sql-orm-client/src/grouped-collection.ts index d94bf2f539..f7ccea987a 100644 --- a/packages/3-extensions/sql-orm-client/src/grouped-collection.ts +++ b/packages/3-extensions/sql-orm-client/src/grouped-collection.ts @@ -5,7 +5,7 @@ import { BinaryExpr, type BinaryOp, ColumnRef, - LiteralExpr, + ParamRef, } from '@prisma-next/sql-relational-core/ast'; import { createAggregateBuilder, isAggregateSelector } from './aggregate-builder'; import { mapStorageRowToModelFields } from './collection-runtime'; @@ -156,7 +156,7 @@ function createHavingComparisonMethods( metric: AggregateExpr, ): HavingComparisonMethods { const buildBinaryExpr = (op: BinaryOp, value: unknown): AnyExpression => - new BinaryExpr(op, metric, LiteralExpr.of(value)); + new BinaryExpr(op, metric, ParamRef.of(value)); return { eq(value) { diff --git a/packages/3-extensions/sql-orm-client/src/model-accessor.ts b/packages/3-extensions/sql-orm-client/src/model-accessor.ts index 5d7f072b7a..b9c1890105 100644 --- a/packages/3-extensions/sql-orm-client/src/model-accessor.ts +++ b/packages/3-extensions/sql-orm-client/src/model-accessor.ts @@ -10,6 +10,7 @@ import { TableSource, } from '@prisma-next/sql-relational-core/ast'; import type { ExecutionContext } from '@prisma-next/sql-relational-core/query-lane-context'; +import { resolveColumnCodecId } from './collection-contract'; import { and, not } from './filters'; import { COMPARISON_METHODS_META, @@ -59,40 +60,43 @@ export function createModelAccessor< } const columnName = fieldToColumn[prop] ?? prop; - const traits = resolveFieldTraits(contract, tableName, columnName, context); - return createScalarFieldAccessor(tableName, columnName, traits); + const columnMeta = resolveColumnMeta(contract, tableName, columnName, context); + return createScalarFieldAccessor(tableName, columnName, columnMeta); }, }); } -function resolveFieldTraits( +interface ColumnMeta { + readonly codecId: string; + readonly traits: readonly string[]; +} + +function resolveColumnMeta( contract: SqlContract, tableName: string, columnName: string, context: ExecutionContext, -): readonly string[] { - const tables = contract.storage?.tables as - | Record }> - | undefined; - const codecId = tables?.[tableName]?.columns?.[columnName]?.codecId; - // unknown columns get no trait-gated methods - if (!codecId) return []; - return context.codecs.traitsOf(codecId); +): ColumnMeta | undefined { + const codecId = resolveColumnCodecId(contract, tableName, columnName); + if (!codecId) return undefined; + return { codecId, traits: context.codecs.traitsOf(codecId) }; } function createScalarFieldAccessor( tableName: string, columnName: string, - traits: readonly string[], + columnMeta: ColumnMeta | undefined, ): Partial> { const column = ColumnRef.of(tableName, columnName); const methods: Record = {}; + const traits = columnMeta?.traits ?? []; + const codecId = columnMeta?.codecId ?? ''; for (const [name, meta] of Object.entries(COMPARISON_METHODS_META)) { if (meta.traits.some((t) => !traits.includes(t))) { continue; } - methods[name] = meta.create(column); + methods[name] = meta.create(column, codecId); } return methods as Partial>; diff --git a/packages/3-extensions/sql-orm-client/src/mutation-executor.ts b/packages/3-extensions/sql-orm-client/src/mutation-executor.ts index 73ee223599..dcc67d3802 100644 --- a/packages/3-extensions/sql-orm-client/src/mutation-executor.ts +++ b/packages/3-extensions/sql-orm-client/src/mutation-executor.ts @@ -3,10 +3,14 @@ import { type AnyExpression, BinaryExpr, ColumnRef, - LiteralExpr, + ParamRef, } from '@prisma-next/sql-relational-core/ast'; import type { ExecutionContext } from '@prisma-next/sql-relational-core/query-lane-context'; -import { resolveModelTableName, resolvePrimaryKeyColumn } from './collection-contract'; +import { + resolveColumnCodecId, + resolveModelTableName, + resolvePrimaryKeyColumn, +} from './collection-contract'; import { acquireRuntimeScope, mapModelDataToStorageRow, @@ -486,7 +490,7 @@ async function applyChildOwnedMutation( } if (!mutation.criteria || mutation.criteria.length === 0) { - const parentJoinWhere = buildChildJoinWhere(relation, parentValues); + const parentJoinWhere = buildChildJoinWhere(contract, relation, parentValues); await executeUpdateCount(scope, contract, relation.relatedTableName, setValues, [ parentJoinWhere, ]); @@ -505,7 +509,7 @@ async function applyChildOwnedMutation( ); } - const parentJoinWhere = buildChildJoinWhere(relation, parentValues); + const parentJoinWhere = buildChildJoinWhere(contract, relation, parentValues); await executeUpdateCount(scope, contract, relation.relatedTableName, setValues, [ and(parentJoinWhere, criterionWhere), ]); @@ -542,16 +546,19 @@ function readParentColumnValues( } function buildChildJoinWhere( + contract: SqlContract, relation: RelationDefinition, childValues: Map, ): AnyExpression { const exprs: AnyExpression[] = []; + const tableName = relation.relatedTableName; for (const [childColumn, parentValue] of childValues.entries()) { + const codecId = resolveColumnCodecId(contract, tableName, childColumn); exprs.push( BinaryExpr.eq( - ColumnRef.of(relation.relatedTableName, childColumn), - LiteralExpr.of(parentValue), + ColumnRef.of(tableName, childColumn), + ParamRef.of(parentValue, { name: childColumn, ...(codecId ? { codecId } : {}) }), ), ); } diff --git a/packages/3-extensions/sql-orm-client/src/query-plan-aggregate.ts b/packages/3-extensions/sql-orm-client/src/query-plan-aggregate.ts index 4d44693236..5e02dc1bd3 100644 --- a/packages/3-extensions/sql-orm-client/src/query-plan-aggregate.ts +++ b/packages/3-extensions/sql-orm-client/src/query-plan-aggregate.ts @@ -29,13 +29,9 @@ function toAggregateExpr(tableName: string, selector: AggregateSelector return new AggregateExpr(selector.fn, ColumnRef.of(tableName, selector.column)); } -// ORM HAVING filters use literal binding (values inlined at plan-build time), -// not parameterized binding. ParamRef is rejected because the ORM's grouped -// collection API always produces literal comparisons for having() predicates. function validateGroupedComparable(value: AnyExpression): AnyExpression { switch (value.kind) { case 'param-ref': - throw new Error('ParamRef is not supported in grouped having expressions'); case 'literal': case 'column-ref': case 'identifier-ref': @@ -43,9 +39,6 @@ function validateGroupedComparable(value: AnyExpression): AnyExpression { case 'operation': return value; case 'list': - if (value.values.some((entry) => entry.kind === 'param-ref')) { - throw new Error('ParamRef is not supported in grouped having expressions'); - } return value; default: throw new Error(`Unsupported comparable kind in grouped having: "${value.kind}"`); diff --git a/packages/3-extensions/sql-orm-client/src/query-plan-select.ts b/packages/3-extensions/sql-orm-client/src/query-plan-select.ts index 3af1707dec..d6f6f74f42 100644 --- a/packages/3-extensions/sql-orm-client/src/query-plan-select.ts +++ b/packages/3-extensions/sql-orm-client/src/query-plan-select.ts @@ -12,18 +12,18 @@ import { JsonArrayAggExpr, JsonObjectExpr, ListExpression, - LiteralExpr, OrderByItem, OrExpr, + ParamRef, ProjectionItem, SelectAst, SubqueryExpr, TableSource, } from '@prisma-next/sql-relational-core/ast'; import type { SqlQueryPlan } from '@prisma-next/sql-relational-core/plan'; +import { resolveColumnCodecId } from './collection-contract'; import { buildOrmQueryPlan, deriveParamsFromAst, resolveTableColumns } from './query-plan-meta'; import type { CollectionState, IncludeExpr, OrderExpr } from './types'; -import { bindWhereExpr } from './where-binding'; import { combineWhereExprs } from './where-utils'; type CursorOrderEntry = OrderExpr & { @@ -57,16 +57,31 @@ function toOrderBy( ); } -function createBoundaryExpr(tableName: string, entry: CursorOrderEntry): AnyExpression { +function columnParam( + contract: SqlContract, + tableName: string, + column: string, + value: unknown, +): ParamRef { + const codecId = resolveColumnCodecId(contract, tableName, column); + return ParamRef.of(value, { name: column, ...(codecId ? { codecId } : {}) }); +} + +function createBoundaryExpr( + contract: SqlContract, + tableName: string, + entry: CursorOrderEntry, +): AnyExpression { const comparator: BinaryOp = entry.direction === 'asc' ? 'gt' : 'lt'; return new BinaryExpr( comparator, ColumnRef.of(tableName, entry.column), - LiteralExpr.of(entry.value), + columnParam(contract, tableName, entry.column, entry.value), ); } function buildLexicographicCursorWhere( + contract: SqlContract, tableName: string, entries: readonly CursorOrderEntry[], ): AnyExpression { @@ -77,12 +92,12 @@ function buildLexicographicCursorWhere( branchExprs.push( BinaryExpr.eq( ColumnRef.of(tableName, prefixEntry.column), - LiteralExpr.of(prefixEntry.value), + columnParam(contract, tableName, prefixEntry.column, prefixEntry.value), ), ); } - branchExprs.push(createBoundaryExpr(tableName, entry)); + branchExprs.push(createBoundaryExpr(contract, tableName, entry)); if (branchExprs.length === 1) { return branchExprs[0] as AnyExpression; } @@ -98,6 +113,7 @@ function buildLexicographicCursorWhere( } function buildCursorWhere( + contract: SqlContract, tableName: string, orderBy: readonly OrderExpr[] | undefined, cursor: Readonly> | undefined, @@ -120,10 +136,10 @@ function buildCursorWhere( const firstEntry = entries[0]; if (entries.length === 1 && firstEntry !== undefined) { - return createBoundaryExpr(tableName, firstEntry); + return createBoundaryExpr(contract, tableName, firstEntry); } - return buildLexicographicCursorWhere(tableName, entries); + return buildLexicographicCursorWhere(contract, tableName, entries); } function createTableRefRemapper(fromTable: string, toTable: string): AstRewriter { @@ -153,18 +169,17 @@ function buildStateWhere( ): AnyExpression | undefined { const filterTableName = options?.filterTableName; const cursorTableName = filterTableName ?? tableName; - const cursorWhere = buildCursorWhere(cursorTableName, state.orderBy, state.cursor); + const cursorWhere = buildCursorWhere(contract, cursorTableName, state.orderBy, state.cursor); const remappedFilters = filterTableName && filterTableName !== tableName ? state.filters.map((filter) => filter.rewrite(createTableRefRemapper(filterTableName, tableName)), ) : state.filters; - const boundCursorWhere = cursorWhere ? bindWhereExpr(contract, cursorWhere) : undefined; const remappedCursorWhere = - boundCursorWhere && filterTableName && filterTableName !== tableName - ? boundCursorWhere.rewrite(createTableRefRemapper(filterTableName, tableName)) - : boundCursorWhere; + cursorWhere && filterTableName && filterTableName !== tableName + ? cursorWhere.rewrite(createTableRefRemapper(filterTableName, tableName)) + : cursorWhere; const filters = remappedCursorWhere ? [...remappedFilters, remappedCursorWhere] : remappedFilters; return combineWhereExprs(filters); } @@ -406,7 +421,7 @@ export function compileRelationSelect( ): SqlQueryPlan> { const inFilter: AnyExpression = BinaryExpr.in( ColumnRef.of(relatedTableName, fkColumn), - ListExpression.fromValues(parentPks), + ListExpression.of(parentPks.map((pk) => columnParam(contract, relatedTableName, fkColumn, pk))), ); return compileSelect(contract, relatedTableName, { @@ -414,7 +429,7 @@ export function compileRelationSelect( includes: [], limit: undefined, offset: undefined, - filters: [bindWhereExpr(contract, inFilter), ...nestedState.filters], + filters: [inFilter, ...nestedState.filters], }); } diff --git a/packages/3-extensions/sql-orm-client/src/types.ts b/packages/3-extensions/sql-orm-client/src/types.ts index 8820a7f907..867cd64ba7 100644 --- a/packages/3-extensions/sql-orm-client/src/types.ts +++ b/packages/3-extensions/sql-orm-client/src/types.ts @@ -12,8 +12,8 @@ import { type CodecTrait, type ColumnRef, ListExpression, - LiteralExpr, NullCheckExpr, + ParamRef, } from '@prisma-next/sql-relational-core/ast'; import type { SqlQueryPlan } from '@prisma-next/sql-relational-core/plan'; import type { ExecutionContext } from '@prisma-next/sql-relational-core/query-lane-context'; @@ -181,12 +181,12 @@ export type ComparisonMethods = { // COMPARISON_METHODS_META — single source of truth for traits + factories // --------------------------------------------------------------------------- -function literal(value: unknown): LiteralExpr { - return LiteralExpr.of(value); +function param(column: ColumnRef, codecId: string, value: unknown): ParamRef { + return ParamRef.of(value, { name: column.column, codecId }); } -function listLiteral(values: readonly unknown[]): ListExpression { - return ListExpression.fromValues(values); +function paramList(column: ColumnRef, codecId: string, values: readonly unknown[]): ListExpression { + return ListExpression.of(values.map((v) => param(column, codecId, v))); } function bin(op: BinaryExpr['op'], column: ColumnRef, right: BinaryExpr['right']): BinaryExpr { @@ -196,7 +196,7 @@ function bin(op: BinaryExpr['op'], column: ColumnRef, right: BinaryExpr['right'] // never[] is intentional: factories have heterogeneous signatures (value: unknown, // values: readonly unknown[], pattern: string, etc.) but are only called through // the typed ComparisonMethodFns interface, never through this type directly. -type MethodFactory = (column: ColumnRef) => (...args: never[]) => unknown; +type MethodFactory = (column: ColumnRef, codecId: string) => (...args: never[]) => unknown; type ComparisonMethodMeta = { readonly traits: readonly CodecTrait[]; @@ -212,43 +212,53 @@ type ComparisonMethodMeta = { export const COMPARISON_METHODS_META = { eq: { traits: ['equality'], - create: (column) => (value: unknown) => bin('eq', column, literal(value)), + create: (column, codecId) => (value: unknown) => + bin('eq', column, param(column, codecId, value)), }, neq: { traits: ['equality'], - create: (column) => (value: unknown) => bin('neq', column, literal(value)), + create: (column, codecId) => (value: unknown) => + bin('neq', column, param(column, codecId, value)), }, in: { traits: ['equality'], - create: (column) => (values: readonly unknown[]) => bin('in', column, listLiteral(values)), + create: (column, codecId) => (values: readonly unknown[]) => + bin('in', column, paramList(column, codecId, values)), }, notIn: { traits: ['equality'], - create: (column) => (values: readonly unknown[]) => bin('notIn', column, listLiteral(values)), + create: (column, codecId) => (values: readonly unknown[]) => + bin('notIn', column, paramList(column, codecId, values)), }, gt: { traits: ['order'], - create: (column) => (value: unknown) => bin('gt', column, literal(value)), + create: (column, codecId) => (value: unknown) => + bin('gt', column, param(column, codecId, value)), }, lt: { traits: ['order'], - create: (column) => (value: unknown) => bin('lt', column, literal(value)), + create: (column, codecId) => (value: unknown) => + bin('lt', column, param(column, codecId, value)), }, gte: { traits: ['order'], - create: (column) => (value: unknown) => bin('gte', column, literal(value)), + create: (column, codecId) => (value: unknown) => + bin('gte', column, param(column, codecId, value)), }, lte: { traits: ['order'], - create: (column) => (value: unknown) => bin('lte', column, literal(value)), + create: (column, codecId) => (value: unknown) => + bin('lte', column, param(column, codecId, value)), }, like: { traits: ['textual'], - create: (column) => (pattern: string) => bin('like', column, literal(pattern)), + create: (column, codecId) => (pattern: string) => + bin('like', column, param(column, codecId, pattern)), }, ilike: { traits: ['textual'], - create: (column) => (pattern: string) => bin('ilike', column, literal(pattern)), + create: (column, codecId) => (pattern: string) => + bin('ilike', column, param(column, codecId, pattern)), }, asc: { traits: ['order'], diff --git a/packages/3-extensions/sql-orm-client/src/where-binding.ts b/packages/3-extensions/sql-orm-client/src/where-binding.ts deleted file mode 100644 index 9370711fdd..0000000000 --- a/packages/3-extensions/sql-orm-client/src/where-binding.ts +++ /dev/null @@ -1,190 +0,0 @@ -import type { SqlContract, SqlStorage } from '@prisma-next/sql-contract/types'; -import { - AndExpr, - type AnyExpression, - type AnyFromSource, - BinaryExpr, - type ColumnRef, - DerivedTableSource, - ExistsExpr, - type ExpressionRewriter, - JoinAst, - ListExpression, - NotExpr, - NullCheckExpr, - OrderByItem, - OrExpr, - ParamRef, - type ProjectionExpr, - ProjectionItem, - SelectAst, -} from '@prisma-next/sql-relational-core/ast'; - -export function bindWhereExpr( - contract: SqlContract, - expr: AnyExpression, -): AnyExpression { - return bindWhereExprNode(contract, expr); -} - -function bindWhereExprNode(contract: SqlContract, expr: AnyExpression): AnyExpression { - return expr.accept({ - columnRef(expr) { - return bindExpression(contract, expr); - }, - identifierRef(expr) { - return expr; - }, - subquery(expr) { - return bindExpression(contract, expr); - }, - operation(expr) { - return bindExpression(contract, expr); - }, - aggregate(expr) { - return bindExpression(contract, expr); - }, - jsonObject(expr) { - return bindExpression(contract, expr); - }, - jsonArrayAgg(expr) { - return bindExpression(contract, expr); - }, - literal(expr) { - return expr; - }, - param(expr) { - return expr; - }, - list(expr) { - return bindExpression(contract, expr); - }, - binary(expr) { - const left = bindExpression(contract, expr.left); - const bindingColumn = left.kind === 'column-ref' ? (left as ColumnRef) : undefined; - - return new BinaryExpr(expr.op, left, bindComparable(contract, expr.right, bindingColumn)); - }, - and(expr) { - return AndExpr.of(expr.exprs.map((part) => bindWhereExprNode(contract, part))); - }, - or(expr) { - return OrExpr.of(expr.exprs.map((part) => bindWhereExprNode(contract, part))); - }, - exists(expr) { - return expr.notExists - ? ExistsExpr.notExists(bindSelectAst(contract, expr.subquery)) - : ExistsExpr.exists(bindSelectAst(contract, expr.subquery)); - }, - nullCheck(expr) { - return expr.isNull - ? NullCheckExpr.isNull(bindExpression(contract, expr.expr)) - : NullCheckExpr.isNotNull(bindExpression(contract, expr.expr)); - }, - not(expr) { - return new NotExpr(bindWhereExprNode(contract, expr.expr)); - }, - }); -} - -function bindComparable( - contract: SqlContract, - comparable: AnyExpression, - bindingColumn: ColumnRef | undefined, -): AnyExpression { - if (comparable.kind === 'param-ref' || bindingColumn === undefined) { - return comparable.kind === 'param-ref' - ? comparable - : comparable.kind === 'literal' || comparable.kind === 'list' - ? comparable - : bindExpression(contract, comparable); - } - - if (comparable.kind === 'literal') { - return createParamRef(contract, bindingColumn, comparable.value); - } - - if (comparable.kind === 'list') { - return ListExpression.of( - comparable.values.map((value) => - value.kind === 'literal' ? createParamRef(contract, bindingColumn, value.value) : value, - ), - ); - } - - return bindExpression(contract, comparable); -} - -function createParamRef( - contract: SqlContract, - columnRef: ColumnRef, - value: unknown, -): ParamRef { - const codecId = contract.storage.tables[columnRef.table]?.columns[columnRef.column]?.codecId; - if (!codecId) { - throw new Error(`Unknown column "${columnRef.column}" in table "${columnRef.table}"`); - } - return ParamRef.of(value, { name: columnRef.column, codecId }); -} - -function createExpressionBinder(contract: SqlContract): ExpressionRewriter { - return { - select: (ast) => bindSelectAst(contract, ast), - }; -} - -function bindExpression(contract: SqlContract, expr: AnyExpression): AnyExpression { - return expr.rewrite(createExpressionBinder(contract)); -} - -function bindProjectionExpr( - contract: SqlContract, - expr: ProjectionExpr, -): ProjectionExpr { - return expr.kind === 'literal' ? expr : bindExpression(contract, expr); -} - -function bindOrderByItem(contract: SqlContract, orderItem: OrderByItem): OrderByItem { - return new OrderByItem(bindExpression(contract, orderItem.expr), orderItem.dir); -} - -function bindJoin(contract: SqlContract, join: JoinAst): JoinAst { - return new JoinAst( - join.joinType, - bindFromSource(contract, join.source), - join.on.kind === 'eq-col-join-on' ? join.on : bindWhereExprNode(contract, join.on), - join.lateral, - ); -} - -function bindFromSource(contract: SqlContract, source: AnyFromSource): AnyFromSource { - if (source.kind === 'table-source') { - return source; - } - if (source.kind === 'derived-table-source') { - const derived = source as DerivedTableSource; - return DerivedTableSource.as(derived.alias, bindSelectAst(contract, derived.query)); - } - - return source; -} - -function bindSelectAst(contract: SqlContract, ast: SelectAst): SelectAst { - return new SelectAst({ - from: bindFromSource(contract, ast.from), - joins: ast.joins?.map((join) => bindJoin(contract, join)), - projection: ast.projection.map( - (projection) => - new ProjectionItem(projection.alias, bindProjectionExpr(contract, projection.expr)), - ), - where: ast.where ? bindWhereExprNode(contract, ast.where) : undefined, - orderBy: ast.orderBy?.map((orderItem) => bindOrderByItem(contract, orderItem)), - distinct: ast.distinct, - distinctOn: ast.distinctOn?.map((expr) => bindExpression(contract, expr)), - groupBy: ast.groupBy?.map((expr) => bindExpression(contract, expr)), - having: ast.having ? bindWhereExprNode(contract, ast.having) : undefined, - limit: ast.limit, - offset: ast.offset, - selectAllIntent: ast.selectAllIntent, - }); -} diff --git a/packages/3-extensions/sql-orm-client/src/where-interop.ts b/packages/3-extensions/sql-orm-client/src/where-interop.ts index b9a5bdc134..22cbda87d8 100644 --- a/packages/3-extensions/sql-orm-client/src/where-interop.ts +++ b/packages/3-extensions/sql-orm-client/src/where-interop.ts @@ -1,23 +1,10 @@ -import type { SqlContract, SqlStorage } from '@prisma-next/sql-contract/types'; import type { AnyExpression, ToWhereExpr, WhereArg } from '@prisma-next/sql-relational-core/ast'; import { isWhereExpr } from '@prisma-next/sql-relational-core/ast'; -import { bindWhereExpr } from './where-binding'; - -interface NormalizeWhereArgOptions { - readonly contract?: SqlContract; -} export function normalizeWhereArg(arg: undefined): undefined; -export function normalizeWhereArg(arg: undefined, options: NormalizeWhereArgOptions): undefined; -export function normalizeWhereArg(arg: WhereArg, options?: NormalizeWhereArgOptions): AnyExpression; -export function normalizeWhereArg( - arg: WhereArg | undefined, - options?: NormalizeWhereArgOptions, -): AnyExpression | undefined; -export function normalizeWhereArg( - arg: WhereArg | undefined, - options?: NormalizeWhereArgOptions, -): AnyExpression | undefined { +export function normalizeWhereArg(arg: WhereArg): AnyExpression; +export function normalizeWhereArg(arg: WhereArg | undefined): AnyExpression | undefined; +export function normalizeWhereArg(arg: WhereArg | undefined): AnyExpression | undefined { if (arg === undefined) { return undefined; } @@ -31,9 +18,6 @@ export function normalizeWhereArg( return arg.toWhereExpr(); } - if (options?.contract) { - return bindWhereExpr(options.contract, arg); - } return arg; } diff --git a/packages/3-extensions/sql-orm-client/test/collection.state.test.ts b/packages/3-extensions/sql-orm-client/test/collection.state.test.ts index 937eb307ee..8f3229ec31 100644 --- a/packages/3-extensions/sql-orm-client/test/collection.state.test.ts +++ b/packages/3-extensions/sql-orm-client/test/collection.state.test.ts @@ -2,13 +2,11 @@ import { AndExpr, BinaryExpr, ColumnRef, - LiteralExpr, NullCheckExpr, ParamRef, type ToWhereExpr, } from '@prisma-next/sql-relational-core/ast'; import { describe, expect, it } from 'vitest'; -import { bindWhereExpr } from '../src/where-binding'; import { baseContract, createCollection, @@ -23,22 +21,22 @@ describe('Collection', () => { const filtered = collection.where((user) => user.name.eq('Alice')); expect(filtered.state.filters).toEqual([ - bindWhereExpr( - baseContract, - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), ), ]); expect(collection.state.filters).toEqual([]); const chained = filtered.where((user) => user.email.neq('old@example.com')); expect(chained.state.filters).toEqual([ - bindWhereExpr( - baseContract, - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), ), - bindWhereExpr( - baseContract, - BinaryExpr.neq(ColumnRef.of('users', 'email'), LiteralExpr.of('old@example.com')), + BinaryExpr.neq( + ColumnRef.of('users', 'email'), + ParamRef.of('old@example.com', { name: 'email', codecId: 'pg/text@1' }), ), ]); }); @@ -88,13 +86,13 @@ describe('Collection', () => { }); expect(filtered.state.filters).toEqual([ - bindWhereExpr( - baseContract, - AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), - NullCheckExpr.isNull(ColumnRef.of('users', 'email')), - ]), - ), + AndExpr.of([ + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), + ), + NullCheckExpr.isNull(ColumnRef.of('users', 'email')), + ]), ]); expect(collection.where({})).toBe(collection); @@ -153,9 +151,9 @@ describe('Collection', () => { cardinality: '1:N', }); expect(withPosts.state.includes[0]?.nested.filters).toEqual([ - bindWhereExpr( - baseContract, - BinaryExpr.gt(ColumnRef.of('posts', 'views'), LiteralExpr.of(100)), + BinaryExpr.gt( + ColumnRef.of('posts', 'views'), + ParamRef.of(100, { name: 'views', codecId: 'pg/int4@1' }), ), ]); expect(withPosts.state.includes[0]?.nested.limit).toBe(5); @@ -169,9 +167,9 @@ describe('Collection', () => { fn: 'count', }); expect(withPostCount.state.includes[0]?.scalar?.state.filters).toEqual([ - bindWhereExpr( - baseContract, - BinaryExpr.gt(ColumnRef.of('posts', 'views'), LiteralExpr.of(100)), + BinaryExpr.gt( + ColumnRef.of('posts', 'views'), + ParamRef.of(100, { name: 'views', codecId: 'pg/int4@1' }), ), ]); diff --git a/packages/3-extensions/sql-orm-client/test/filters.test.ts b/packages/3-extensions/sql-orm-client/test/filters.test.ts index 0a4b57b34e..66e8b1873d 100644 --- a/packages/3-extensions/sql-orm-client/test/filters.test.ts +++ b/packages/3-extensions/sql-orm-client/test/filters.test.ts @@ -3,10 +3,10 @@ import { BinaryExpr, ColumnRef, ListExpression, - LiteralExpr, NotExpr, NullCheckExpr, OrExpr, + ParamRef, } from '@prisma-next/sql-relational-core/ast'; import { describe, expect, it } from 'vitest'; import { all, and, not, or, shorthandToWhereExpr } from '../src/filters'; @@ -23,21 +23,38 @@ describe('filters', () => { const andExpr = and(user['name']!.eq('Alice'), user['email']!.neq('bob@example.com')); expect(andExpr).toEqual( AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), - BinaryExpr.neq(ColumnRef.of('users', 'email'), LiteralExpr.of('bob@example.com')), + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), + ), + BinaryExpr.neq( + ColumnRef.of('users', 'email'), + ParamRef.of('bob@example.com', { name: 'email', codecId: 'pg/text@1' }), + ), ]), ); const orExpr = or(user['name']!.eq('Alice'), user['name']!.eq('Bob')); expect(orExpr).toEqual( OrExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Bob')), + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), + ), + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Bob', { name: 'name', codecId: 'pg/text@1' }), + ), ]), ); expect(not(user['name']!.eq('Alice'))).toEqual( - new NotExpr(BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice'))), + new NotExpr( + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), + ), + ), ); expect(not(user['posts']!.some()).kind).toBe('not'); expect(not(user['email']!.isNull())).toEqual( @@ -48,10 +65,19 @@ describe('filters', () => { ).toEqual( new NotExpr( AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), + ), OrExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('a')), - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('b')), + BinaryExpr.eq( + ColumnRef.of('users', 'email'), + ParamRef.of('a', { name: 'email', codecId: 'pg/text@1' }), + ), + BinaryExpr.eq( + ColumnRef.of('users', 'email'), + ParamRef.of('b', { name: 'email', codecId: 'pg/text@1' }), + ), ]), ]), ), @@ -63,22 +89,58 @@ describe('filters', () => { const user = createModelAccessor(context, 'User'); expect(not(user['id']!.neq(1))).toEqual( - new NotExpr(BinaryExpr.neq(ColumnRef.of('users', 'id'), LiteralExpr.of(1))), + new NotExpr( + BinaryExpr.neq( + ColumnRef.of('users', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ), + ), ); expect(not(user['id']!.lt(1))).toEqual( - new NotExpr(BinaryExpr.lt(ColumnRef.of('users', 'id'), LiteralExpr.of(1))), + new NotExpr( + BinaryExpr.lt( + ColumnRef.of('users', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ), + ), ); expect(not(user['id']!.gte(1))).toEqual( - new NotExpr(BinaryExpr.gte(ColumnRef.of('users', 'id'), LiteralExpr.of(1))), + new NotExpr( + BinaryExpr.gte( + ColumnRef.of('users', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ), + ), ); expect(not(user['id']!.lte(1))).toEqual( - new NotExpr(BinaryExpr.lte(ColumnRef.of('users', 'id'), LiteralExpr.of(1))), + new NotExpr( + BinaryExpr.lte( + ColumnRef.of('users', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ), + ), ); expect(not(user['id']!.in([1, 2]))).toEqual( - new NotExpr(BinaryExpr.in(ColumnRef.of('users', 'id'), ListExpression.fromValues([1, 2]))), + new NotExpr( + BinaryExpr.in( + ColumnRef.of('users', 'id'), + ListExpression.of([ + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ParamRef.of(2, { name: 'id', codecId: 'pg/int4@1' }), + ]), + ), + ), ); expect(not(user['id']!.notIn([1, 2]))).toEqual( - new NotExpr(BinaryExpr.notIn(ColumnRef.of('users', 'id'), ListExpression.fromValues([1, 2]))), + new NotExpr( + BinaryExpr.notIn( + ColumnRef.of('users', 'id'), + ListExpression.of([ + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ParamRef.of(2, { name: 'id', codecId: 'pg/int4@1' }), + ]), + ), + ), ); }); @@ -86,10 +148,20 @@ describe('filters', () => { const user = createModelAccessor(context, 'User'); expect(not(user['name']!.like('%a%'))).toEqual( - new NotExpr(BinaryExpr.like(ColumnRef.of('users', 'name'), LiteralExpr.of('%a%'))), + new NotExpr( + BinaryExpr.like( + ColumnRef.of('users', 'name'), + ParamRef.of('%a%', { name: 'name', codecId: 'pg/text@1' }), + ), + ), ); expect(not(user['name']!.ilike('%a%'))).toEqual( - new NotExpr(BinaryExpr.ilike(ColumnRef.of('users', 'name'), LiteralExpr.of('%a%'))), + new NotExpr( + BinaryExpr.ilike( + ColumnRef.of('users', 'name'), + ParamRef.of('%a%', { name: 'name', codecId: 'pg/text@1' }), + ), + ), ); }); @@ -102,7 +174,10 @@ describe('filters', () => { expect(expr).toEqual( AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('posts', 'id'), LiteralExpr.of(1)), + BinaryExpr.eq( + ColumnRef.of('posts', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ), NullCheckExpr.isNull(ColumnRef.of('posts', 'user_id')), ]), ); @@ -124,7 +199,12 @@ describe('filters', () => { shorthandToWhereExpr({ ...context, contract: withoutModelToTable } as never, 'User', { email: 'alice@example.com', }), - ).toEqual(BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('alice@example.com'))); + ).toEqual( + BinaryExpr.eq( + ColumnRef.of('users', 'email'), + ParamRef.of('alice@example.com', { name: 'email', codecId: 'pg/text@1' }), + ), + ); const withoutMappings = { ...contract, diff --git a/packages/3-extensions/sql-orm-client/test/integration/group-by.test.ts b/packages/3-extensions/sql-orm-client/test/integration/group-by.test.ts index 3243ad5b09..5395e8f689 100644 --- a/packages/3-extensions/sql-orm-client/test/integration/group-by.test.ts +++ b/packages/3-extensions/sql-orm-client/test/integration/group-by.test.ts @@ -1,4 +1,4 @@ -import { AggregateExpr, BinaryExpr, LiteralExpr } from '@prisma-next/sql-relational-core/ast'; +import { AggregateExpr, BinaryExpr, ParamRef } from '@prisma-next/sql-relational-core/ast'; import { describe, expect, it } from 'vitest'; import { isSelectAst } from '../helpers'; import { createPostsCollection, timeouts, withCollectionRuntime } from './helpers'; @@ -63,7 +63,7 @@ describe('integration/groupBy', () => { if (!isSelectAst(ast)) { throw new Error('Expected grouped query to emit a select AST plan'); } - expect(ast.having).toEqual(BinaryExpr.gt(AggregateExpr.count(), LiteralExpr.of(1))); + expect(ast.having).toEqual(BinaryExpr.gt(AggregateExpr.count(), ParamRef.of(1))); }); }, timeouts.spinUpPpgDev, diff --git a/packages/3-extensions/sql-orm-client/test/model-accessor.test.ts b/packages/3-extensions/sql-orm-client/test/model-accessor.test.ts index 4e86c9c49e..929152236a 100644 --- a/packages/3-extensions/sql-orm-client/test/model-accessor.test.ts +++ b/packages/3-extensions/sql-orm-client/test/model-accessor.test.ts @@ -7,9 +7,9 @@ import { createCodecRegistry, ExistsExpr, ListExpression, - LiteralExpr, NotExpr, NullCheckExpr, + ParamRef, ProjectionItem, SelectAst, TableSource, @@ -21,45 +21,66 @@ import { getTestContext, getTestContract } from './helpers'; describe('createModelAccessor', () => { const context = getTestContext(); - function expectBinaryLiteral( + function expectBinaryParam( actual: unknown, table: string, column: string, op: BinaryExpr['op'], value: unknown, + codecId: string, ) { - expect(actual).toEqual(new BinaryExpr(op, ColumnRef.of(table, column), LiteralExpr.of(value))); + expect(actual).toEqual( + new BinaryExpr( + op, + ColumnRef.of(table, column), + ParamRef.of(value, { name: column, codecId }), + ), + ); } it('creates scalar comparison operators and maps fields to columns', () => { const user = createModelAccessor(context, 'User'); const post = createModelAccessor(context, 'Post'); - expectBinaryLiteral(user['name']!.eq('Alice'), 'users', 'name', 'eq', 'Alice'); - expectBinaryLiteral( + expectBinaryParam(user['name']!.eq('Alice'), 'users', 'name', 'eq', 'Alice', 'pg/text@1'); + expectBinaryParam( user['email']!.neq('test@example.com'), 'users', 'email', 'neq', 'test@example.com', + 'pg/text@1', ); - expectBinaryLiteral(post['views']!.gt(1000), 'posts', 'views', 'gt', 1000); - expectBinaryLiteral(post['views']!.lt(100), 'posts', 'views', 'lt', 100); - expectBinaryLiteral(post['id']!.gte(5), 'posts', 'id', 'gte', 5); - expectBinaryLiteral(post['id']!.lte(10), 'posts', 'id', 'lte', 10); - expectBinaryLiteral(post['userId']!.eq(42), 'posts', 'user_id', 'eq', 42); - expectBinaryLiteral(user['name']!.like('%Ali%'), 'users', 'name', 'like', '%Ali%'); - expectBinaryLiteral(user['name']!.ilike('%ali%'), 'users', 'name', 'ilike', '%ali%'); + expectBinaryParam(post['views']!.gt(1000), 'posts', 'views', 'gt', 1000, 'pg/int4@1'); + expectBinaryParam(post['views']!.lt(100), 'posts', 'views', 'lt', 100, 'pg/int4@1'); + expectBinaryParam(post['id']!.gte(5), 'posts', 'id', 'gte', 5, 'pg/int4@1'); + expectBinaryParam(post['id']!.lte(10), 'posts', 'id', 'lte', 10, 'pg/int4@1'); + expectBinaryParam(post['userId']!.eq(42), 'posts', 'user_id', 'eq', 42, 'pg/int4@1'); + expectBinaryParam(user['name']!.like('%Ali%'), 'users', 'name', 'like', '%Ali%', 'pg/text@1'); + expectBinaryParam(user['name']!.ilike('%ali%'), 'users', 'name', 'ilike', '%ali%', 'pg/text@1'); }); it('creates list literal, null check, and order directive helpers', () => { const accessor = createModelAccessor(context, 'Post'); expect(accessor['id']!.in([1, 2, 3])).toEqual( - BinaryExpr.in(ColumnRef.of('posts', 'id'), ListExpression.fromValues([1, 2, 3])), + BinaryExpr.in( + ColumnRef.of('posts', 'id'), + ListExpression.of([ + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ParamRef.of(2, { name: 'id', codecId: 'pg/int4@1' }), + ParamRef.of(3, { name: 'id', codecId: 'pg/int4@1' }), + ]), + ), ); expect(accessor['id']!.notIn([4, 5])).toEqual( - BinaryExpr.notIn(ColumnRef.of('posts', 'id'), ListExpression.fromValues([4, 5])), + BinaryExpr.notIn( + ColumnRef.of('posts', 'id'), + ListExpression.of([ + ParamRef.of(4, { name: 'id', codecId: 'pg/int4@1' }), + ParamRef.of(5, { name: 'id', codecId: 'pg/int4@1' }), + ]), + ), ); expect(accessor['id']!.asc()).toEqual({ column: 'id', direction: 'asc' }); expect(accessor['id']!.desc()).toEqual({ column: 'id', direction: 'desc' }); @@ -91,7 +112,10 @@ describe('createModelAccessor', () => { expect(noneExpr.subquery.where).toEqual( AndExpr.of([ BinaryExpr.eq(ColumnRef.of('posts', 'user_id'), ColumnRef.of('users', 'id')), - BinaryExpr.eq(ColumnRef.of('posts', 'views'), LiteralExpr.of(10)), + BinaryExpr.eq( + ColumnRef.of('posts', 'views'), + ParamRef.of(10, { name: 'views', codecId: 'pg/int4@1' }), + ), ]), ); @@ -100,7 +124,12 @@ describe('createModelAccessor', () => { expect(everyExpr.subquery.where).toEqual( AndExpr.of([ BinaryExpr.eq(ColumnRef.of('posts', 'user_id'), ColumnRef.of('users', 'id')), - new NotExpr(BinaryExpr.gt(ColumnRef.of('posts', 'views'), LiteralExpr.of(10))), + new NotExpr( + BinaryExpr.gt( + ColumnRef.of('posts', 'views'), + ParamRef.of(10, { name: 'views', codecId: 'pg/int4@1' }), + ), + ), ]), ); }); @@ -342,8 +371,14 @@ describe('createModelAccessor', () => { AndExpr.of([ BinaryExpr.eq(ColumnRef.of('posts', 'user_id'), ColumnRef.of('users', 'id')), AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('posts', 'title'), LiteralExpr.of('A')), - BinaryExpr.eq(ColumnRef.of('posts', 'views'), LiteralExpr.of(1)), + BinaryExpr.eq( + ColumnRef.of('posts', 'title'), + ParamRef.of('A', { name: 'title', codecId: 'pg/text@1' }), + ), + BinaryExpr.eq( + ColumnRef.of('posts', 'views'), + ParamRef.of(1, { name: 'views', codecId: 'pg/int4@1' }), + ), ]), ]), ); diff --git a/packages/3-extensions/sql-orm-client/test/mutation-executor.test.ts b/packages/3-extensions/sql-orm-client/test/mutation-executor.test.ts index 7f77c87c21..9d56038c5d 100644 --- a/packages/3-extensions/sql-orm-client/test/mutation-executor.test.ts +++ b/packages/3-extensions/sql-orm-client/test/mutation-executor.test.ts @@ -2,7 +2,7 @@ import { type AnyExpression, BinaryExpr, ColumnRef, - LiteralExpr, + ParamRef, } from '@prisma-next/sql-relational-core/ast'; import { describe, expect, it, vi } from 'vitest'; import { @@ -49,9 +49,15 @@ function withConnection(runtime: MockRuntime, onRelease: () => void) { }); } -const postIdFilter: AnyExpression = BinaryExpr.eq(ColumnRef.of('posts', 'id'), LiteralExpr.of(1)); +const postIdFilter: AnyExpression = BinaryExpr.eq( + ColumnRef.of('posts', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), +); -const userIdFilter: AnyExpression = BinaryExpr.eq(ColumnRef.of('users', 'id'), LiteralExpr.of(1)); +const userIdFilter: AnyExpression = BinaryExpr.eq( + ColumnRef.of('users', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), +); describe('mutation-executor', () => { it('hasNestedMutationCallbacks() detects callbacks only on relation fields', () => { diff --git a/packages/3-extensions/sql-orm-client/test/query-plan-aggregate.test.ts b/packages/3-extensions/sql-orm-client/test/query-plan-aggregate.test.ts index 5a0d25da0d..c7f129f496 100644 --- a/packages/3-extensions/sql-orm-client/test/query-plan-aggregate.test.ts +++ b/packages/3-extensions/sql-orm-client/test/query-plan-aggregate.test.ts @@ -22,7 +22,6 @@ import { } from '@prisma-next/sql-relational-core/ast'; import { describe, expect, it } from 'vitest'; import { compileAggregate, compileGroupedAggregate } from '../src/query-plan'; -import { bindWhereExpr } from '../src/where-binding'; import { baseContract } from './collection-fixtures'; const defaultAggSpec = { @@ -34,9 +33,9 @@ function compileWithHaving(having: AnyExpression) { } describe('query plan aggregate', () => { - const filteredViews = bindWhereExpr( - baseContract, - BinaryExpr.gte(ColumnRef.of('posts', 'views'), LiteralExpr.of(100)), + const filteredViews = BinaryExpr.gte( + ColumnRef.of('posts', 'views'), + ParamRef.of(100, { name: 'views', codecId: 'pg/int4@1' }), ); it('rejects empty aggregate specs and selectors without required fields', () => { @@ -67,38 +66,61 @@ describe('query plan aggregate', () => { ).toThrow('groupBy().aggregate() requires at least one aggregation selector'); }); - it('validates grouped having expressions before lowering them', () => { - const scalarSubquery = SelectAst.from(TableSource.named('posts')).withProjection([ - ProjectionItem.of('id', ColumnRef.of('posts', 'id')), - ]); + it('parameterizes ParamRef values in HAVING comparables', () => { + const plan = compileGroupedAggregate( + baseContract, + 'posts', + [], + ['user_id'], + { totalViews: { kind: 'aggregate', fn: 'sum', column: 'views' } }, + BinaryExpr.gte( + AggregateExpr.sum(ColumnRef.of('posts', 'views')), + ParamRef.of(1, { name: 'views', codecId: 'pg/int4@1' }), + ), + ); - expect(() => - compileGroupedAggregate( - baseContract, - 'posts', - [], - ['user_id'], - { totalViews: { kind: 'aggregate', fn: 'sum', column: 'views' } }, - BinaryExpr.gte( - AggregateExpr.sum(ColumnRef.of('posts', 'views')), - ParamRef.of(1, { name: 'views', codecId: 'pg/int4@1' }), - ), + const ast = plan.ast as SelectAst; + expect(ast.having).toEqual( + BinaryExpr.gte( + AggregateExpr.sum(ColumnRef.of('posts', 'views')), + ParamRef.of(1, { name: 'views', codecId: 'pg/int4@1' }), ), - ).toThrow('ParamRef is not supported in grouped having expressions'); + ); + expect(plan.params).toEqual([1]); + expect(plan.meta.paramDescriptors).toContainEqual({ + name: 'views', + codecId: 'pg/int4@1', + source: 'dsl', + }); + }); - expect(() => - compileGroupedAggregate( - baseContract, - 'posts', - [], - ['user_id'], - { totalViews: { kind: 'aggregate', fn: 'sum', column: 'views' } }, - BinaryExpr.in( - AggregateExpr.sum(ColumnRef.of('posts', 'views')), - ListExpression.of([ParamRef.of(1, { name: 'views', codecId: 'pg/int4@1' })]), - ), + it('parameterizes ParamRef values inside list HAVING comparables', () => { + const plan = compileGroupedAggregate( + baseContract, + 'posts', + [], + ['user_id'], + { totalViews: { kind: 'aggregate', fn: 'sum', column: 'views' } }, + BinaryExpr.in( + AggregateExpr.sum(ColumnRef.of('posts', 'views')), + ListExpression.of([ParamRef.of(1, { name: 'views', codecId: 'pg/int4@1' })]), ), - ).toThrow('ParamRef is not supported in grouped having expressions'); + ); + + const ast = plan.ast as SelectAst; + expect(ast.having).toEqual( + BinaryExpr.in( + AggregateExpr.sum(ColumnRef.of('posts', 'views')), + ListExpression.of([ParamRef.of(1, { name: 'views', codecId: 'pg/int4@1' })]), + ), + ); + expect(plan.params).toEqual([1]); + }); + + it('rejects EXISTS and non-aggregate metrics in HAVING', () => { + const scalarSubquery = SelectAst.from(TableSource.named('posts')).withProjection([ + ProjectionItem.of('id', ColumnRef.of('posts', 'id')), + ]); expect(() => compileGroupedAggregate( @@ -136,7 +158,7 @@ describe('query plan aggregate', () => { AndExpr.of([ BinaryExpr.in( AggregateExpr.sum(ColumnRef.of('posts', 'views')), - ListExpression.fromValues([1, 2]), + ListExpression.of([ParamRef.of(1), ParamRef.of(2)]), ), NullCheckExpr.isNotNull(AggregateExpr.sum(ColumnRef.of('posts', 'views'))), ]), @@ -149,7 +171,7 @@ describe('query plan aggregate', () => { AndExpr.of([ BinaryExpr.in( AggregateExpr.sum(ColumnRef.of('posts', 'views')), - ListExpression.of([LiteralExpr.of(1), LiteralExpr.of(2)]), + ListExpression.of([ParamRef.of(1), ParamRef.of(2)]), ), NullCheckExpr.isNotNull(AggregateExpr.sum(ColumnRef.of('posts', 'views'))), ]), @@ -171,7 +193,7 @@ describe('query plan aggregate', () => { AggregateExpr.sum(ColumnRef.of('posts', 'views')), ColumnRef.of('posts', 'views'), ), - BinaryExpr.gte(AggregateExpr.count(), LiteralExpr.of(5)), + BinaryExpr.gte(AggregateExpr.count(), ParamRef.of(5)), ]), ); @@ -282,7 +304,7 @@ describe('query plan aggregate', () => { expect(() => compileWithHaving( AndExpr.of([ - BinaryExpr.gte(AggregateExpr.count(), LiteralExpr.of(5)), + BinaryExpr.gte(AggregateExpr.count(), ParamRef.of(5)), ColumnRef.of('posts', 'views'), ]), ), @@ -292,10 +314,7 @@ describe('query plan aggregate', () => { it('rejects invalid expression inside OR', () => { expect(() => compileWithHaving( - OrExpr.of([ - BinaryExpr.gte(AggregateExpr.count(), LiteralExpr.of(5)), - LiteralExpr.of(true), - ]), + OrExpr.of([BinaryExpr.gte(AggregateExpr.count(), ParamRef.of(5)), LiteralExpr.of(true)]), ), ).toThrow('Unsupported grouped having expression kind "literal"'); }); @@ -310,7 +329,7 @@ describe('query plan aggregate', () => { describe('validateGroupedHavingExpr accepts valid predicate expressions', () => { it('accepts NOT wrapping a valid binary', () => { const plan = compileWithHaving( - new NotExpr(BinaryExpr.gte(AggregateExpr.count(), LiteralExpr.of(5))), + new NotExpr(BinaryExpr.gte(AggregateExpr.count(), ParamRef.of(5))), ); expect((plan.ast as SelectAst).having).toBeInstanceOf(NotExpr); }); @@ -326,8 +345,8 @@ describe('query plan aggregate', () => { const plan = compileWithHaving( new NotExpr( AndExpr.of([ - BinaryExpr.gte(AggregateExpr.count(), LiteralExpr.of(1)), - BinaryExpr.lte(AggregateExpr.sum(ColumnRef.of('posts', 'views')), LiteralExpr.of(100)), + BinaryExpr.gte(AggregateExpr.count(), ParamRef.of(1)), + BinaryExpr.lte(AggregateExpr.sum(ColumnRef.of('posts', 'views')), ParamRef.of(100)), ]), ), ); diff --git a/packages/3-extensions/sql-orm-client/test/query-plan-select.test.ts b/packages/3-extensions/sql-orm-client/test/query-plan-select.test.ts index f7b63c6688..acca7b0ee6 100644 --- a/packages/3-extensions/sql-orm-client/test/query-plan-select.test.ts +++ b/packages/3-extensions/sql-orm-client/test/query-plan-select.test.ts @@ -4,8 +4,8 @@ import { ColumnRef, DerivedTableSource, ListExpression, - LiteralExpr, OrExpr, + ParamRef, type SelectAst, SubqueryExpr, } from '@prisma-next/sql-relational-core/ast'; @@ -15,7 +15,6 @@ import { compileSelect, compileSelectWithIncludeStrategy, } from '../src/query-plan-select'; -import { bindWhereExpr } from '../src/where-binding'; import { baseContract, createCollection, createCollectionFor } from './collection-fixtures'; import { isSelectAst } from './helpers'; @@ -62,9 +61,9 @@ describe('compileSelectWithIncludeStrategy', () => { expectSelectAst(plan.ast); expect(plan.ast.where).toEqual( - bindWhereExpr( - baseContract, - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), + BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), ), ); @@ -78,9 +77,9 @@ describe('compileSelectWithIncludeStrategy', () => { expect(childRowsSource.query.where).toEqual( AndExpr.of([ BinaryExpr.eq(ColumnRef.of('posts', 'user_id'), ColumnRef.of('users', 'id')), - bindWhereExpr( - baseContract, - BinaryExpr.gte(ColumnRef.of('posts', 'views'), LiteralExpr.of(100)), + BinaryExpr.gte( + ColumnRef.of('posts', 'views'), + ParamRef.of(100, { name: 'views', codecId: 'pg/int4@1' }), ), ]), ); @@ -106,17 +105,17 @@ describe('compileSelectWithIncludeStrategy', () => { dslDescriptor('users', 'id'), ]); - const gtName = bindWhereExpr( - baseContract, - BinaryExpr.gt(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), + const gtName = BinaryExpr.gt( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), ); - const eqName = bindWhereExpr( - baseContract, - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), + const eqName = BinaryExpr.eq( + ColumnRef.of('users', 'name'), + ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), ); - const ltId = bindWhereExpr( - baseContract, - BinaryExpr.lt(ColumnRef.of('users', 'id'), LiteralExpr.of(7)), + const ltId = BinaryExpr.lt( + ColumnRef.of('users', 'id'), + ParamRef.of(7, { name: 'id', codecId: 'pg/int4@1' }), ); expect(plan.ast.where).toEqual(OrExpr.of([gtName, AndExpr.of([eqName, ltId])])); @@ -134,7 +133,10 @@ describe('compileSelectWithIncludeStrategy', () => { expect(plan.params).toEqual([9]); expect(plan.meta.paramDescriptors).toEqual([dslDescriptor('users', 'id')]); expect(plan.ast.where).toEqual( - bindWhereExpr(baseContract, BinaryExpr.gt(ColumnRef.of('users', 'id'), LiteralExpr.of(9))), + BinaryExpr.gt( + ColumnRef.of('users', 'id'), + ParamRef.of(9, { name: 'id', codecId: 'pg/int4@1' }), + ), ); const invalidState = { @@ -162,13 +164,16 @@ describe('compileSelectWithIncludeStrategy', () => { dslDescriptor('posts', 'title'), ]); - const inWhere = bindWhereExpr( - baseContract, - BinaryExpr.in(ColumnRef.of('posts', 'user_id'), ListExpression.fromValues([1, 2])), + const inWhere = BinaryExpr.in( + ColumnRef.of('posts', 'user_id'), + ListExpression.of([ + ParamRef.of(1, { name: 'user_id', codecId: 'pg/int4@1' }), + ParamRef.of(2, { name: 'user_id', codecId: 'pg/int4@1' }), + ]), ); - const titleWhere = bindWhereExpr( - baseContract, - BinaryExpr.eq(ColumnRef.of('posts', 'title'), LiteralExpr.of('Hello')), + const titleWhere = BinaryExpr.eq( + ColumnRef.of('posts', 'title'), + ParamRef.of('Hello', { name: 'title', codecId: 'pg/text@1' }), ); expect(plan.ast.where).toEqual(AndExpr.of([inWhere, titleWhere])); diff --git a/packages/3-extensions/sql-orm-client/test/rich-collection.test.ts b/packages/3-extensions/sql-orm-client/test/rich-collection.test.ts index e2dc451a08..5203d3553b 100644 --- a/packages/3-extensions/sql-orm-client/test/rich-collection.test.ts +++ b/packages/3-extensions/sql-orm-client/test/rich-collection.test.ts @@ -7,8 +7,7 @@ import { type ToWhereExpr, } from '@prisma-next/sql-relational-core/ast'; import { describe, expect, it } from 'vitest'; -import { bindWhereExpr } from '../src/where-binding'; -import { baseContract, createCollectionFor } from './collection-fixtures'; +import { createCollectionFor } from './collection-fixtures'; describe('SQL ORM collections with rich AST plans', () => { it('stores direct where expressions and bound where payloads in collection state', () => { @@ -16,9 +15,9 @@ describe('SQL ORM collections with rich AST plans', () => { const direct = collection.where(BinaryExpr.eq(ColumnRef.of('users', 'id'), LiteralExpr.of(1))); expect(direct.state.filters[0]).toBeInstanceOf(BinaryExpr); - expect( - bindWhereExpr(baseContract, BinaryExpr.eq(ColumnRef.of('users', 'id'), LiteralExpr.of(1))), - ).toEqual(direct.state.filters[0]); + expect(direct.state.filters[0]).toEqual( + BinaryExpr.eq(ColumnRef.of('users', 'id'), LiteralExpr.of(1)), + ); const bound = collection.where({ toWhereExpr: () => diff --git a/packages/3-extensions/sql-orm-client/test/rich-filters-and-where.test.ts b/packages/3-extensions/sql-orm-client/test/rich-filters-and-where.test.ts index 75f5abba63..cd599944f3 100644 --- a/packages/3-extensions/sql-orm-client/test/rich-filters-and-where.test.ts +++ b/packages/3-extensions/sql-orm-client/test/rich-filters-and-where.test.ts @@ -13,7 +13,7 @@ import { all, and, not, or } from '../src/filters'; import { createModelAccessor } from '../src/model-accessor'; import { normalizeWhereArg } from '../src/where-interop'; import { combineWhereExprs } from '../src/where-utils'; -import { getTestContext, getTestContract } from './helpers'; +import { getTestContext } from './helpers'; function collectParamValues(expr: AnyExpression): unknown[] { return expr.fold({ @@ -26,7 +26,6 @@ function collectParamValues(expr: AnyExpression): unknown[] { } describe('SQL ORM rich AST filters', () => { - const contract = getTestContract(); const context = getTestContext(); it('builds scalar and relation filters as AST instances', () => { @@ -42,7 +41,7 @@ describe('SQL ORM rich AST filters', () => { expect(nameFilter).toMatchObject({ op: 'eq', left: ColumnRef.of('users', 'name'), - right: LiteralExpr.of('Alice'), + right: ParamRef.of('Alice', { name: 'name', codecId: 'pg/text@1' }), }); expect(postsFilter?.kind).toBe('exists'); @@ -53,16 +52,13 @@ describe('SQL ORM rich AST filters', () => { }); it('normalizes, combines, and negates bound filters', () => { - const normalized = normalizeWhereArg( - { - toWhereExpr: () => - BinaryExpr.eq( - ColumnRef.of('users', 'id'), - ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), - ), - }, - { contract }, - ); + const normalized = normalizeWhereArg({ + toWhereExpr: () => + BinaryExpr.eq( + ColumnRef.of('users', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ), + }); expect(normalized.kind).toBe('binary'); expect(collectParamValues(normalized as BinaryExpr)).toEqual([1]); diff --git a/packages/3-extensions/sql-orm-client/test/rich-query-plans.test.ts b/packages/3-extensions/sql-orm-client/test/rich-query-plans.test.ts index bd9623e011..55aa0ae482 100644 --- a/packages/3-extensions/sql-orm-client/test/rich-query-plans.test.ts +++ b/packages/3-extensions/sql-orm-client/test/rich-query-plans.test.ts @@ -5,7 +5,6 @@ import { ColumnRef, type DerivedTableSource, type InsertAst, - LiteralExpr, ParamRef, type SelectAst, type SubqueryExpr, @@ -87,7 +86,12 @@ describe('SQL ORM rich AST query plans', () => { baseContract, 'users', { email: 'b@example.com' }, - [BinaryExpr.eq(ColumnRef.of('users', 'id'), LiteralExpr.of(1))], + [ + BinaryExpr.eq( + ColumnRef.of('users', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ), + ], ['id'], ); expect(updatePlan.ast.kind).toBe('update'); @@ -96,7 +100,12 @@ describe('SQL ORM rich AST query plans', () => { const deletePlan = compileDeleteReturning( baseContract, 'users', - [BinaryExpr.eq(ColumnRef.of('users', 'id'), LiteralExpr.of(1))], + [ + BinaryExpr.eq( + ColumnRef.of('users', 'id'), + ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' }), + ), + ], ['id'], ); expect(deletePlan.ast.kind).toBe('delete'); @@ -110,7 +119,7 @@ describe('SQL ORM rich AST query plans', () => { postCount: { kind: 'aggregate', fn: 'count' }, totalViews: { kind: 'aggregate', fn: 'sum', column: 'views' }, }, - BinaryExpr.gt(AggregateExpr.count(), LiteralExpr.of(1)), + BinaryExpr.gt(AggregateExpr.count(), ParamRef.of(1, { name: 'id', codecId: 'pg/int4@1' })), ); expect(groupedPlan.ast.kind).toBe('select'); const groupedAst = groupedPlan.ast as SelectAst; diff --git a/packages/3-extensions/sql-orm-client/test/where-binding.test.ts b/packages/3-extensions/sql-orm-client/test/where-binding.test.ts deleted file mode 100644 index 726d1a6c65..0000000000 --- a/packages/3-extensions/sql-orm-client/test/where-binding.test.ts +++ /dev/null @@ -1,549 +0,0 @@ -import { - AggregateExpr, - AndExpr, - BinaryExpr, - ColumnRef, - DerivedTableSource, - EqColJoinOn, - ExistsExpr, - IdentifierRef, - JoinAst, - JsonArrayAggExpr, - JsonObjectExpr, - ListExpression, - LiteralExpr, - NotExpr, - NullCheckExpr, - OperationExpr, - OrderByItem, - OrExpr, - ParamRef, - ProjectionItem, - SelectAst, - SubqueryExpr, - TableSource, -} from '@prisma-next/sql-relational-core/ast'; -import { describe, expect, it } from 'vitest'; -import { bindWhereExpr } from '../src/where-binding'; -import { getTestContract } from './helpers'; - -const subqueryWithLiteral = () => - SelectAst.from(TableSource.named('posts')) - .withProjection([ProjectionItem.of('id', ColumnRef.of('posts', 'id'))]) - .withWhere(BinaryExpr.eq(ColumnRef.of('posts', 'views'), LiteralExpr.of(100))); - -describe('bindWhereExpr', () => { - const contract = getTestContract(); - - it('binds a simple binary eq with a literal to a parameterized expression', () => { - const expr = BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('alice@test.com')); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('binary'); - const binary = bound as BinaryExpr; - expect(binary.right.kind).toBe('param-ref'); - const ref = binary.right as ParamRef; - expect(ref.value).toBe('alice@test.com'); - expect(ref.codecId).toBe('pg/text@1'); - }); - - it('binds AND expressions recursively', () => { - const expr = AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('a@test.com')), - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), - ]); - const bound = bindWhereExpr(contract, expr); - - const and = bound as AndExpr; - const andRight0 = (and.exprs[0] as BinaryExpr).right; - const andRight1 = (and.exprs[1] as BinaryExpr).right; - expect(andRight0.kind).toBe('param-ref'); - expect(andRight1.kind).toBe('param-ref'); - expect([(andRight0 as ParamRef).value, (andRight1 as ParamRef).value]).toEqual([ - 'a@test.com', - 'Alice', - ]); - expect((andRight0 as ParamRef).codecId).toBe('pg/text@1'); - expect((andRight1 as ParamRef).codecId).toBe('pg/text@1'); - }); - - it('binds OR expressions recursively', () => { - const expr = OrExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('a@test.com')), - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('b@test.com')), - ]); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('or'); - const or = bound as OrExpr; - const orRight0 = (or.exprs[0] as BinaryExpr).right; - const orRight1 = (or.exprs[1] as BinaryExpr).right; - expect(orRight0.kind).toBe('param-ref'); - expect(orRight1.kind).toBe('param-ref'); - expect([(orRight0 as ParamRef).value, (orRight1 as ParamRef).value]).toEqual([ - 'a@test.com', - 'b@test.com', - ]); - }); - - it('binds EXISTS subquery expressions and rebinds inner literals', () => { - const subquery = SelectAst.from(TableSource.named('posts')) - .withProjection([ProjectionItem.of('id', ColumnRef.of('posts', 'id'))]) - .withWhere( - AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('posts', 'user_id'), ColumnRef.of('users', 'id')), - BinaryExpr.gte(ColumnRef.of('posts', 'views'), LiteralExpr.of(100)), - ]), - ); - const expr = ExistsExpr.exists(subquery); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('exists'); - expect((bound as ExistsExpr).notExists).toBe(false); - const innerWhere = ((bound as ExistsExpr).subquery as SelectAst).where as AndExpr; - const viewsRight = (innerWhere.exprs[1] as BinaryExpr).right; - expect(viewsRight.kind).toBe('param-ref'); - expect((viewsRight as ParamRef).value).toBe(100); - expect((viewsRight as ParamRef).codecId).toBe('pg/int4@1'); - }); - - it('binds NOT EXISTS subquery expressions', () => { - const subquery = SelectAst.from(TableSource.named('posts')).withProjection([ - ProjectionItem.of('id', ColumnRef.of('posts', 'id')), - ]); - const expr = ExistsExpr.notExists(subquery); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('exists'); - expect((bound as ExistsExpr).notExists).toBe(true); - }); - - it('binds IS NULL null-check expressions', () => { - const expr = NullCheckExpr.isNull(ColumnRef.of('users', 'email')); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('null-check'); - expect((bound as NullCheckExpr).isNull).toBe(true); - }); - - it('binds IS NOT NULL null-check expressions', () => { - const expr = NullCheckExpr.isNotNull(ColumnRef.of('users', 'email')); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('null-check'); - expect((bound as NullCheckExpr).isNull).toBe(false); - }); - - it('binds IN with list literal values to parameterized refs', () => { - const expr = BinaryExpr.in( - ColumnRef.of('users', 'id'), - ListExpression.of([LiteralExpr.of(1), LiteralExpr.of(2)]), - ); - const bound = bindWhereExpr(contract, expr); - - const binary = bound as BinaryExpr; - expect(binary.right.kind).toBe('list'); - const list = binary.right as ListExpression; - expect(list.values).toMatchObject([{ kind: 'param-ref' }, { kind: 'param-ref' }]); - expect(list.values).toMatchObject([ - { value: 1, codecId: 'pg/int4@1' }, - { value: 2, codecId: 'pg/int4@1' }, - ]); - }); - - it('preserves ParamRef on the right side without rebinding', () => { - const existing = ParamRef.of(42, { name: 'id', codecId: 'pg/int4@1' }); - const expr = BinaryExpr.eq(ColumnRef.of('users', 'id'), existing); - const bound = bindWhereExpr(contract, expr); - - const binary = bound as BinaryExpr; - expect(binary.right).toBe(existing); - }); - - it('binds subquery within a select that has joins, orderBy, and derived sources', () => { - const inner = SelectAst.from(TableSource.named('posts')) - .withProjection([ProjectionItem.of('id', ColumnRef.of('posts', 'id'))]) - .withOrderBy([OrderByItem.asc(ColumnRef.of('posts', 'id'))]) - .withWhere(BinaryExpr.eq(ColumnRef.of('posts', 'user_id'), ColumnRef.of('users', 'id'))); - - const lateral = SelectAst.from(DerivedTableSource.as('p', inner)).withProjection([ - ProjectionItem.of('id', ColumnRef.of('p', 'id')), - ]); - - const main = SelectAst.from(TableSource.named('users')) - .withProjection([ProjectionItem.of('id', ColumnRef.of('users', 'id'))]) - .withJoins([ - JoinAst.left( - DerivedTableSource.as('lat', lateral), - EqColJoinOn.of(ColumnRef.of('users', 'id'), ColumnRef.of('lat', 'id')), - true, - ), - ]); - - const expr = ExistsExpr.exists(main); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('exists'); - }); - - it('handles binary expression with non-column left side and literal right', () => { - const subquery = SelectAst.from(TableSource.named('users')).withProjection([ - ProjectionItem.of('cnt', AggregateExpr.count()), - ]); - const expr = BinaryExpr.gt(SubqueryExpr.of(subquery), LiteralExpr.of(0)); - const bound = bindWhereExpr(contract, expr); - - const binary = bound as BinaryExpr; - expect(binary.right.kind).toBe('literal'); - }); - - it('handles binary expression with non-column left side and column right', () => { - const subquery = SelectAst.from(TableSource.named('users')).withProjection([ - ProjectionItem.of('cnt', AggregateExpr.count()), - ]); - const expr = BinaryExpr.gt(SubqueryExpr.of(subquery), ColumnRef.of('users', 'id')); - const bound = bindWhereExpr(contract, expr); - - const binary = bound as BinaryExpr; - expect(binary.right.kind).toBe('column-ref'); - }); - - it('binds EXISTS with a select that has HAVING, literal projections, and where-expr joins', () => { - const subquery = SelectAst.from(TableSource.named('users')) - .withProjection([ - ProjectionItem.of('email', ColumnRef.of('users', 'email')), - ProjectionItem.of('one', LiteralExpr.of(1)), - ]) - .withGroupBy([ColumnRef.of('users', 'email')]) - .withHaving(BinaryExpr.gt(AggregateExpr.count(), LiteralExpr.of(1))) - .withJoins([ - JoinAst.inner( - TableSource.named('posts'), - BinaryExpr.eq(ColumnRef.of('users', 'id'), ColumnRef.of('posts', 'user_id')), - ), - ]); - - const expr = ExistsExpr.exists(subquery); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('exists'); - }); - - it('passes through ParamRef values inside ListExpression without rebinding', () => { - const existing = ParamRef.of(99, { name: 'id', codecId: 'pg/int4@1' }); - const expr = BinaryExpr.in( - ColumnRef.of('users', 'id'), - ListExpression.of([existing, LiteralExpr.of(42)]), - ); - const bound = bindWhereExpr(contract, expr); - - const binary = bound as BinaryExpr; - const list = binary.right as ListExpression; - expect(list.values).toMatchObject([{ kind: 'param-ref' }, { kind: 'param-ref' }]); - expect(list.values[0]).toBe(existing); - expect(list.values).toMatchObject([{ value: 99 }, { value: 42 }]); - }); - - describe('leaf passthrough', () => { - it('passes through IdentifierRef unchanged', () => { - const expr = IdentifierRef.of('some_name'); - const bound = bindWhereExpr(contract, expr); - - expect(bound).toBe(expr); - }); - - it('passes through top-level LiteralExpr unchanged', () => { - const expr = LiteralExpr.of(42); - const bound = bindWhereExpr(contract, expr); - - expect(bound).toBe(expr); - }); - - it('passes through top-level ParamRef unchanged', () => { - const expr = ParamRef.of('hello', { name: 'x', codecId: 'pg/text@1' }); - const bound = bindWhereExpr(contract, expr); - - expect(bound).toBe(expr); - }); - }); - - describe('composite expression binding', () => { - it('binds inner SelectAst of SubqueryExpr', () => { - const expr = SubqueryExpr.of(subqueryWithLiteral()); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('subquery'); - const innerWhere = ((bound as SubqueryExpr).query as SelectAst).where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - expect((innerWhere.right as ParamRef).value).toBe(100); - expect((innerWhere.right as ParamRef).codecId).toBe('pg/int4@1'); - }); - - it('binds inner expressions of OperationExpr', () => { - const expr = OperationExpr.function({ - method: 'contains', - forTypeId: 'pg/text@1', - self: SubqueryExpr.of(subqueryWithLiteral()), - args: [], - returns: { kind: 'builtin', type: 'boolean' }, - template: 'position({1} in {0}) > 0', - }); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('operation'); - const op = bound as OperationExpr; - const innerQuery = (op.self as SubqueryExpr).query as SelectAst; - const innerWhere = innerQuery.where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - expect((innerWhere.right as ParamRef).codecId).toBe('pg/int4@1'); - }); - - it('binds inner expression of AggregateExpr', () => { - const expr = AggregateExpr.sum(SubqueryExpr.of(subqueryWithLiteral())); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('aggregate'); - const agg = bound as AggregateExpr; - const innerQuery = (agg.expr as SubqueryExpr).query as SelectAst; - const innerWhere = innerQuery.where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - }); - - it('binds inner expressions of JsonObjectExpr', () => { - const expr = JsonObjectExpr.fromEntries([ - JsonObjectExpr.entry('sub', SubqueryExpr.of(subqueryWithLiteral())), - ]); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('json-object'); - const json = bound as JsonObjectExpr; - const innerQuery = (json.entries[0]!.value as SubqueryExpr).query as SelectAst; - const innerWhere = innerQuery.where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - }); - - it('binds inner expression of JsonArrayAggExpr', () => { - const expr = JsonArrayAggExpr.of(SubqueryExpr.of(subqueryWithLiteral())); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('json-array-agg'); - const agg = bound as JsonArrayAggExpr; - const innerQuery = (agg.expr as SubqueryExpr).query as SelectAst; - const innerWhere = innerQuery.where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - }); - - it('binds inner expressions of top-level ListExpression', () => { - const expr = ListExpression.of([SubqueryExpr.of(subqueryWithLiteral())]); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('list'); - const list = bound as ListExpression; - const innerQuery = (list.values[0] as SubqueryExpr).query as SelectAst; - const innerWhere = innerQuery.where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - }); - }); - - describe('NotExpr', () => { - it('binds inner binary expression', () => { - const expr = new NotExpr( - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('test@test.com')), - ); - const bound = bindWhereExpr(contract, expr); - - expect(bound.kind).toBe('not'); - const inner = (bound as NotExpr).expr as BinaryExpr; - expect(inner.right.kind).toBe('param-ref'); - expect((inner.right as ParamRef).value).toBe('test@test.com'); - expect((inner.right as ParamRef).codecId).toBe('pg/text@1'); - }); - - it('binds NOT(AND(...)) recursively', () => { - const expr = new NotExpr( - AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('a@test.com')), - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), - ]), - ); - const bound = bindWhereExpr(contract, expr); - - const and = (bound as NotExpr).expr as AndExpr; - expect((and.exprs[0] as BinaryExpr).right.kind).toBe('param-ref'); - expect((and.exprs[1] as BinaryExpr).right.kind).toBe('param-ref'); - }); - }); - - describe('error handling', () => { - it('throws for unknown table', () => { - const expr = BinaryExpr.eq(ColumnRef.of('nonexistent', 'col'), LiteralExpr.of('x')); - expect(() => bindWhereExpr(contract, expr)).toThrow( - 'Unknown column "col" in table "nonexistent"', - ); - }); - - it('throws for unknown column', () => { - const expr = BinaryExpr.eq(ColumnRef.of('users', 'nonexistent'), LiteralExpr.of('x')); - expect(() => bindWhereExpr(contract, expr)).toThrow( - 'Unknown column "nonexistent" in table "users"', - ); - }); - }); - - describe('bindComparable edge cases', () => { - it('preserves column-ref on right when left is a column', () => { - const expr = BinaryExpr.eq(ColumnRef.of('users', 'id'), ColumnRef.of('posts', 'user_id')); - const bound = bindWhereExpr(contract, expr); - - const binary = bound as BinaryExpr; - expect(binary.right.kind).toBe('column-ref'); - }); - - it('rewrites aggregate on right via bindExpression when left is a column', () => { - const aggWithSubquery = AggregateExpr.sum(SubqueryExpr.of(subqueryWithLiteral())); - const expr = BinaryExpr.eq(ColumnRef.of('users', 'id'), aggWithSubquery); - const bound = bindWhereExpr(contract, expr); - - const binary = bound as BinaryExpr; - expect(binary.right.kind).toBe('aggregate'); - const innerQuery = ((binary.right as AggregateExpr).expr as SubqueryExpr).query as SelectAst; - const innerWhere = innerQuery.where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - }); - - it('rewrites non-literal/non-param right via bindExpression when left is not a column', () => { - const aggWithSubquery = AggregateExpr.sum(SubqueryExpr.of(subqueryWithLiteral())); - const expr = BinaryExpr.gt(AggregateExpr.count(), aggWithSubquery); - const bound = bindWhereExpr(contract, expr); - - const binary = bound as BinaryExpr; - expect(binary.right.kind).toBe('aggregate'); - const innerQuery = ((binary.right as AggregateExpr).expr as SubqueryExpr).query as SelectAst; - const innerWhere = innerQuery.where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - }); - }); - - describe('binary operators', () => { - it('neq binds literal to param', () => { - const expr = BinaryExpr.neq(ColumnRef.of('users', 'name'), LiteralExpr.of('Bob')); - const bound = bindWhereExpr(contract, expr) as BinaryExpr; - - expect(bound.op).toBe('neq'); - expect(bound.right.kind).toBe('param-ref'); - expect((bound.right as ParamRef).value).toBe('Bob'); - }); - - it('lt binds literal to param', () => { - const expr = BinaryExpr.lt(ColumnRef.of('posts', 'views'), LiteralExpr.of(50)); - const bound = bindWhereExpr(contract, expr) as BinaryExpr; - - expect(bound.op).toBe('lt'); - expect(bound.right.kind).toBe('param-ref'); - expect((bound.right as ParamRef).codecId).toBe('pg/int4@1'); - }); - - it('lte binds literal to param', () => { - const expr = BinaryExpr.lte(ColumnRef.of('posts', 'views'), LiteralExpr.of(50)); - const bound = bindWhereExpr(contract, expr) as BinaryExpr; - - expect(bound.op).toBe('lte'); - expect(bound.right.kind).toBe('param-ref'); - }); - - it('like binds literal to param', () => { - const expr = BinaryExpr.like(ColumnRef.of('users', 'name'), LiteralExpr.of('%alice%')); - const bound = bindWhereExpr(contract, expr) as BinaryExpr; - - expect(bound.op).toBe('like'); - expect(bound.right.kind).toBe('param-ref'); - expect((bound.right as ParamRef).value).toBe('%alice%'); - }); - - it('ilike binds literal to param', () => { - const expr = BinaryExpr.ilike(ColumnRef.of('users', 'name'), LiteralExpr.of('%alice%')); - const bound = bindWhereExpr(contract, expr) as BinaryExpr; - - expect(bound.op).toBe('ilike'); - expect(bound.right.kind).toBe('param-ref'); - }); - - it('notIn binds list literals to params', () => { - const expr = BinaryExpr.notIn( - ColumnRef.of('users', 'id'), - ListExpression.of([LiteralExpr.of(1), LiteralExpr.of(2)]), - ); - const bound = bindWhereExpr(contract, expr) as BinaryExpr; - - expect(bound.op).toBe('notIn'); - const list = bound.right as ListExpression; - expect(list.values).toMatchObject([ - { kind: 'param-ref', value: 1, codecId: 'pg/int4@1' }, - { kind: 'param-ref', value: 2, codecId: 'pg/int4@1' }, - ]); - }); - }); - - describe('SelectAst binding details', () => { - it('preserves and binds distinctOn expressions', () => { - const subquery = SelectAst.from(TableSource.named('posts')) - .withProjection([ProjectionItem.of('title', ColumnRef.of('posts', 'title'))]) - .withDistinctOn([ColumnRef.of('posts', 'user_id')]) - .withWhere(BinaryExpr.eq(ColumnRef.of('posts', 'views'), LiteralExpr.of(100))); - const expr = ExistsExpr.exists(subquery); - const bound = bindWhereExpr(contract, expr); - - const select = (bound as ExistsExpr).subquery as SelectAst; - expect(select.distinctOn).toHaveLength(1); - expect(select.distinctOn?.[0]?.kind).toBe('column-ref'); - const innerWhere = select.where as BinaryExpr; - expect(innerWhere.right.kind).toBe('param-ref'); - }); - - it('preserves limit and offset', () => { - const subquery = SelectAst.from(TableSource.named('posts')) - .withProjection([ProjectionItem.of('id', ColumnRef.of('posts', 'id'))]) - .withLimit(10) - .withOffset(5) - .withWhere(BinaryExpr.eq(ColumnRef.of('posts', 'views'), LiteralExpr.of(100))); - const expr = ExistsExpr.exists(subquery); - const bound = bindWhereExpr(contract, expr); - - const select = (bound as ExistsExpr).subquery as SelectAst; - expect(select.limit).toBe(10); - expect(select.offset).toBe(5); - }); - }); - - describe('nested logical expressions', () => { - it('binds AND inside OR', () => { - const expr = OrExpr.of([ - AndExpr.of([ - BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Alice')), - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('a@test.com')), - ]), - BinaryExpr.eq(ColumnRef.of('users', 'id'), LiteralExpr.of(1)), - ]); - const bound = bindWhereExpr(contract, expr); - - const or = bound as OrExpr; - const and = or.exprs[0] as AndExpr; - expect((and.exprs[0] as BinaryExpr).right.kind).toBe('param-ref'); - expect((and.exprs[1] as BinaryExpr).right.kind).toBe('param-ref'); - expect((or.exprs[1] as BinaryExpr).right.kind).toBe('param-ref'); - }); - - it('binds NOT inside AND', () => { - const expr = AndExpr.of([ - new NotExpr(BinaryExpr.eq(ColumnRef.of('users', 'name'), LiteralExpr.of('Bob'))), - BinaryExpr.eq(ColumnRef.of('users', 'email'), LiteralExpr.of('a@test.com')), - ]); - const bound = bindWhereExpr(contract, expr); - - const and = bound as AndExpr; - const not = and.exprs[0] as NotExpr; - expect((not.expr as BinaryExpr).right.kind).toBe('param-ref'); - expect((and.exprs[1] as BinaryExpr).right.kind).toBe('param-ref'); - }); - }); -});