From 0e29f3eb85d8828521c62b79013c9f37c8f2d708 Mon Sep 17 00:00:00 2001 From: rebelice Date: Thu, 14 May 2026 11:52:27 +0900 Subject: [PATCH] Fix PostgreSQL column default expression analysis --- pg/catalog/alter.go | 86 +++++++---- pg/catalog/alter_identity_test.go | 25 ++++ pg/catalog/alter_validation_test.go | 39 ++++- pg/catalog/analyze.go | 205 +++++++++++++++++++++++++-- pg/catalog/dryrun_validation_test.go | 187 ++++++++++++++++++++++++ pg/catalog/errors.go | 70 ++++----- pg/catalog/tablecmds.go | 39 +++-- 7 files changed, 568 insertions(+), 83 deletions(-) diff --git a/pg/catalog/alter.go b/pg/catalog/alter.go index 82d3946c..6ff55612 100644 --- a/pg/catalog/alter.go +++ b/pg/catalog/alter.go @@ -344,8 +344,20 @@ func (c *Catalog) execAlterTableCmd(schema *Schema, rel *Relation, relName strin if colDef.RawDefault != nil { col := rel.Columns[len(rel.Columns)-1] col.HasDefault = true - if analyzed, err := c.AnalyzeStandaloneExpr(colDef.RawDefault, rel); err == nil && analyzed != nil { - if coerced, cerr := c.coerceToTargetType(analyzed, analyzed.exprType(), col.TypeOID, 'i'); cerr == nil && coerced != nil { + analyzed, err := c.AnalyzeColumnDefaultExpr(colDef.RawDefault, rel) + if err != nil { + delete(rel.colByName, col.Name) + rel.Columns = rel.Columns[:len(rel.Columns)-1] + return err + } + if analyzed != nil { + coerced, cerr := c.coerceToTargetTypeWithFormat(analyzed, analyzed.exprType(), col.TypeOID, 'a', 'i') + if cerr != nil { + delete(rel.colByName, col.Name) + rel.Columns = rel.Columns[:len(rel.Columns)-1] + return cerr + } + if coerced != nil { analyzed = coerced } col.DefaultAnalyzed = analyzed @@ -393,20 +405,29 @@ func (c *Catalog) execAlterTableCmd(schema *Schema, rel *Relation, relName strin } // Analyze the default expression and coerce to column type. // pg: cookDefault uses COERCE_IMPLICIT_CAST ('i') as display format - if analyzed, err := c.AnalyzeStandaloneExpr(rawExpr, rel); err == nil && analyzed != nil { - if idx, exists := rel.colByName[atc.Name]; exists { - if coerced, cerr := c.coerceToTargetType(analyzed, analyzed.exprType(), rel.Columns[idx].TypeOID, 'i'); cerr == nil && coerced != nil { - analyzed = coerced - } - rel.Columns[idx].DefaultAnalyzed = analyzed - c.recordDependencyOnSingleRelExprForObject('r', rel.OID, int32(rel.Columns[idx].AttNum), analyzed, rel.OID, - DepNormal, DepNormal) + idx, col, err := c.columnDefaultTarget(rel, atc.Name) + if err != nil { + return err + } + analyzed, err := c.AnalyzeColumnDefaultExpr(rawExpr, rel) + if err != nil { + return err + } + if analyzed != nil { + coerced, cerr := c.coerceToTargetTypeWithFormat(analyzed, analyzed.exprType(), col.TypeOID, 'a', 'i') + if cerr != nil { + return cerr } - rte := c.buildRelationRTE(rel) - defStr := c.DeparseExpr(analyzed, []*RangeTableEntry{rte}, false) - if err := c.atSetDefault(rel, atc.Name, defStr); err != nil { - return err + if coerced != nil { + analyzed = coerced } + col.DefaultAnalyzed = analyzed + c.recordDependencyOnSingleRelExprForObject('r', rel.OID, int32(col.AttNum), analyzed, rel.OID, + DepNormal, DepNormal) + rte := c.buildRelationRTE(rel) + col.HasDefault = true + col.Default = c.DeparseExpr(analyzed, []*RangeTableEntry{rte}, false) + rel.Columns[idx] = col } return nil } @@ -1331,27 +1352,38 @@ func (c *Catalog) atDropNotNull(rel *Relation, colName string) error { // // pg: src/backend/commands/tablecmds.c — ATExecColumnDefault func (c *Catalog) atSetDefault(rel *Relation, colName, expr string) error { + _, col, err := c.columnDefaultTarget(rel, colName) + if err != nil { + return err + } + + col.HasDefault = true + col.Default = expr + return nil +} + +func (c *Catalog) columnDefaultTarget(rel *Relation, colName string) (int, *Column, error) { idx, exists := rel.colByName[colName] if !exists { - return errUndefinedColumn(colName) + return 0, nil, errUndefinedColumn(colName) } col := rel.Columns[idx] // Cannot set default on identity columns (use SET IDENTITY instead). if col.Identity != 0 { - return errInvalidObjectDefinition(fmt.Sprintf( - "column %q of relation %q is an identity column", colName, rel.Name)) + return 0, nil, &Error{ + Code: CodeSyntaxError, + Message: fmt.Sprintf("column %q of relation %q is an identity column", colName, rel.Name)} } // Cannot set default on generated columns. if col.Generated != 0 { - return errInvalidObjectDefinition(fmt.Sprintf( - "column %q of relation %q is a generated column", colName, rel.Name)) + return 0, nil, &Error{ + Code: CodeSyntaxError, + Message: fmt.Sprintf("column %q of relation %q is a generated column", colName, rel.Name)} } - col.HasDefault = true - col.Default = expr - return nil + return idx, col, nil } // atDropDefault removes a column's default expression. @@ -1366,14 +1398,16 @@ func (c *Catalog) atDropDefault(rel *Relation, colName string) error { // Cannot drop default on identity columns (use DROP IDENTITY instead). if col.Identity != 0 { - return errInvalidObjectDefinition(fmt.Sprintf( - "column %q of relation %q is an identity column", colName, rel.Name)) + return &Error{ + Code: CodeSyntaxError, + Message: fmt.Sprintf("column %q of relation %q is an identity column", colName, rel.Name)} } // Cannot drop default on generated columns (use DROP EXPRESSION instead). if col.Generated != 0 { - return errInvalidObjectDefinition(fmt.Sprintf( - "column %q of relation %q is a generated column", colName, rel.Name)) + return &Error{ + Code: CodeSyntaxError, + Message: fmt.Sprintf("column %q of relation %q is a generated column", colName, rel.Name)} } col.HasDefault = false diff --git a/pg/catalog/alter_identity_test.go b/pg/catalog/alter_identity_test.go index 428ffb8a..c0465f1a 100644 --- a/pg/catalog/alter_identity_test.go +++ b/pg/catalog/alter_identity_test.go @@ -472,6 +472,31 @@ func TestAlterAddColumnIdentityCreatesSequenceAndDefault(t *testing.T) { } } +func TestAlterAddColumnIdentityWithDefaultRejectsBeforeSequenceCreate(t *testing.T) { + c := New() + execIdentitySQL(t, c, `CREATE TABLE t (name text)`) + + results, err := c.Exec(`ALTER TABLE t ADD COLUMN id integer GENERATED ALWAYS AS IDENTITY DEFAULT name`, nil) + if err != nil { + t.Fatalf("Exec parse error: %v", err) + } + if len(results) != 1 || results[0].Error == nil { + t.Fatalf("expected ALTER TABLE error, got results=%v", results) + } + assertErrorCode(t, results[0].Error, CodeSyntaxError) + + rel := c.GetRelation("", "t") + if rel == nil { + t.Fatal("relation t not found") + } + if _, exists := rel.colByName["id"]; exists { + t.Fatal("column id should not be added after identity/default conflict") + } + if _, err := c.findSequence("", "t_id_seq"); err == nil { + t.Fatal("identity sequence public.t_id_seq should not be created after identity/default conflict") + } +} + func TestAlterAddColumnIdentityRejectsRegularInheritanceChildren(t *testing.T) { c := New() execIdentitySQL(t, c, ` diff --git a/pg/catalog/alter_validation_test.go b/pg/catalog/alter_validation_test.go index 9b78878f..32223b4f 100644 --- a/pg/catalog/alter_validation_test.go +++ b/pg/catalog/alter_validation_test.go @@ -21,7 +21,7 @@ func TestSetDefaultOnIdentityColumn(t *testing.T) { err := c.AlterTableStmt(makeAlterTableStmt("", "t", makeATSetDefault("id", "42"), )) - assertErrorCode(t, err, CodeInvalidObjectDefinition) + assertErrorCode(t, err, CodeSyntaxError) } func TestSetDefaultOnGeneratedColumn(t *testing.T) { @@ -37,7 +37,7 @@ func TestSetDefaultOnGeneratedColumn(t *testing.T) { err := c.AlterTableStmt(makeAlterTableStmt("", "t", makeATSetDefault("val", "1"), )) - assertErrorCode(t, err, CodeInvalidObjectDefinition) + assertErrorCode(t, err, CodeSyntaxError) } func TestDropDefaultOnIdentityColumn(t *testing.T) { @@ -53,7 +53,7 @@ func TestDropDefaultOnIdentityColumn(t *testing.T) { err := c.AlterTableStmt(makeAlterTableStmt("", "t", makeATDropDefault("id"), )) - assertErrorCode(t, err, CodeInvalidObjectDefinition) + assertErrorCode(t, err, CodeSyntaxError) } func TestDropDefaultOnGeneratedColumn(t *testing.T) { @@ -69,7 +69,38 @@ func TestDropDefaultOnGeneratedColumn(t *testing.T) { err := c.AlterTableStmt(makeAlterTableStmt("", "t", makeATDropDefault("val"), )) - assertErrorCode(t, err, CodeInvalidObjectDefinition) + assertErrorCode(t, err, CodeSyntaxError) +} + +func TestAddColumnDefaultTypeMismatchRollsBackColumn(t *testing.T) { + c := New() + results, err := c.Exec("CREATE TABLE t(id int)", nil) + if err != nil { + t.Fatalf("setup parse error: %v", err) + } + if len(results) != 1 || results[0].Error != nil { + t.Fatalf("setup error: %v", results) + } + + results, err = c.Exec("ALTER TABLE t ADD COLUMN a int DEFAULT true", nil) + if err != nil { + t.Fatalf("alter parse error: %v", err) + } + if len(results) != 1 || results[0].Error == nil { + t.Fatalf("expected ALTER TABLE error, got results=%v", results) + } + assertErrorCode(t, results[0].Error, CodeDatatypeMismatch) + + rel := c.GetRelation("", "t") + if rel == nil { + t.Fatal("relation t not found") + } + if _, exists := rel.colByName["a"]; exists { + t.Fatal("column a should not be added after default type mismatch") + } + if len(rel.Columns) != 1 { + t.Fatalf("columns length: got %d, want 1", len(rel.Columns)) + } } func TestAlterColumnTypeOnIdentityColumn(t *testing.T) { diff --git a/pg/catalog/analyze.go b/pg/catalog/analyze.go index 15a0f77f..43a4d8aa 100644 --- a/pg/catalog/analyze.go +++ b/pg/catalog/analyze.go @@ -405,12 +405,20 @@ type analyzeCtx struct { query *Query parent *analyzeCtx // for correlated subqueries disallowParentCols bool // true for non-LATERAL FROM subqueries that may see CTEs but not outer columns - domainConstraint bool // true when analyzing a domain CHECK constraint - domainBaseTypeOID uint32 // base type OID for domain VALUE keyword - domainBaseTypMod int32 // base type modifier for domain VALUE keyword - domainBaseCollation uint32 // base type collation for domain VALUE keyword + exprKind analyzeExprKind + domainConstraint bool // true when analyzing a domain CHECK constraint + domainBaseTypeOID uint32 // base type OID for domain VALUE keyword + domainBaseTypMod int32 // base type modifier for domain VALUE keyword + domainBaseCollation uint32 // base type collation for domain VALUE keyword } +type analyzeExprKind int + +const ( + analyzeExprKindStandalone analyzeExprKind = iota + analyzeExprKindColumnDefault +) + // transformFromClauseItem processes a FROM clause item. // // pg: src/backend/parser/parse_clause.c — transformFromClauseItem @@ -1338,6 +1346,12 @@ func (ac *analyzeCtx) transformColumnRef(cr *nodes.ColumnRef) (AnalyzedExpr, err tableName = stringVal(items[len(items)-2]) } + // pg: src/backend/parser/parse_expr.c — transformColumnRef + // EXPR_KIND_COLUMN_DEFAULT rejects ColumnRef before attempting resolution. + if ac.exprKind == analyzeExprKindColumnDefault { + return nil, errColumnReferenceInDefault() + } + // pg: src/backend/commands/typecmds.c — replace_domain_constraint_value // In domain CHECK constraints, "value" is replaced with CoerceToDomainValue. if ac.domainConstraint && tableName == "" && strings.EqualFold(colName, "value") { @@ -1600,6 +1614,9 @@ func (ac *analyzeCtx) transformFuncCall(fc *nodes.FuncCall) (AnalyzedExpr, error if fc.AggStar && funcName == "count" { // If OVER is present, it's a window function. if fc.Over != nil { + if ac.exprKind == analyzeExprKindColumnDefault { + return nil, errWindowFunctionInDefault() + } winRef, err := ac.resolveWindowDef(fc.Over) if err != nil { return nil, err @@ -1621,6 +1638,9 @@ func (ac *analyzeCtx) transformFuncCall(fc *nodes.FuncCall) (AnalyzedExpr, error WinRef: winRef, }, nil } + if ac.exprKind == analyzeExprKindColumnDefault { + return nil, errAggregateInDefault() + } return &AggExpr{ AggFuncOID: 0, AggName: "count", @@ -1657,6 +1677,9 @@ func (ac *analyzeCtx) transformFuncCall(fc *nodes.FuncCall) (AnalyzedExpr, error // Window function: has OVER clause. // pg: src/backend/parser/parse_func.c — ParseFuncOrColumn (window function path) if fc.Over != nil { + if ac.exprKind == analyzeExprKindColumnDefault { + return nil, errWindowFunctionInDefault() + } winRef, err := ac.resolveWindowDef(fc.Over) if err != nil { return nil, err @@ -1690,6 +1713,9 @@ func (ac *analyzeCtx) transformFuncCall(fc *nodes.FuncCall) (AnalyzedExpr, error // Determine if this is an aggregate. if proc.Kind == 'a' { + if ac.exprKind == analyzeExprKindColumnDefault { + return nil, errAggregateInDefault() + } var aggColl uint32 if ac.catalog.typeCollation(retType) != 0 { aggColl = resolveCollation(args...) @@ -1708,6 +1734,10 @@ func (ac *analyzeCtx) transformFuncCall(fc *nodes.FuncCall) (AnalyzedExpr, error }, nil } + if proc.RetSet && ac.exprKind == analyzeExprKindColumnDefault { + return nil, errSetReturningFunctionInDefault() + } + // Determine result collation: if result type is collatable, derive from args. var funcColl uint32 if ac.catalog.typeCollation(retType) != 0 { @@ -2089,6 +2119,10 @@ func (ac *analyzeCtx) transformCoalesceExpr(ce *nodes.CoalesceExpr) (AnalyzedExp // // pg: src/backend/parser/parse_expr.c — transformSubLink func (ac *analyzeCtx) transformSubLink(sl *nodes.SubLink) (AnalyzedExpr, error) { + if ac.exprKind == analyzeExprKindColumnDefault { + return nil, &Error{Code: CodeFeatureNotSupported, Message: "cannot use subquery in DEFAULT expression"} + } + sub, ok := sl.Subselect.(*nodes.SelectStmt) if !ok { return nil, fmt.Errorf("subquery is not a SELECT") @@ -3513,6 +3547,10 @@ func (c *Catalog) typeRelation(typeOID uint32) *Relation { // coerceToTargetType creates a coercion node for the given pathway. func (c *Catalog) coerceToTargetType(expr AnalyzedExpr, srcType, targetType uint32, context byte) (AnalyzedExpr, error) { + return c.coerceToTargetTypeWithFormat(expr, srcType, targetType, context, context) +} + +func (c *Catalog) coerceToTargetTypeWithFormat(expr AnalyzedExpr, srcType, targetType uint32, context, format byte) (AnalyzedExpr, error) { if srcType == targetType { return expr, nil } @@ -3532,7 +3570,7 @@ func (c *Catalog) coerceToTargetType(expr AnalyzedExpr, srcType, targetType uint Arg: expr, ResultType: targetType, TypeMod: -1, - Format: context, + Format: format, }, nil } @@ -3543,7 +3581,7 @@ func (c *Catalog) coerceToTargetType(expr AnalyzedExpr, srcType, targetType uint Arg: expr, ResultType: targetType, TypeMod: -1, - Format: context, + Format: format, }, nil case CoercionFunc: proc := c.procByOID[funcOID] @@ -3557,13 +3595,13 @@ func (c *Catalog) coerceToTargetType(expr AnalyzedExpr, srcType, targetType uint ResultType: targetType, ResultTypMod: -1, Args: []AnalyzedExpr{expr}, - CoerceFormat: context, + CoerceFormat: format, }, nil case CoercionIO: return &CoerceViaIOExpr{ Arg: expr, ResultType: targetType, - Format: context, + Format: format, }, nil default: return nil, errDatatypeMismatch(fmt.Sprintf( @@ -4509,6 +4547,154 @@ func (c *Catalog) AnalyzeStandaloneExpr(expr nodes.Node, rel *Relation) (Analyze return ac.transformExpr(expr) } +// AnalyzeColumnDefaultExpr analyzes a column DEFAULT expression using +// PostgreSQL's EXPR_KIND_COLUMN_DEFAULT semantics. +// +// pg: src/backend/catalog/heap.c — cookDefault +// pg: src/backend/parser/parse_expr.c — transformColumnRef +func (c *Catalog) AnalyzeColumnDefaultExpr(expr nodes.Node, rel *Relation) (AnalyzedExpr, error) { + if expr == nil { + return nil, nil + } + rte := c.buildRelationRTE(rel) + ac := &analyzeCtx{ + catalog: c, + query: &Query{RangeTable: []*RangeTableEntry{rte}}, + exprKind: analyzeExprKindColumnDefault, + } + analyzed, err := ac.transformExpr(expr) + if err != nil { + return nil, err + } + if exprContainsVarExpr(analyzed) { + return nil, errColumnReferenceInDefault() + } + return analyzed, nil +} + +func errColumnReferenceInDefault() error { + return &Error{Code: CodeFeatureNotSupported, Message: "cannot use column reference in DEFAULT expression"} +} + +func errAggregateInDefault() error { + return &Error{Code: CodeGroupingError, Message: "aggregate functions are not allowed in DEFAULT expressions"} +} + +func errWindowFunctionInDefault() error { + return &Error{Code: CodeWindowingError, Message: "window functions are not allowed in DEFAULT expressions"} +} + +func errSetReturningFunctionInDefault() error { + return &Error{Code: CodeFeatureNotSupported, Message: "set-returning functions are not allowed in DEFAULT expressions"} +} + +func errGroupingOperationInDefault() error { + return &Error{Code: CodeGroupingError, Message: "grouping operations are not allowed in DEFAULT expressions"} +} + +func exprContainsVarExpr(expr AnalyzedExpr) bool { + switch v := expr.(type) { + case nil: + return false + case *VarExpr: + return true + case *FuncCallExpr: + return exprListContainsVarExpr(v.Args) + case *AggExpr: + return exprListContainsVarExpr(v.Args) + case *OpExpr: + return exprContainsVarExpr(v.Left) || exprContainsVarExpr(v.Right) + case *RelabelExpr: + return exprContainsVarExpr(v.Arg) + case *CoerceViaIOExpr: + return exprContainsVarExpr(v.Arg) + case *CaseExprQ: + if exprContainsVarExpr(v.Arg) || exprContainsVarExpr(v.Default) { + return true + } + for _, w := range v.When { + if w != nil && (exprContainsVarExpr(w.Condition) || exprContainsVarExpr(w.Result)) { + return true + } + } + case *CoalesceExprQ: + return exprListContainsVarExpr(v.Args) + case *BoolExprQ: + return exprListContainsVarExpr(v.Args) + case *NullTestExpr: + return exprContainsVarExpr(v.Arg) + case *SubLinkExpr: + if exprContainsVarExpr(v.TestExpr) { + return true + } + if v.SubQuery != nil { + for _, te := range v.SubQuery.TargetList { + if te != nil && exprContainsVarExpr(te.Expr) { + return true + } + } + } + case *NullIfExprQ: + return exprListContainsVarExpr(v.Args) + case *MinMaxExprQ: + return exprListContainsVarExpr(v.Args) + case *BooleanTestExpr: + return exprContainsVarExpr(v.Arg) + case *DistinctExprQ: + return exprContainsVarExpr(v.Left) || exprContainsVarExpr(v.Right) + case *ScalarArrayOpExpr: + return exprContainsVarExpr(v.Left) || exprContainsVarExpr(v.Right) + case *ArrayExprQ: + return exprListContainsVarExpr(v.Elements) + case *RowExprQ: + return exprListContainsVarExpr(v.Args) + case *CollateExprQ: + return exprContainsVarExpr(v.Arg) + case *FieldSelectExprQ: + return exprContainsVarExpr(v.Arg) + case *WindowFuncExpr: + return exprListContainsVarExpr(v.Args) || exprContainsVarExpr(v.AggFilter) + case *SubscriptingRefExpr: + return exprContainsVarExpr(v.ContainerExpr) || + exprListContainsVarExpr(v.SubscriptExprs) || + exprListContainsVarExpr(v.LowerExprs) + case *NamedArgExprQ: + return exprContainsVarExpr(v.Arg) + case *ArrayCoerceExprQ: + return exprContainsVarExpr(v.Arg) || exprContainsVarExpr(v.ElemExpr) + case *CoerceToDomainExpr: + return exprContainsVarExpr(v.Arg) + case *RowCompareExprQ: + return exprListContainsVarExpr(v.LArgs) || exprListContainsVarExpr(v.RArgs) + case *ConvertRowtypeExprQ: + return exprContainsVarExpr(v.Arg) + case *FieldStoreExprQ: + return exprContainsVarExpr(v.Arg) || exprListContainsVarExpr(v.NewVals) + case *GroupingFuncExpr: + return exprListContainsVarExpr(v.Args) + case *XmlExprQ: + return exprListContainsVarExpr(v.NamedArgs) || exprListContainsVarExpr(v.Args) + case *JsonConstructorExprQ: + return exprListContainsVarExpr(v.Args) + case *JsonExprQ: + return exprContainsVarExpr(v.Expr) + case *JsonIsPredicateExpr: + return exprContainsVarExpr(v.Expr) + case *JsonValueExprQ: + return exprContainsVarExpr(v.Expr) + } + return false +} + +func exprListContainsVarExpr(exprs []AnalyzedExpr) bool { + for _, expr := range exprs { + if exprContainsVarExpr(expr) { + return true + } + } + return false +} + // AnalyzeDomainExpr analyzes a raw expression node in the context of a domain // type. The "VALUE" keyword is intercepted in transformColumnRef and replaced // with CoerceToDomainValueExpr. @@ -4674,6 +4860,9 @@ func (ac *analyzeCtx) transformGroupingFunc(gf *nodes.GroupingFunc) (AnalyzedExp } } } + if ac.exprKind == analyzeExprKindColumnDefault { + return nil, errGroupingOperationInDefault() + } return &GroupingFuncExpr{ Args: args, Refs: refs, diff --git a/pg/catalog/dryrun_validation_test.go b/pg/catalog/dryrun_validation_test.go index a5bc4e15..f8dfcf0f 100644 --- a/pg/catalog/dryrun_validation_test.go +++ b/pg/catalog/dryrun_validation_test.go @@ -138,6 +138,66 @@ func TestDryRunValidation_CreateTable(t *testing.T) { CodeDuplicateColumn, "specified more than once") }) + t.Run("DefaultDoubleQuotedIdentifierRejected", func(t *testing.T) { + dryRunExpectError(t, "", + `CREATE TABLE t (a text DEFAULT "OTHER")`, + CodeFeatureNotSupported, "cannot use column reference in DEFAULT expression") + }) + + t.Run("DefaultColumnReferenceRejected", func(t *testing.T) { + dryRunExpectError(t, "", + "CREATE TABLE t (a int, b int DEFAULT a)", + CodeFeatureNotSupported, "cannot use column reference in DEFAULT expression") + }) + + t.Run("DefaultStringLiteralOK", func(t *testing.T) { + dryRunExpectOK(t, "CREATE TABLE t (a text DEFAULT 'OTHER')") + }) + + t.Run("DefaultTypeMismatchRejected", func(t *testing.T) { + dryRunExpectError(t, "", + "CREATE TABLE t (a int DEFAULT true)", + CodeDatatypeMismatch, "") + }) + + for _, tc := range []struct { + name string + expr string + code string + msgContains string + }{ + { + name: "DefaultAggregateRejected", + expr: "count(*)", + code: CodeGroupingError, + msgContains: "aggregate functions are not allowed in DEFAULT expressions", + }, + { + name: "DefaultWindowFunctionRejected", + expr: "row_number() OVER ()", + code: CodeWindowingError, + msgContains: "window functions are not allowed in DEFAULT expressions", + }, + { + name: "DefaultSetReturningFunctionRejected", + expr: "generate_series(1, 2)", + code: CodeFeatureNotSupported, + msgContains: "set-returning functions are not allowed in DEFAULT expressions", + }, + { + name: "DefaultGroupingOperationRejected", + expr: "GROUPING(1)", + code: CodeGroupingError, + msgContains: "grouping operations are not allowed in DEFAULT expressions", + }, + } { + t.Run(tc.name, func(t *testing.T) { + dryRunExpectError(t, "", + "CREATE TABLE t (a int DEFAULT "+tc.expr+")", + tc.code, tc.msgContains) + }) + } + t.Run("UndefinedColumnType", func(t *testing.T) { dryRunExpectError(t, "", "CREATE TABLE t(id nosuchtype)", @@ -241,6 +301,59 @@ func TestDryRunValidation_AlterTable(t *testing.T) { CodeUndefinedObject, "") }) + t.Run("AddColumnDefaultDoubleQuotedIdentifierRejected", func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(id int)", + `ALTER TABLE t ADD COLUMN a text DEFAULT "OTHER"`, + CodeFeatureNotSupported, "cannot use column reference in DEFAULT expression") + }) + + t.Run("AddColumnDefaultTypeMismatchRejected", func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(id int)", + "ALTER TABLE t ADD COLUMN a int DEFAULT true", + CodeDatatypeMismatch, "") + }) + + for _, tc := range []struct { + name string + expr string + code string + msgContains string + }{ + { + name: "AddColumnDefaultAggregateRejected", + expr: "count(*)", + code: CodeGroupingError, + msgContains: "aggregate functions are not allowed in DEFAULT expressions", + }, + { + name: "AddColumnDefaultWindowFunctionRejected", + expr: "row_number() OVER ()", + code: CodeWindowingError, + msgContains: "window functions are not allowed in DEFAULT expressions", + }, + { + name: "AddColumnDefaultSetReturningFunctionRejected", + expr: "generate_series(1, 2)", + code: CodeFeatureNotSupported, + msgContains: "set-returning functions are not allowed in DEFAULT expressions", + }, + { + name: "AddColumnDefaultGroupingOperationRejected", + expr: "GROUPING(1)", + code: CodeGroupingError, + msgContains: "grouping operations are not allowed in DEFAULT expressions", + }, + } { + t.Run(tc.name, func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(id int)", + "ALTER TABLE t ADD COLUMN a int DEFAULT "+tc.expr, + tc.code, tc.msgContains) + }) + } + t.Run("AddColumnFKToMissing", func(t *testing.T) { dryRunExpectError(t, "CREATE TABLE t(id int)", @@ -269,6 +382,80 @@ func TestDryRunValidation_AlterTable(t *testing.T) { CodeUndefinedColumn, "") }) + t.Run("SetDefaultColumnNotExistsBeforeDefaultExpr", func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(id int)", + `ALTER TABLE t ALTER COLUMN nosuch SET DEFAULT "OTHER"`, + CodeUndefinedColumn, "") + }) + + t.Run("SetDefaultDoubleQuotedIdentifierRejected", func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(a text)", + `ALTER TABLE t ALTER COLUMN a SET DEFAULT "OTHER"`, + CodeFeatureNotSupported, "cannot use column reference in DEFAULT expression") + }) + + t.Run("SetDefaultTypeMismatchRejected", func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(a int)", + "ALTER TABLE t ALTER COLUMN a SET DEFAULT true", + CodeDatatypeMismatch, "") + }) + + for _, tc := range []struct { + name string + expr string + code string + msgContains string + }{ + { + name: "SetDefaultAggregateRejected", + expr: "count(*)", + code: CodeGroupingError, + msgContains: "aggregate functions are not allowed in DEFAULT expressions", + }, + { + name: "SetDefaultWindowFunctionRejected", + expr: "row_number() OVER ()", + code: CodeWindowingError, + msgContains: "window functions are not allowed in DEFAULT expressions", + }, + { + name: "SetDefaultSetReturningFunctionRejected", + expr: "generate_series(1, 2)", + code: CodeFeatureNotSupported, + msgContains: "set-returning functions are not allowed in DEFAULT expressions", + }, + { + name: "SetDefaultGroupingOperationRejected", + expr: "GROUPING(1)", + code: CodeGroupingError, + msgContains: "grouping operations are not allowed in DEFAULT expressions", + }, + } { + t.Run(tc.name, func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(a int)", + "ALTER TABLE t ALTER COLUMN a SET DEFAULT "+tc.expr, + tc.code, tc.msgContains) + }) + } + + t.Run("SetDefaultIdentityColumnBeforeDefaultExpr", func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(id int GENERATED ALWAYS AS IDENTITY, name text)", + "ALTER TABLE t ALTER COLUMN id SET DEFAULT name", + CodeSyntaxError, "identity column") + }) + + t.Run("SetDefaultGeneratedColumnBeforeDefaultExpr", func(t *testing.T) { + dryRunExpectError(t, + "CREATE TABLE t(a int, b int GENERATED ALWAYS AS (a + 1) STORED)", + "ALTER TABLE t ALTER COLUMN b SET DEFAULT a", + CodeSyntaxError, "generated column") + }) + t.Run("SetNotNullColumnNotExists", func(t *testing.T) { dryRunExpectError(t, "CREATE TABLE t(id int)", diff --git a/pg/catalog/errors.go b/pg/catalog/errors.go index 95b2e79a..7bb8d381 100644 --- a/pg/catalog/errors.go +++ b/pg/catalog/errors.go @@ -17,41 +17,43 @@ const ( // SQLSTATE error codes matching PostgreSQL. const ( - CodeDuplicateSchema = "42P06" - CodeDuplicateTable = "42P07" - CodeDuplicateColumn = "42701" - CodeDuplicateObject = "42710" - CodeUndefinedSchema = "3F000" - CodeUndefinedTable = "42P01" - CodeUndefinedColumn = "42703" - CodeUndefinedObject = "42704" - CodeSchemaNotEmpty = "2BP01" - CodeDependentObjects = "2BP01" - CodeWrongObjectType = "42809" - CodeInvalidParameterValue = "22023" - CodeInvalidFK = "42830" - CodeInvalidTableDefinition = "42P16" - CodeDuplicatePKey = "42P16" // same SQLSTATE as InvalidTableDefinition - CodeDatatypeMismatch = "42804" - CodeUndefinedFunction = "42883" - CodeAmbiguousColumn = "42702" - CodeAmbiguousFunction = "42725" - CodeInvalidColumnDefinition = "42611" - CodeTooManyColumns = "54011" - CodeFeatureNotSupported = "0A000" - CodeDuplicateFunction = "42723" - CodeInvalidObjectDefinition = "42P17" - CodeSyntaxError = "42601" - CodeInvalidFunctionDefinition = "42P13" - CodeCheckViolation = "23514" - CodeNotNullViolation = "23502" - CodeForeignKeyViolation = "23503" - CodeUniqueViolation = "23505" - CodeIndeterminateCollation = "42P22" + CodeDuplicateSchema = "42P06" + CodeDuplicateTable = "42P07" + CodeDuplicateColumn = "42701" + CodeDuplicateObject = "42710" + CodeUndefinedSchema = "3F000" + CodeUndefinedTable = "42P01" + CodeUndefinedColumn = "42703" + CodeUndefinedObject = "42704" + CodeSchemaNotEmpty = "2BP01" + CodeDependentObjects = "2BP01" + CodeWrongObjectType = "42809" + CodeInvalidParameterValue = "22023" + CodeInvalidFK = "42830" + CodeInvalidTableDefinition = "42P16" + CodeDuplicatePKey = "42P16" // same SQLSTATE as InvalidTableDefinition + CodeDatatypeMismatch = "42804" + CodeUndefinedFunction = "42883" + CodeAmbiguousColumn = "42702" + CodeAmbiguousFunction = "42725" + CodeGroupingError = "42803" + CodeWindowingError = "42P20" + CodeInvalidColumnDefinition = "42611" + CodeTooManyColumns = "54011" + CodeFeatureNotSupported = "0A000" + CodeDuplicateFunction = "42723" + CodeInvalidObjectDefinition = "42P17" + CodeSyntaxError = "42601" + CodeInvalidFunctionDefinition = "42P13" + CodeCheckViolation = "23514" + CodeNotNullViolation = "23502" + CodeForeignKeyViolation = "23503" + CodeUniqueViolation = "23505" + CodeIndeterminateCollation = "42P22" CodeObjectNotInPrerequisiteState = "55000" - CodeInvalidGrantOperation = "0LP01" - CodeProgramLimitExceeded = "54000" - CodeReservedName = "42939" + CodeInvalidGrantOperation = "0LP01" + CodeProgramLimitExceeded = "54000" + CodeReservedName = "42939" ) // Error represents a PostgreSQL-compatible error with an SQLSTATE code. diff --git a/pg/catalog/tablecmds.go b/pg/catalog/tablecmds.go index ed9de51b..6e3fca4e 100644 --- a/pg/catalog/tablecmds.go +++ b/pg/catalog/tablecmds.go @@ -752,18 +752,28 @@ func (c *Catalog) DefineRelation(stmt *nodes.CreateStmt, relkind byte) error { // pg: src/backend/commands/tablecmds.c — cookDefault / cookConstraint for i, cd := range colDefs { if cd.RawDefault != nil && columns[i].HasDefault { - if analyzed, err := c.AnalyzeStandaloneExpr(cd.RawDefault, rel); err == nil && analyzed != nil { - // pg: cookDefault uses COERCE_IMPLICIT_CAST ('i') as display format - coerced, cerr := c.coerceToTargetType(analyzed, analyzed.exprType(), columns[i].TypeOID, 'i') - if cerr == nil && coerced != nil { - analyzed = coerced - } - columns[i].DefaultAnalyzed = analyzed - c.recordDependencyOnSingleRelExprForObject('r', rel.OID, int32(columns[i].AttNum), analyzed, rel.OID, - DepNormal, DepNormal) - rte := c.buildRelationRTE(rel) - columns[i].Default = c.DeparseExpr(analyzed, []*RangeTableEntry{rte}, false) + analyzed, err := c.AnalyzeColumnDefaultExpr(cd.RawDefault, rel) + if err != nil { + c.removeRelation(schema, relName, rel) + return err + } + if analyzed == nil { + continue + } + // pg: cookDefault uses COERCE_IMPLICIT_CAST ('i') as display format + coerced, cerr := c.coerceToTargetTypeWithFormat(analyzed, analyzed.exprType(), columns[i].TypeOID, 'a', 'i') + if cerr != nil { + c.removeRelation(schema, relName, rel) + return cerr } + if coerced != nil { + analyzed = coerced + } + columns[i].DefaultAnalyzed = analyzed + c.recordDependencyOnSingleRelExprForObject('r', rel.OID, int32(columns[i].AttNum), analyzed, rel.OID, + DepNormal, DepNormal) + rte := c.buildRelationRTE(rel) + columns[i].Default = c.DeparseExpr(analyzed, []*RangeTableEntry{rte}, false) } if cd.RawGenExpr != nil && columns[i].Generated == 's' { if rawExprContainsSubLink(cd.RawGenExpr) { @@ -1053,6 +1063,13 @@ func (c *Catalog) convertColumnDef(cd *nodes.ColumnDef, relName string, schema * } } + if result.Identity != 0 && (result.RawDefault != nil || result.Default != "") { + return ColumnDef{}, nil, &Error{ + Code: CodeSyntaxError, + Message: fmt.Sprintf("both default and identity specified for column %q of table %q", result.Name, relName), + } + } + return result, cons, nil }