Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 60 additions & 26 deletions pg/catalog/alter.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,20 @@ func (c *Catalog) execAlterTableCmd(schema *Schema, rel *Relation, relName strin
if colDef.RawDefault != nil {
col := rel.Columns[len(rel.Columns)-1]
col.HasDefault = true
if analyzed, err := c.AnalyzeStandaloneExpr(colDef.RawDefault, rel); err == nil && analyzed != nil {
if coerced, cerr := c.coerceToTargetType(analyzed, analyzed.exprType(), col.TypeOID, 'i'); cerr == nil && coerced != nil {
analyzed, err := c.AnalyzeColumnDefaultExpr(colDef.RawDefault, rel)
if err != nil {
delete(rel.colByName, col.Name)
rel.Columns = rel.Columns[:len(rel.Columns)-1]
return err
}
if analyzed != nil {
coerced, cerr := c.coerceToTargetTypeWithFormat(analyzed, analyzed.exprType(), col.TypeOID, 'a', 'i')
if cerr != nil {
delete(rel.colByName, col.Name)
rel.Columns = rel.Columns[:len(rel.Columns)-1]
return cerr
}
if coerced != nil {
analyzed = coerced
}
col.DefaultAnalyzed = analyzed
Expand Down Expand Up @@ -393,20 +405,29 @@ func (c *Catalog) execAlterTableCmd(schema *Schema, rel *Relation, relName strin
}
// Analyze the default expression and coerce to column type.
// pg: cookDefault uses COERCE_IMPLICIT_CAST ('i') as display format
if analyzed, err := c.AnalyzeStandaloneExpr(rawExpr, rel); err == nil && analyzed != nil {
if idx, exists := rel.colByName[atc.Name]; exists {
if coerced, cerr := c.coerceToTargetType(analyzed, analyzed.exprType(), rel.Columns[idx].TypeOID, 'i'); cerr == nil && coerced != nil {
analyzed = coerced
}
rel.Columns[idx].DefaultAnalyzed = analyzed
c.recordDependencyOnSingleRelExprForObject('r', rel.OID, int32(rel.Columns[idx].AttNum), analyzed, rel.OID,
DepNormal, DepNormal)
idx, col, err := c.columnDefaultTarget(rel, atc.Name)
if err != nil {
return err
}
analyzed, err := c.AnalyzeColumnDefaultExpr(rawExpr, rel)
if err != nil {
return err
}
if analyzed != nil {
coerced, cerr := c.coerceToTargetTypeWithFormat(analyzed, analyzed.exprType(), col.TypeOID, 'a', 'i')
if cerr != nil {
return cerr
}
rte := c.buildRelationRTE(rel)
defStr := c.DeparseExpr(analyzed, []*RangeTableEntry{rte}, false)
if err := c.atSetDefault(rel, atc.Name, defStr); err != nil {
return err
if coerced != nil {
analyzed = coerced
}
col.DefaultAnalyzed = analyzed
c.recordDependencyOnSingleRelExprForObject('r', rel.OID, int32(col.AttNum), analyzed, rel.OID,
DepNormal, DepNormal)
rte := c.buildRelationRTE(rel)
col.HasDefault = true
col.Default = c.DeparseExpr(analyzed, []*RangeTableEntry{rte}, false)
rel.Columns[idx] = col
}
return nil
}
Expand Down Expand Up @@ -1331,27 +1352,38 @@ func (c *Catalog) atDropNotNull(rel *Relation, colName string) error {
//
// pg: src/backend/commands/tablecmds.c — ATExecColumnDefault
func (c *Catalog) atSetDefault(rel *Relation, colName, expr string) error {
_, col, err := c.columnDefaultTarget(rel, colName)
if err != nil {
return err
}

col.HasDefault = true
col.Default = expr
return nil
}

func (c *Catalog) columnDefaultTarget(rel *Relation, colName string) (int, *Column, error) {
idx, exists := rel.colByName[colName]
if !exists {
return errUndefinedColumn(colName)
return 0, nil, errUndefinedColumn(colName)
}
col := rel.Columns[idx]

// Cannot set default on identity columns (use SET IDENTITY instead).
if col.Identity != 0 {
return errInvalidObjectDefinition(fmt.Sprintf(
"column %q of relation %q is an identity column", colName, rel.Name))
return 0, nil, &Error{
Code: CodeSyntaxError,
Message: fmt.Sprintf("column %q of relation %q is an identity column", colName, rel.Name)}
}

// Cannot set default on generated columns.
if col.Generated != 0 {
return errInvalidObjectDefinition(fmt.Sprintf(
"column %q of relation %q is a generated column", colName, rel.Name))
return 0, nil, &Error{
Code: CodeSyntaxError,
Message: fmt.Sprintf("column %q of relation %q is a generated column", colName, rel.Name)}
}

col.HasDefault = true
col.Default = expr
return nil
return idx, col, nil
}

// atDropDefault removes a column's default expression.
Expand All @@ -1366,14 +1398,16 @@ func (c *Catalog) atDropDefault(rel *Relation, colName string) error {

// Cannot drop default on identity columns (use DROP IDENTITY instead).
if col.Identity != 0 {
return errInvalidObjectDefinition(fmt.Sprintf(
"column %q of relation %q is an identity column", colName, rel.Name))
return &Error{
Code: CodeSyntaxError,
Message: fmt.Sprintf("column %q of relation %q is an identity column", colName, rel.Name)}
}

// Cannot drop default on generated columns (use DROP EXPRESSION instead).
if col.Generated != 0 {
return errInvalidObjectDefinition(fmt.Sprintf(
"column %q of relation %q is a generated column", colName, rel.Name))
return &Error{
Code: CodeSyntaxError,
Message: fmt.Sprintf("column %q of relation %q is a generated column", colName, rel.Name)}
}

col.HasDefault = false
Expand Down
25 changes: 25 additions & 0 deletions pg/catalog/alter_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,31 @@ func TestAlterAddColumnIdentityCreatesSequenceAndDefault(t *testing.T) {
}
}

func TestAlterAddColumnIdentityWithDefaultRejectsBeforeSequenceCreate(t *testing.T) {
c := New()
execIdentitySQL(t, c, `CREATE TABLE t (name text)`)

results, err := c.Exec(`ALTER TABLE t ADD COLUMN id integer GENERATED ALWAYS AS IDENTITY DEFAULT name`, nil)
if err != nil {
t.Fatalf("Exec parse error: %v", err)
}
if len(results) != 1 || results[0].Error == nil {
t.Fatalf("expected ALTER TABLE error, got results=%v", results)
}
assertErrorCode(t, results[0].Error, CodeSyntaxError)

rel := c.GetRelation("", "t")
if rel == nil {
t.Fatal("relation t not found")
}
if _, exists := rel.colByName["id"]; exists {
t.Fatal("column id should not be added after identity/default conflict")
}
if _, err := c.findSequence("", "t_id_seq"); err == nil {
t.Fatal("identity sequence public.t_id_seq should not be created after identity/default conflict")
}
}

func TestAlterAddColumnIdentityRejectsRegularInheritanceChildren(t *testing.T) {
c := New()
execIdentitySQL(t, c, `
Expand Down
39 changes: 35 additions & 4 deletions pg/catalog/alter_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestSetDefaultOnIdentityColumn(t *testing.T) {
err := c.AlterTableStmt(makeAlterTableStmt("", "t",
makeATSetDefault("id", "42"),
))
assertErrorCode(t, err, CodeInvalidObjectDefinition)
assertErrorCode(t, err, CodeSyntaxError)
}

func TestSetDefaultOnGeneratedColumn(t *testing.T) {
Expand All @@ -37,7 +37,7 @@ func TestSetDefaultOnGeneratedColumn(t *testing.T) {
err := c.AlterTableStmt(makeAlterTableStmt("", "t",
makeATSetDefault("val", "1"),
))
assertErrorCode(t, err, CodeInvalidObjectDefinition)
assertErrorCode(t, err, CodeSyntaxError)
}

func TestDropDefaultOnIdentityColumn(t *testing.T) {
Expand All @@ -53,7 +53,7 @@ func TestDropDefaultOnIdentityColumn(t *testing.T) {
err := c.AlterTableStmt(makeAlterTableStmt("", "t",
makeATDropDefault("id"),
))
assertErrorCode(t, err, CodeInvalidObjectDefinition)
assertErrorCode(t, err, CodeSyntaxError)
}

func TestDropDefaultOnGeneratedColumn(t *testing.T) {
Expand All @@ -69,7 +69,38 @@ func TestDropDefaultOnGeneratedColumn(t *testing.T) {
err := c.AlterTableStmt(makeAlterTableStmt("", "t",
makeATDropDefault("val"),
))
assertErrorCode(t, err, CodeInvalidObjectDefinition)
assertErrorCode(t, err, CodeSyntaxError)
}

func TestAddColumnDefaultTypeMismatchRollsBackColumn(t *testing.T) {
c := New()
results, err := c.Exec("CREATE TABLE t(id int)", nil)
if err != nil {
t.Fatalf("setup parse error: %v", err)
}
if len(results) != 1 || results[0].Error != nil {
t.Fatalf("setup error: %v", results)
}

results, err = c.Exec("ALTER TABLE t ADD COLUMN a int DEFAULT true", nil)
if err != nil {
t.Fatalf("alter parse error: %v", err)
}
if len(results) != 1 || results[0].Error == nil {
t.Fatalf("expected ALTER TABLE error, got results=%v", results)
}
assertErrorCode(t, results[0].Error, CodeDatatypeMismatch)

rel := c.GetRelation("", "t")
if rel == nil {
t.Fatal("relation t not found")
}
if _, exists := rel.colByName["a"]; exists {
t.Fatal("column a should not be added after default type mismatch")
}
if len(rel.Columns) != 1 {
t.Fatalf("columns length: got %d, want 1", len(rel.Columns))
}
}

func TestAlterColumnTypeOnIdentityColumn(t *testing.T) {
Expand Down
Loading
Loading