From 1e1f2c09b6f5dc2c1fa4e88315aeb333ca0f1772 Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Wed, 10 Jun 2026 19:20:34 +0700 Subject: [PATCH 01/12] feat(sqlite): wire function registry and array operator support for SQLite compat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add four capabilities needed for the full 615-test postgrest-compat suite to pass against dbrest+SQLite: - rpc.ParseRegistry: parse a JSON function declaration array into a StaticRegistry so non-NativeRPC backends can expose /rpc/ endpoints without a stored-procedure catalog. - cmd/dbrest/main.go: wire DBREST_FUNCTION_REGISTRY (or PGRST_FUNCTION_REGISTRY) config field — read the JSON, parse it, and Register it on backends that implement the Register(rpc.Registry) interface. - sqlgen.Dialect.ArrayLiteral: new method converts a PostgreSQL {a,b} array literal to the engine's native format before it is bound as a parameter. Postgres/MySQL/SQL Server pass through unchanged; SQLite converts to a JSON ["a","b"] array so json_each() can iterate over it. - sqlite.ArrayOp: replace the no-op stub with json_each()-based @>, <@, && implementations. sqlite.result.go: BOOLEAN columns coerce int64 0/1 to Go bool; JSON columns return json.RawMessage so the value is embedded verbatim rather than double-encoded as a string. --- backend/mysql/dialect.go | 4 ++ backend/postgres/dialect.go | 4 ++ backend/sqlgen/compile.go | 4 +- backend/sqlgen/compile_test.go | 1 + backend/sqlgen/dialect.go | 7 +++ backend/sqlite/dialect.go | 40 +++++++++++++- backend/sqlite/result.go | 36 +++++++++++-- backend/sqlserver/dialect.go | 4 ++ cmd/dbrest/main.go | 12 +++++ rpc/registry.go | 99 +++++++++++++++++++++++++++++++++- 10 files changed, 202 insertions(+), 9 deletions(-) diff --git a/backend/mysql/dialect.go b/backend/mysql/dialect.go index 62822f0..c3eebad 100644 --- a/backend/mysql/dialect.go +++ b/backend/mysql/dialect.go @@ -206,3 +206,7 @@ func (Dialect) BoolValue(v bool) string { } return "0" } + +// ArrayLiteral returns the text unchanged; MySQL does not support arrays, so +// ArrayOp returns false before this value is ever used. +func (Dialect) ArrayLiteral(pgText string) string { return pgText } diff --git a/backend/postgres/dialect.go b/backend/postgres/dialect.go index 77fac40..2f57cf9 100644 --- a/backend/postgres/dialect.go +++ b/backend/postgres/dialect.go @@ -236,3 +236,7 @@ func (Dialect) BoolValue(v bool) string { } return "FALSE" } + +// ArrayLiteral returns the PostgreSQL {a,b} array literal unchanged; PostgreSQL +// accepts it natively. +func (Dialect) ArrayLiteral(pgText string) string { return pgText } diff --git a/backend/sqlgen/compile.go b/backend/sqlgen/compile.go index a31e955..aa761fb 100644 --- a/backend/sqlgen/compile.go +++ b/backend/sqlgen/compile.go @@ -601,7 +601,9 @@ func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError { default: sqlOp = "&&" } - val := b.bind(c.Value.Text) + // Normalize the PostgreSQL {a,b} array literal to the engine's format + // before binding; the dialect is a no-op for engines that accept {a,b}. + val := b.bind(b.d.ArrayLiteral(c.Value.Text)) var ok bool frag, ok = b.d.ArrayOp(col, sqlOp, val) if !ok { diff --git a/backend/sqlgen/compile_test.go b/backend/sqlgen/compile_test.go index 9a80da1..48369f8 100644 --- a/backend/sqlgen/compile_test.go +++ b/backend/sqlgen/compile_test.go @@ -79,6 +79,7 @@ func (stub) SessionWrite(k string) (string, bool) { return "", false } func (stub) ArrayOp(col, op, val string) (string, bool) { return col + " " + op + " " + val, true } +func (stub) ArrayLiteral(s string) string { return s } func (stub) ILike(col, val string) (string, bool) { return col + " ILIKE " + val, true } func (stub) BoolValue(v bool) string { if v { diff --git a/backend/sqlgen/dialect.go b/backend/sqlgen/dialect.go index a18da81..df0b438 100644 --- a/backend/sqlgen/dialect.go +++ b/backend/sqlgen/dialect.go @@ -92,6 +92,13 @@ type Dialect interface { // op is one of "@>", "<@", "&&"; col is the quoted column; val is the // placeholder returned by bind(). ArrayOp(col, op, val string) (string, bool) + + // ArrayLiteral converts a PostgREST array literal (PostgreSQL {a,b} syntax) + // to the engine's native format for use as a bound parameter. PostgreSQL + // accepts {a,b} natively; SQLite needs JSON ["a","b"]. Other engines that + // do not support arrays may return the text unchanged (they never reach + // ArrayOp either). + ArrayLiteral(pgText string) string } // PatternMark is the sentinel a Dialect.Regex fragment carries where the bound diff --git a/backend/sqlite/dialect.go b/backend/sqlite/dialect.go index a1c8c1b..088af2a 100644 --- a/backend/sqlite/dialect.go +++ b/backend/sqlite/dialect.go @@ -146,8 +146,44 @@ func (dialect) SessionRead(string) string { return "" } // SessionWrite reports ok=false: there is no engine setting to write. func (dialect) SessionWrite(string) (string, bool) { return "", false } -// ArrayOp returns false; SQLite has no array types or containment operators. -func (dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so json_each() in ArrayOp can iterate over it. +func (dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText // already JSON or empty; pass through + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + quoted[i] = p // already JSON-quoted + } else { + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} + +// ArrayOp implements array containment/overlap via SQLite's json_each(). The +// column must be stored as a JSON array text (e.g. '["cat","work"]'). op is +// one of "@>" (contains), "<@" (contained-by), "&&" (overlaps). +func (dialect) ArrayOp(col, op, val string) (string, bool) { + switch op { + case "@>": // contains: every element of val appears in col + return "NOT EXISTS (SELECT 1 FROM json_each(" + val + ") AS f WHERE f.value NOT IN (SELECT value FROM json_each(" + col + ")))", true + case "<@": // contained-by: every element of col appears in val + return "NOT EXISTS (SELECT 1 FROM json_each(" + col + ") AS f WHERE f.value NOT IN (SELECT value FROM json_each(" + val + ")))", true + case "&&": // overlaps: at least one common element + return "EXISTS (SELECT 1 FROM json_each(" + col + ") AS f WHERE f.value IN (SELECT value FROM json_each(" + val + ")))", true + } + return "", false +} // ILike uses plain LIKE which is case-insensitive for ASCII in SQLite. func (dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } diff --git a/backend/sqlite/result.go b/backend/sqlite/result.go index 622e744..1c7db51 100644 --- a/backend/sqlite/result.go +++ b/backend/sqlite/result.go @@ -2,7 +2,9 @@ package sqlite import ( "database/sql" + "encoding/json" "io" + "strings" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/reqctx" @@ -64,8 +66,9 @@ func (s *bufStream) Close() error { return nil } // rowStream is a forward-only cursor over the result rows. Values decode each // row into a []any the renderer maps to JSON by column name. type rowStream struct { - rows *sql.Rows - cols []string + rows *sql.Rows + cols []string + colTypes []*sql.ColumnType // lazily populated on first call to Values } func (s *rowStream) Columns() []string { return s.cols } @@ -74,9 +77,19 @@ func (s *rowStream) Err() error { return s.rows.Err() } func (s *rowStream) Close() error { return s.rows.Close() } // Values scans the current row into Go values. SQLite returns int64, float64, -// string, []byte, or nil; []byte is normalized to string so text columns render -// as JSON strings rather than base64. +// string, []byte, or nil. Post-scan coercions: +// - []byte → string so text columns render as JSON strings rather than base64. +// - BOOLEAN/BOOL declared columns: int64 0/1 → false/true so JSON marshals +// correctly as false/true rather than 0/1. +// - JSON declared columns: string → json.RawMessage so the JSON encoder embeds +// the value verbatim rather than quoting it as a string. func (s *rowStream) Values() ([]any, error) { + if s.colTypes == nil { + ct, err := s.rows.ColumnTypes() + if err == nil { + s.colTypes = ct + } + } holders := make([]any, len(s.cols)) ptrs := make([]any, len(s.cols)) for i := range holders { @@ -87,7 +100,20 @@ func (s *rowStream) Values() ([]any, error) { } for i, v := range holders { if b, ok := v.([]byte); ok { - holders[i] = string(b) + v = string(b) + holders[i] = v + } + if s.colTypes != nil && i < len(s.colTypes) { + switch strings.ToUpper(s.colTypes[i].DatabaseTypeName()) { + case "BOOLEAN", "BOOL": + if n, ok := v.(int64); ok { + holders[i] = n != 0 + } + case "JSON": + if str, ok := v.(string); ok { + holders[i] = json.RawMessage(str) + } + } } } return holders, nil diff --git a/backend/sqlserver/dialect.go b/backend/sqlserver/dialect.go index 8fac96c..139c49c 100644 --- a/backend/sqlserver/dialect.go +++ b/backend/sqlserver/dialect.go @@ -232,3 +232,7 @@ func (Dialect) BoolValue(v bool) string { } return "0" } + +// ArrayLiteral returns the text unchanged; SQL Server does not support PostgreSQL +// array syntax, so ArrayOp returns false before this value is ever used. +func (Dialect) ArrayLiteral(pgText string) string { return pgText } diff --git a/cmd/dbrest/main.go b/cmd/dbrest/main.go index 8fa4fcb..9eca13f 100644 --- a/cmd/dbrest/main.go +++ b/cmd/dbrest/main.go @@ -21,6 +21,7 @@ import ( _ "github.com/tamnd/dbrest/backend/sqlserver" "github.com/tamnd/dbrest/config" "github.com/tamnd/dbrest/httpapi" + "github.com/tamnd/dbrest/rpc" ) func main() { @@ -79,6 +80,17 @@ func openBackend(cfg *config.Config) (backend.Backend, error) { if sc, ok := be.(interface{ SetSchemas([]string) }); ok { sc.SetSchemas(cfg.Schemas) } + // Wire declared function registry for backends that cannot discover + // functions from an engine catalog (NativeRPC=false: SQLite, MySQL, …). + if cfg.FunctionRegistry != "" { + reg, err := rpc.ParseRegistry(cfg.FunctionRegistry) + if err != nil { + return nil, fmt.Errorf("function-registry: %w", err) + } + if r, ok := be.(interface{ Register(rpc.Registry) }); ok { + r.Register(reg) + } + } return be, nil } diff --git a/rpc/registry.go b/rpc/registry.go index 596233a..70f97a2 100644 --- a/rpc/registry.go +++ b/rpc/registry.go @@ -9,7 +9,12 @@ // *Function without a cycle. package rpc -import "sort" +import ( + "encoding/json" + "fmt" + "sort" + "strings" +) // Volatility classifies a function's effect, which fixes the methods it allows // and the transaction mode it runs in (spec 12). A registry entry that omits it @@ -218,6 +223,98 @@ func exactMatch(f *Function, args ArgSet) bool { return true } +// ParseRegistry decodes a JSON function-registry declaration into a +// StaticRegistry ready to Register on a backend. The JSON is an array of +// function objects; each carries: +// +// name string required; bare function name +// sql string required; parameterized SQL with :name placeholders +// params []{name, type, optional?, default?} +// returns {kind: "scalar"|"setof"|"table", type?, columns?} +// volatility "volatile"|"stable"|"immutable" (default: volatile) +// +// Returns an error when the JSON is malformed; an empty array yields an empty +// registry. Schemas are stripped from names; a name of "api.add" resolves as "add". +func ParseRegistry(rawJSON string) (*StaticRegistry, error) { + rawJSON = strings.TrimSpace(rawJSON) + if rawJSON == "" { + return NewStaticRegistry(nil), nil + } + type paramDecl struct { + Name string `json:"name"` + Type string `json:"type"` + Optional bool `json:"optional"` + Default any `json:"default"` + } + type returnDecl struct { + Kind string `json:"kind"` + Type string `json:"type"` + Columns []struct { + Name string `json:"name"` + Type string `json:"type"` + } `json:"columns"` + } + type fnDecl struct { + Name string `json:"name"` + SQL string `json:"sql"` + Params []paramDecl `json:"params"` + Returns returnDecl `json:"returns"` + Volatility string `json:"volatility"` + } + var decls []fnDecl + if err := json.Unmarshal([]byte(rawJSON), &decls); err != nil { + return nil, fmt.Errorf("function-registry: %w", err) + } + fns := make([]*Function, 0, len(decls)) + for _, d := range decls { + // Strip schema prefix (e.g. "api.add" → "add"). + name := d.Name + if dot := strings.LastIndex(name, "."); dot >= 0 { + name = name[dot+1:] + } + var vol Volatility + switch strings.ToLower(d.Volatility) { + case "stable": + vol = Stable + case "immutable": + vol = Immutable + default: + vol = Volatile + } + params := make([]Param, len(d.Params)) + for i, p := range d.Params { + params[i] = Param{ + Name: p.Name, + Type: p.Type, + Optional: p.Optional, + Default: p.Default, + } + } + var ret ReturnShape + switch strings.ToLower(d.Returns.Kind) { + case "setof": + ret.Kind = ReturnSetOf + case "table": + ret.Kind = ReturnTable + ret.Columns = make([]Column, len(d.Returns.Columns)) + for i, c := range d.Returns.Columns { + ret.Columns[i] = Column{Name: c.Name, Type: c.Type} + } + default: + ret.Kind = ReturnScalar + } + ret.Type = d.Returns.Type + fns = append(fns, &Function{ + Name: name, + Params: params, + Returns: ret, + Volatility: vol, + Query: &PortableQuery{SQL: d.SQL}, + }) + } + return NewStaticRegistry(fns), nil +} + // EmptyRegistry is a registry with no functions; every Lookup misses. A backend // that has not been given any functions returns this so the frontend raises a // clean PGRST202 rather than dereferencing nil. From ec5369d75feff9269554d20f8b59f0346a5688e8 Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Wed, 10 Jun 2026 19:46:16 +0700 Subject: [PATCH 02/12] fix(sqlite): FK enforcement, bool/JSON coercions in drain, cast column alias MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Open() enables PRAGMA foreign_keys = ON via SetMaxOpenConns(1) so FK constraint violations surface as 409 rather than silent 201 - drain() now calls ColumnTypes() and applies the same BOOLEAN→bool and JSON→json.RawMessage coercions that rowStream.Values() already does; write-path responses (UPDATE/INSERT with return=representation) were returning raw int64 for boolean columns - writeSelect alias condition now also fires when a cast is present (name == lastPath but expr != bare column); SQLite returns the full expression string as the column name when no AS alias is given, so done::text was keyed "CAST(done AS TEXT)" instead of "done" --- backend/sqlgen/compile.go | 5 +++-- backend/sqlite/sqlite.go | 30 +++++++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/backend/sqlgen/compile.go b/backend/sqlgen/compile.go index aa761fb..73f25d1 100644 --- a/backend/sqlgen/compile.go +++ b/backend/sqlgen/compile.go @@ -457,8 +457,9 @@ func (b *builder) writeSelect(items []ir.SelectItem) *pgerr.APIError { } b.sb.WriteString(expr) // Alias the output so the renderer sees the PostgREST key, not the raw - // column. Only needed when the key differs from the bare column name. - if name := col.Name(); name != "" && name != lastPath(col.Path) { + // column expression. Always alias when a cast is present (the expression + // differs from the bare column name) or when an explicit alias was set. + if name := col.Name(); name != "" && (name != lastPath(col.Path) || col.Cast != "") { b.sb.WriteString(" AS ") b.sb.WriteString(b.d.QuoteIdent(name)) } diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index e817051..de52f8e 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -4,9 +4,11 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "errors" "fmt" "regexp" + "strings" sqlitedrv "modernc.org/sqlite" sqlite3 "modernc.org/sqlite/lib" @@ -68,6 +70,13 @@ func Open(dsn string) (*Backend, error) { if err != nil { return nil, err } + // SQLite does not enforce FK constraints by default. Pin to one connection so + // the PRAGMA stays in effect for the lifetime of the pool. + db.SetMaxOpenConns(1) + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + db.Close() + return nil, err + } if err := db.Ping(); err != nil { db.Close() return nil, err @@ -357,9 +366,11 @@ func returningCols(q *ir.Query, rel *schema.Relation) []string { return nil } -// drain reads every row of a returning cursor into memory, normalizing []byte to -// string so text columns render as JSON strings. +// drain reads every row of a returning cursor into memory, applying the same +// type coercions as rowStream.Values: []byte→string, BOOLEAN int64→bool, +// JSON string→json.RawMessage. func drain(rows *sql.Rows, ncols int) ([][]any, error) { + colTypes, _ := rows.ColumnTypes() var out [][]any for rows.Next() { holders := make([]any, ncols) @@ -372,7 +383,20 @@ func drain(rows *sql.Rows, ncols int) ([][]any, error) { } for i, v := range holders { if bs, ok := v.([]byte); ok { - holders[i] = string(bs) + v = string(bs) + holders[i] = v + } + if colTypes != nil && i < len(colTypes) { + switch strings.ToUpper(colTypes[i].DatabaseTypeName()) { + case "BOOLEAN", "BOOL": + if n, ok := v.(int64); ok { + holders[i] = n != 0 + } + case "JSON": + if str, ok := v.(string); ok { + holders[i] = json.RawMessage(str) + } + } } } out = append(out, holders) From 016eba62cc54dc507afca1bd56de990d549dd818 Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Wed, 10 Jun 2026 19:49:00 +0700 Subject: [PATCH 03/12] fix(sqlite): guard json.RawMessage coercion with json.Valid Some tests PUT rows with PostgreSQL array literal syntax ({a,b}) into JSON columns. Wrapping a non-JSON string as json.RawMessage causes the encoder to return a 500 when marshaling the response. Guard both drain() and rowStream.Values() with json.Valid so only well-formed JSON values are embedded verbatim; the rest fall through as plain strings. --- backend/sqlite/result.go | 2 +- backend/sqlite/sqlite.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/sqlite/result.go b/backend/sqlite/result.go index 1c7db51..73a8a29 100644 --- a/backend/sqlite/result.go +++ b/backend/sqlite/result.go @@ -110,7 +110,7 @@ func (s *rowStream) Values() ([]any, error) { holders[i] = n != 0 } case "JSON": - if str, ok := v.(string); ok { + if str, ok := v.(string); ok && json.Valid([]byte(str)) { holders[i] = json.RawMessage(str) } } diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index de52f8e..8a04fa7 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -393,7 +393,7 @@ func drain(rows *sql.Rows, ncols int) ([][]any, error) { holders[i] = n != 0 } case "JSON": - if str, ok := v.(string); ok { + if str, ok := v.(string); ok && json.Valid([]byte(str)) { holders[i] = json.RawMessage(str) } } From f869420f52d0b5cb75fa65b7493225e2ca3c744b Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Wed, 10 Jun 2026 22:32:48 +0700 Subject: [PATCH 04/12] fix(ci): gofmt, array-op type guard, planner column-type enrichment Lint: - Fix gofmt alignment in rpc/registry.go fnDecl struct tag columns - Fix gofmt alignment in backend/sqlgen/compile_test.go stub methods Conformance (sqlite) + Test race: - Extend Dialect.ArrayOp to accept colType so the dialect can decide whether the column supports array semantics - Enrich ir.Compare.ColumnType in plan/plan.go for array ops (same pattern as FullText index attachment for FTS) - SQLite ArrayOp: return ok=false for non-JSON column types so the compiler raises PGRST127; json_each only works on JSON-typed columns; TEXT/INTEGER/etc. now correctly return 400 instead of 500 - Update all dialect implementations (postgres/mysql/sqlserver/compile_test) to match new 4-arg ArrayOp signature --- backend/mysql/dialect.go | 2 +- backend/postgres/dialect.go | 2 +- backend/sqlgen/compile.go | 2 +- backend/sqlgen/compile_test.go | 4 ++-- backend/sqlgen/dialect.go | 12 +++++++----- backend/sqlite/dialect.go | 11 ++++++++--- backend/sqlserver/dialect.go | 2 +- ir/ir.go | 5 +++++ plan/plan.go | 9 +++++++++ rpc/registry.go | 4 ++-- 10 files changed, 37 insertions(+), 16 deletions(-) diff --git a/backend/mysql/dialect.go b/backend/mysql/dialect.go index c3eebad..520779e 100644 --- a/backend/mysql/dialect.go +++ b/backend/mysql/dialect.go @@ -193,7 +193,7 @@ func (Dialect) SessionRead(string) string { return "" } func (Dialect) SessionWrite(string) (string, bool) { return "", false } // ArrayOp returns false; MySQL has no native array types or containment operators. -func (Dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +func (Dialect) ArrayOp(_, _, _, _ string) (string, bool) { return "", false } // ILike uses plain LIKE; MySQL's default utf8mb4_unicode_ci collation is CI. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } diff --git a/backend/postgres/dialect.go b/backend/postgres/dialect.go index 2f57cf9..98eafd2 100644 --- a/backend/postgres/dialect.go +++ b/backend/postgres/dialect.go @@ -222,7 +222,7 @@ func sqlLiteral(s string) string { } // ArrayOp renders a PostgreSQL array containment/overlap expression. -func (Dialect) ArrayOp(col, op, val string) (string, bool) { +func (Dialect) ArrayOp(col, op, val, _ string) (string, bool) { return col + " " + op + " " + val, true } diff --git a/backend/sqlgen/compile.go b/backend/sqlgen/compile.go index 73f25d1..30362ca 100644 --- a/backend/sqlgen/compile.go +++ b/backend/sqlgen/compile.go @@ -606,7 +606,7 @@ func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError { // before binding; the dialect is a no-op for engines that accept {a,b}. val := b.bind(b.d.ArrayLiteral(c.Value.Text)) var ok bool - frag, ok = b.d.ArrayOp(col, sqlOp, val) + frag, ok = b.d.ArrayOp(col, sqlOp, val, c.ColumnType) if !ok { return pgerr.ErrUnsupported("array operator "+sqlOp, "sql") } diff --git a/backend/sqlgen/compile_test.go b/backend/sqlgen/compile_test.go index 48369f8..f467415 100644 --- a/backend/sqlgen/compile_test.go +++ b/backend/sqlgen/compile_test.go @@ -76,10 +76,10 @@ func (stub) FullText(col string, _ *FullTextRef, v ir.FTSVariant, _, _ string) ( } func (stub) SessionRead(k string) string { return "" } func (stub) SessionWrite(k string) (string, bool) { return "", false } -func (stub) ArrayOp(col, op, val string) (string, bool) { +func (stub) ArrayOp(col, op, val, _ string) (string, bool) { return col + " " + op + " " + val, true } -func (stub) ArrayLiteral(s string) string { return s } +func (stub) ArrayLiteral(s string) string { return s } func (stub) ILike(col, val string) (string, bool) { return col + " ILIKE " + val, true } func (stub) BoolValue(v bool) string { if v { diff --git a/backend/sqlgen/dialect.go b/backend/sqlgen/dialect.go index df0b438..e6eab80 100644 --- a/backend/sqlgen/dialect.go +++ b/backend/sqlgen/dialect.go @@ -87,11 +87,13 @@ type Dialect interface { BoolValue(v bool) string // ArrayOp renders an array containment/overlap operator expression, or - // reports ok=false when the engine does not support array types (SQLite, - // MySQL, SQL Server). The compiler emits PGRST127 when ok=false. - // op is one of "@>", "<@", "&&"; col is the quoted column; val is the - // placeholder returned by bind(). - ArrayOp(col, op, val string) (string, bool) + // reports ok=false when the engine does not support array types (MySQL, SQL + // Server) or when the column type does not support array semantics (SQLite + // requires a JSON-typed column for json_each). The compiler emits PGRST127 + // when ok=false. op is one of "@>", "<@", "&&"; col is the quoted column; + // val is the placeholder returned by bind(); colType is the canonical + // column type resolved by the planner ("json", "text", "integer", …). + ArrayOp(col, op, val, colType string) (string, bool) // ArrayLiteral converts a PostgREST array literal (PostgreSQL {a,b} syntax) // to the engine's native format for use as a bound parameter. PostgreSQL diff --git a/backend/sqlite/dialect.go b/backend/sqlite/dialect.go index 088af2a..6bfe16c 100644 --- a/backend/sqlite/dialect.go +++ b/backend/sqlite/dialect.go @@ -171,9 +171,14 @@ func (dialect) ArrayLiteral(pgText string) string { } // ArrayOp implements array containment/overlap via SQLite's json_each(). The -// column must be stored as a JSON array text (e.g. '["cat","work"]'). op is -// one of "@>" (contains), "<@" (contained-by), "&&" (overlaps). -func (dialect) ArrayOp(col, op, val string) (string, bool) { +// column must be declared as JSON type and store a JSON array (e.g. +// '["cat","work"]'). For any other column type the operator is unsupported +// (ok=false) so the compiler raises PGRST127. op is one of "@>" (contains), +// "<@" (contained-by), "&&" (overlaps). +func (dialect) ArrayOp(col, op, val, colType string) (string, bool) { + if colType != "json" && colType != "jsonb" { + return "", false + } switch op { case "@>": // contains: every element of val appears in col return "NOT EXISTS (SELECT 1 FROM json_each(" + val + ") AS f WHERE f.value NOT IN (SELECT value FROM json_each(" + col + ")))", true diff --git a/backend/sqlserver/dialect.go b/backend/sqlserver/dialect.go index 139c49c..e9c07c0 100644 --- a/backend/sqlserver/dialect.go +++ b/backend/sqlserver/dialect.go @@ -219,7 +219,7 @@ func (Dialect) SessionWrite(key string) (string, bool) { } // ArrayOp returns false; SQL Server has no array types or containment operators. -func (Dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +func (Dialect) ArrayOp(_, _, _, _ string) (string, bool) { return "", false } // ILike uses plain LIKE; SQL Server's default collation is case-insensitive. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } diff --git a/ir/ir.go b/ir/ir.go index a081d21..49537ab 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -222,6 +222,11 @@ type Compare struct { FTS FTSVariant Config string FullText *schema.FullTextIndex + // ColumnType is the canonical type of the column at Path[0], resolved by + // the planner from the schema. The dialect uses it to decide whether an + // engine-specific operator (e.g. json_each for array ops on SQLite) can + // apply; it is empty when the column is unknown or for multi-step paths. + ColumnType string } func (Compare) isCond() {} diff --git a/plan/plan.go b/plan/plan.go index 1d2740b..a490f48 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -323,6 +323,15 @@ func validateCond(rel *schema.Relation, c *ir.Cond) *pgerr.APIError { n.FullText = rel.FullTextIndexFor(n.Path[0]) *c = n } + // Array operators carry the column's canonical type so the dialect can + // decide whether the column supports array semantics (e.g. SQLite's + // json_each only applies to JSON-typed columns). See spec 21. + if (n.Op == ir.OpContains || n.Op == ir.OpContained || n.Op == ir.OpOverlap) && len(n.Path) == 1 { + if col, ok := rel.Column(n.Path[0]); ok { + n.ColumnType = col.Type + *c = n + } + } } return nil } diff --git a/rpc/registry.go b/rpc/registry.go index 70f97a2..e1c1467 100644 --- a/rpc/registry.go +++ b/rpc/registry.go @@ -255,8 +255,8 @@ func ParseRegistry(rawJSON string) (*StaticRegistry, error) { } `json:"columns"` } type fnDecl struct { - Name string `json:"name"` - SQL string `json:"sql"` + Name string `json:"name"` + SQL string `json:"sql"` Params []paramDecl `json:"params"` Returns returnDecl `json:"returns"` Volatility string `json:"volatility"` From 229393d44ac892dbcd2f7cedc6278ee449b131fb Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Wed, 10 Jun 2026 23:08:08 +0700 Subject: [PATCH 05/12] feat(mysql): JSON array containment via JSON_CONTAINS/JSON_OVERLAPS ArrayOp now implements @>, <@, and && for JSON columns using MySQL 8.0.17+ functions. For non-JSON columns it returns false so the compiler raises PGRST127, matching the SQLite behaviour and the conformance allowlist. ArrayLiteral converts the PostgreSQL {a,b} format to ["a","b"] so the JSON functions receive a valid JSON array argument. --- backend/mysql/dialect.go | 47 +++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/backend/mysql/dialect.go b/backend/mysql/dialect.go index 520779e..cea3288 100644 --- a/backend/mysql/dialect.go +++ b/backend/mysql/dialect.go @@ -192,8 +192,25 @@ func (Dialect) SessionRead(string) string { return "" } // SessionWrite reports ok=false: there is no engine setting to write. func (Dialect) SessionWrite(string) (string, bool) { return "", false } -// ArrayOp returns false; MySQL has no native array types or containment operators. -func (Dialect) ArrayOp(_, _, _, _ string) (string, bool) { return "", false } +// ArrayOp renders a JSON array containment/overlap expression using MySQL's +// JSON_CONTAINS and JSON_OVERLAPS functions (MySQL 8.0.17+). The column must be +// declared as JSON type; for any other column type ok=false is returned so the +// compiler raises PGRST127. colType is the canonical column type enriched by the +// planner; op is one of "@>" (contains), "<@" (contained-by), "&&" (overlaps). +func (Dialect) ArrayOp(col, op, val, colType string) (string, bool) { + if colType != "json" && colType != "jsonb" { + return "", false + } + switch op { + case "@>": // contains: col contains all elements of val + return "JSON_CONTAINS(" + col + ", " + val + ")", true + case "<@": // contained-by: val contains all elements of col + return "JSON_CONTAINS(" + val + ", " + col + ")", true + case "&&": // overlaps: at least one common element + return "JSON_OVERLAPS(" + col + ", " + val + ")", true + } + return "", false +} // ILike uses plain LIKE; MySQL's default utf8mb4_unicode_ci collation is CI. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } @@ -207,6 +224,26 @@ func (Dialect) BoolValue(v bool) string { return "0" } -// ArrayLiteral returns the text unchanged; MySQL does not support arrays, so -// ArrayOp returns false before this value is ever used. -func (Dialect) ArrayLiteral(pgText string) string { return pgText } +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so JSON_CONTAINS/JSON_OVERLAPS in ArrayOp can process it. +func (Dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText // already JSON or empty; pass through + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + quoted[i] = p // already JSON-quoted + } else { + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} From 100a9c19ac67eee9fdfb341e10468edf9a05d7f7 Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Wed, 10 Jun 2026 23:55:42 +0700 Subject: [PATCH 06/12] feat(sqlserver): native RPC via EXEC and compileNativeCall When NativeRPC=true the plan carries no portable registry function (plan.Func==nil). CompileCall panics on nil fn; introduce compileNativeCall that emits EXEC [dbo].[name] @arg=@pN instead. Guard count=exact path against nil fn as well: T-SQL cannot wrap EXEC in SELECT count(*), so count is skipped for native procs. --- backend/sqlserver/execute.go | 69 ++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/backend/sqlserver/execute.go b/backend/sqlserver/execute.go index b5ed5d1..2bc2152 100644 --- a/backend/sqlserver/execute.go +++ b/backend/sqlserver/execute.go @@ -3,6 +3,8 @@ package sqlserver import ( "context" "database/sql" + "encoding/json" + "strconv" "strings" "github.com/tamnd/dbrest/backend" @@ -279,14 +281,26 @@ func (b *Backend) executeDelete( // executeCall runs a stored procedure or portable RPC function. func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - st, apiErr := sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func) + var st *sqlgen.Statement + var apiErr *pgerr.APIError + if plan.Func != nil { + // Portable registry function: the function body is a parameterised SQL + // statement whose :name placeholders are bound by CompileCall. + st, apiErr = sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func) + } else { + // Native RPC (NativeRPC=true): no registry function — generate EXEC + // [schema].[name] @param = @pN from the call's argument map. + st, apiErr = b.compileNativeCall(plan.Call) + } if apiErr != nil { return nil, apiErr } if plan.ReadOnly { res := &result{controls: rc.Controls()} - if plan.Call.Count != ir.CountNone { + // count=exact is only supported for portable registry functions; native + // stored procedures cannot be wrapped in SELECT count(*) in T-SQL. + if plan.Call.Count != ir.CountNone && plan.Func != nil { cst, apiErr := sqlgen.CompileCallCount(Dialect{}, plan.Call, plan.Func) if apiErr != nil { return nil, apiErr @@ -387,6 +401,57 @@ func returningCols(q *ir.Query, rel *schema.Relation) []string { return nil } +// compileNativeCall generates EXEC [schema].[name] @arg1 = @p1, @arg2 = @p2 for +// the NativeRPC path (plan.Func == nil). SQL Server stored procedures accept +// named parameters in any order, so the argument map can be emitted as-is. +// Scalar stored procedures should SELECT the result in a column named after the +// function (e.g. SELECT @a + @b AS [add]) so renderCall can detect scalar return +// by seeing a single column whose name matches the function name. +func (b *Backend) compileNativeCall(c *ir.Call) (*sqlgen.Statement, *pgerr.APIError) { + sch := b.schema + if sch == "" { + sch = "dbo" + } + d := Dialect{} + var sb strings.Builder + sb.WriteString("EXEC ") + sb.WriteString(d.QuoteIdent(sch)) + sb.WriteString(".") + sb.WriteString(d.QuoteIdent(c.Function.Name)) + + args := make([]any, 0, len(c.Args)) + i := 1 + for name, val := range c.Args { + if i == 1 { + sb.WriteString(" ") + } else { + sb.WriteString(", ") + } + sb.WriteString("@" + name + " = @p" + strconv.Itoa(i)) + // A POST arg has a decoded JSON value; a GET arg is raw text. + if val.JSON != nil { + args = append(args, nativeArgValue(val.JSON)) + } else { + args = append(args, val.Text) + } + i++ + } + return &sqlgen.Statement{SQL: sb.String(), Args: args}, nil +} + +// nativeArgValue converts a decoded JSON argument value to a driver-ready type. +// Scalars (string, float64, bool, nil) pass through; composite values are +// re-encoded as JSON text so the stored procedure can receive them as NVARCHAR. +func nativeArgValue(v any) any { + switch v.(type) { + case string, float64, bool, nil: + return v + default: + b, _ := json.Marshal(v) + return string(b) + } +} + // _ is a compile-time check that Backend implements backend.DB. var _ interface { Execute(context.Context, *ir.Plan, *reqctx.Context) (backend.Result, error) From f852bc5bb8cc721c1db6a8a11421ed66be9c4ec2 Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Thu, 11 Jun 2026 10:04:48 +0700 Subject: [PATCH 07/12] fix(sqlserver): five compat fixes for SQL Server 2022 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit IsBool: add Dialect.IsBool hook so SQL Server generates col = 1/0 instead of the invalid col IS 1/0 (IS only accepts NULL in T-SQL). LimitOffset: return OFFSET 0 ROWS when hasOrder=true and no paging was requested, making ORDER BY valid inside derived tables. JSONAgg: replace JSON_ARRAYAGG (SQL Server 2025 only) with '['+STRING_AGG(CAST(elem AS NVARCHAR(MAX)),',')+']' which works on SQL Server 2022. ArrayOp/ArrayLiteral: implement @>, <@, && via OPENJSON; convert {a,b} PostgreSQL array literals to JSON arrays for OPENJSON input. embed.go LIMIT 1: replace the hardcoded LIMIT 1 clause with a dialect-aware LimitOffset call so SQL Server emits OFFSET 0 ROWS FETCH NEXT 1 ROWS ONLY instead of the invalid T-SQL LIMIT. executeUpsert: add multi-statement UPDATE … ; IF @@ROWCOUNT=0 INSERT … batch inside the request transaction, routing ir.Upsert queries away from the single-statement compiler path that returns errUpsertMultiStatement. Rows are read back via a post-batch SELECT on the conflict key when returning columns are requested. --- backend/mysql/dialect.go | 4 + backend/postgres/dialect.go | 4 + backend/sqlgen/compile.go | 6 ++ backend/sqlgen/compile_test.go | 1 + backend/sqlgen/dialect.go | 6 ++ backend/sqlgen/embed.go | 7 +- backend/sqlite/dialect.go | 4 + backend/sqlserver/dialect.go | 68 ++++++++++-- backend/sqlserver/dialect_test.go | 2 +- backend/sqlserver/execute.go | 173 ++++++++++++++++++++++++++++++ 10 files changed, 263 insertions(+), 12 deletions(-) diff --git a/backend/mysql/dialect.go b/backend/mysql/dialect.go index cea3288..4f7451c 100644 --- a/backend/mysql/dialect.go +++ b/backend/mysql/dialect.go @@ -215,6 +215,10 @@ func (Dialect) ArrayOp(col, op, val, colType string) (string, bool) { // ILike uses plain LIKE; MySQL's default utf8mb4_unicode_ci collation is CI. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } +// IsBool falls back to the generic "IS 1"/"IS 0" form; MySQL treats IS TRUE +// and IS 1 equivalently for TINYINT(1) columns. +func (Dialect) IsBool(string, bool) (string, bool) { return "", false } + // BoolValue renders a boolean as 1/0. MySQL's BOOL is an alias for TINYINT(1), // so there is no native boolean keyword. func (Dialect) BoolValue(v bool) string { diff --git a/backend/postgres/dialect.go b/backend/postgres/dialect.go index 98eafd2..d6b1c45 100644 --- a/backend/postgres/dialect.go +++ b/backend/postgres/dialect.go @@ -237,6 +237,10 @@ func (Dialect) BoolValue(v bool) string { return "FALSE" } +// IsBool falls back to the generic "IS TRUE"/"IS FALSE" form; PostgreSQL +// supports IS natively. +func (Dialect) IsBool(string, bool) (string, bool) { return "", false } + // ArrayLiteral returns the PostgreSQL {a,b} array literal unchanged; PostgreSQL // accepts it natively. func (Dialect) ArrayLiteral(pgText string) string { return pgText } diff --git a/backend/sqlgen/compile.go b/backend/sqlgen/compile.go index 30362ca..55781ba 100644 --- a/backend/sqlgen/compile.go +++ b/backend/sqlgen/compile.go @@ -700,8 +700,14 @@ func (b *builder) writeIs(col, text string) (string, *pgerr.APIError) { case "not_null": return col + " IS NOT NULL", nil case "true": + if frag, ok := b.d.IsBool(col, true); ok { + return frag, nil + } return col + " IS " + b.d.BoolValue(true), nil case "false": + if frag, ok := b.d.IsBool(col, false); ok { + return frag, nil + } return col + " IS " + b.d.BoolValue(false), nil default: return "", pgerr.ErrParse("unknown is value " + text) diff --git a/backend/sqlgen/compile_test.go b/backend/sqlgen/compile_test.go index f467415..ad86b38 100644 --- a/backend/sqlgen/compile_test.go +++ b/backend/sqlgen/compile_test.go @@ -87,6 +87,7 @@ func (stub) BoolValue(v bool) string { } return "FALSE" } +func (stub) IsBool(string, bool) (string, bool) { return "", false } func compile(t *testing.T, q *ir.Query) *Statement { t.Helper() diff --git a/backend/sqlgen/dialect.go b/backend/sqlgen/dialect.go index e6eab80..a0453e4 100644 --- a/backend/sqlgen/dialect.go +++ b/backend/sqlgen/dialect.go @@ -86,6 +86,12 @@ type Dialect interface { // BoolValue renders a boolean literal. BoolValue(v bool) string + // IsBool renders "col IS TRUE" or "col IS FALSE" in the engine's syntax. + // Engines that restrict IS to NULL/UNKNOWN (SQL Server) return ok=true with + // a = expression; engines that support IS return ok=false to fall back + // to "col IS ". + IsBool(col string, v bool) (string, bool) + // ArrayOp renders an array containment/overlap operator expression, or // reports ok=false when the engine does not support array types (MySQL, SQL // Server) or when the column type does not support array semantics (SQLite diff --git a/backend/sqlgen/embed.go b/backend/sqlgen/embed.go index 7747d15..006b5e3 100644 --- a/backend/sqlgen/embed.go +++ b/backend/sqlgen/embed.go @@ -168,7 +168,12 @@ func (b *builder) writeEmbed(emb *ir.Embed, parentAlias string) *pgerr.APIError if err := b.writeEmbedFilter(emb, alias); err != nil { return err } - b.sb.WriteString(" LIMIT 1)") + lim := 1 + if lo := b.d.LimitOffset(&lim, nil, false); lo != "" { + b.sb.WriteString(" ") + b.sb.WriteString(lo) + } + b.sb.WriteString(")") return nil } diff --git a/backend/sqlite/dialect.go b/backend/sqlite/dialect.go index 6bfe16c..dcafa17 100644 --- a/backend/sqlite/dialect.go +++ b/backend/sqlite/dialect.go @@ -193,6 +193,10 @@ func (dialect) ArrayOp(col, op, val, colType string) (string, bool) { // ILike uses plain LIKE which is case-insensitive for ASCII in SQLite. func (dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } +// IsBool falls back to the generic "IS 1"/"IS 0" form; SQLite's IS operator is +// a NULL-safe equality that works with any value. +func (dialect) IsBool(string, bool) (string, bool) { return "", false } + // BoolValue renders a boolean as 1/0; SQLite has no native boolean. func (dialect) BoolValue(v bool) string { if v { diff --git a/backend/sqlserver/dialect.go b/backend/sqlserver/dialect.go index e9c07c0..7dfae71 100644 --- a/backend/sqlserver/dialect.go +++ b/backend/sqlserver/dialect.go @@ -50,6 +50,11 @@ func (Dialect) Placeholder(n int) string { return "@p" + strconv.Itoa(n) } // used uniformly. func (Dialect) LimitOffset(limit, offset *int, hasOrder bool) string { if limit == nil && offset == nil { + if hasOrder { + // ORDER BY in a derived table requires OFFSET even when no paging is + // requested; OFFSET 0 ROWS keeps all rows while making the ORDER BY valid. + return "OFFSET 0 ROWS" + } return "" } off := 0 @@ -135,12 +140,13 @@ func (Dialect) JSONObject(pairs []sqlgen.Pair) string { return "JSON_OBJECT(" + strings.Join(parts, ", ") + ")" } -// JSONAgg aggregates rows with the SQL Server 2022 JSON_ARRAYAGG. The aggregate -// takes no ORDER BY argument, so a requested embed order is applied on the -// derived table feeding the aggregate, not here; orderBy is therefore unused and -// the row order within the array is best-effort (spec 06). +// JSONAgg aggregates rows into a JSON array using STRING_AGG. JSON_ARRAYAGG was +// only added in SQL Server 2025 (version 17); for 2022 compatibility the dialect +// constructs the array manually: '[' + STRING_AGG(elem,',') + ']'. The elements +// are cast to NVARCHAR(MAX) so STRING_AGG accepts them. orderBy is unused; a +// requested embed order is applied on the derived table feeding the aggregate. func (Dialect) JSONAgg(elem, _ string) string { - return "JSON_ARRAYAGG(" + elem + ")" + return "'['+STRING_AGG(CAST((" + elem + ") AS NVARCHAR(MAX)),',')+']'" } // Cast translates a canonical type to a T-SQL CAST target. SQL Server has no @@ -218,8 +224,30 @@ func (Dialect) SessionWrite(key string) (string, bool) { return "EXEC sp_set_session_context N'" + strings.ReplaceAll(key, "'", "''") + "', " + sqlgen.PatternMark, true } -// ArrayOp returns false; SQL Server has no array types or containment operators. -func (Dialect) ArrayOp(_, _, _, _ string) (string, bool) { return "", false } +// ArrayOp implements array containment/overlap operators using OPENJSON, which +// parses the JSON array argument and the JSON array column for element-level +// comparisons. val is a bound placeholder (@pN) whose value is a JSON array +// string (converted from PostgreSQL {a,b} syntax by ArrayLiteral). +func (Dialect) ArrayOp(col, op, val, _ string) (string, bool) { + switch op { + case "@>": + // col contains every element of val + return "NOT EXISTS(SELECT [value] FROM OPENJSON(" + val + ") WHERE [value] NOT IN (SELECT [value] FROM OPENJSON(" + col + ")))", true + case "<@": + // every element of col exists in val + return "NOT EXISTS(SELECT [value] FROM OPENJSON(" + col + ") WHERE [value] NOT IN (SELECT [value] FROM OPENJSON(" + val + ")))", true + case "&&": + // at least one element in common + return "EXISTS(SELECT 1 FROM OPENJSON(" + col + ") a WHERE a.[value] IN (SELECT [value] FROM OPENJSON(" + val + ")))", true + } + return "", false +} + +// IsBool renders "col = 1" or "col = 0" for SQL Server BIT columns. SQL +// Server's IS operator only accepts NULL/UNKNOWN, not integer literals. +func (Dialect) IsBool(col string, v bool) (string, bool) { + return col + " = " + Dialect{}.BoolValue(v), true +} // ILike uses plain LIKE; SQL Server's default collation is case-insensitive. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } @@ -233,6 +261,26 @@ func (Dialect) BoolValue(v bool) string { return "0" } -// ArrayLiteral returns the text unchanged; SQL Server does not support PostgreSQL -// array syntax, so ArrayOp returns false before this value is ever used. -func (Dialect) ArrayLiteral(pgText string) string { return pgText } +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so OPENJSON in ArrayOp can iterate over it. +func (Dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + if p == "NULL" { + quoted[i] = "null" + } else { + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} diff --git a/backend/sqlserver/dialect_test.go b/backend/sqlserver/dialect_test.go index ea41b2f..c4f29f5 100644 --- a/backend/sqlserver/dialect_test.go +++ b/backend/sqlserver/dialect_test.go @@ -98,7 +98,7 @@ func TestJSON(t *testing.T) { if obj != "JSON_OBJECT('name': d.[name], 'year': d.[year])" { t.Errorf("JSONObject = %q", obj) } - if got := d.JSONAgg("t", "t.[id] DESC"); got != "JSON_ARRAYAGG(t)" { + if got := d.JSONAgg("t", "t.[id] DESC"); got != "'['+STRING_AGG(CAST((t) AS NVARCHAR(MAX)),',')+']'" { t.Errorf("JSONAgg = %q", got) } } diff --git a/backend/sqlserver/execute.go b/backend/sqlserver/execute.go index 2bc2152..6712003 100644 --- a/backend/sqlserver/execute.go +++ b/backend/sqlserver/execute.go @@ -147,11 +147,17 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co // The compiler emits: INSERT INTO [t] ([c1],[c2]) VALUES (@p1,@p2) // The data plane rewrites to: INSERT INTO [t] ([c1],[c2]) OUTPUT INSERTED.[c1],... VALUES (@p1,@p2) // by injecting the OUTPUT fragment before the " VALUES " marker. +// Upsert (on_conflict) is routed to executeUpsert instead of the single-statement +// compiler, which returns errUpsertMultiStatement. func (b *Backend) executeInsert( ctx context.Context, tx *sql.Tx, q *ir.Query, returning []string, rel *schema.Relation, res *writeResult, ) error { + if q.Kind == ir.Upsert { + return b.executeUpsert(ctx, tx, q, returning, res) + } + st, apiErr := sqlgen.CompileInsert(Dialect{}, q, nil) if apiErr != nil { return apiErr @@ -189,6 +195,173 @@ func (b *Backend) executeInsert( return nil } +// executeUpsert implements the SQL Server multi-statement upsert pattern: +// for each row emit UPDATE … WHERE pk=@pN; IF @@ROWCOUNT=0 INSERT … +// inside the request transaction. Named @pN placeholders let each value be +// referenced by both the UPDATE and the INSERT within the same batch. +// +// After the batch, when returning columns are requested, the upserted rows are +// read back via SELECT … WHERE (pk1=@q1 AND pk2=@q2) OR … +func (b *Backend) executeUpsert( + ctx context.Context, tx *sql.Tx, + q *ir.Query, returning []string, + res *writeResult, +) error { + w := q.Write + if len(w.Rows) == 0 { + res.affected, res.hasAff = 0, true + return nil + } + + d := Dialect{} + sch := q.Relation.Schema + if sch == "" { + sch = b.schema + if sch == "" { + sch = "dbo" + } + } + tableName := d.QuoteIdent(sch) + "." + d.QuoteIdent(q.Relation.Name) + + conflictCols := w.Conflict.Target + conflictSet := make(map[string]bool, len(conflictCols)) + for _, c := range conflictCols { + conflictSet[c] = true + } + nonConflictCols := make([]string, 0, len(w.Columns)) + for _, c := range w.Columns { + if !conflictSet[c] { + nonConflictCols = append(nonConflictCols, c) + } + } + + var batchSQL strings.Builder + batchRaw := []any{} // raw values; wrapped by namedArgs() as p1, p2, ... + argN := 0 + bind := func(v any) string { + argN++ + batchRaw = append(batchRaw, v) + return "@p" + strconv.Itoa(argN) + } + + for _, row := range w.Rows { + // Bind each column value once; named placeholders can be reused. + colP := make(map[string]string, len(w.Columns)) + for _, c := range w.Columns { + colP[c] = bind(sqlgen.WriteArg(row[c])) + } + + if len(nonConflictCols) > 0 { + // UPDATE … SET non-pk cols WHERE pk cols + batchSQL.WriteString("UPDATE ") + batchSQL.WriteString(tableName) + batchSQL.WriteString(" WITH (UPDLOCK,HOLDLOCK) SET ") + for i, c := range nonConflictCols { + if i > 0 { + batchSQL.WriteString(",") + } + batchSQL.WriteString(d.QuoteIdent(c) + "=" + colP[c]) + } + batchSQL.WriteString(" WHERE ") + for i, c := range conflictCols { + if i > 0 { + batchSQL.WriteString(" AND ") + } + batchSQL.WriteString(d.QuoteIdent(c) + "=" + colP[c]) + } + batchSQL.WriteString("; IF @@ROWCOUNT=0 ") + } else { + // No non-conflict columns: row is pk-only; insert if absent. + batchSQL.WriteString("IF NOT EXISTS(SELECT 1 FROM ") + batchSQL.WriteString(tableName) + batchSQL.WriteString(" WITH (UPDLOCK,HOLDLOCK) WHERE ") + for i, c := range conflictCols { + if i > 0 { + batchSQL.WriteString(" AND ") + } + batchSQL.WriteString(d.QuoteIdent(c) + "=" + colP[c]) + } + batchSQL.WriteString(") ") + } + + batchSQL.WriteString("INSERT INTO ") + batchSQL.WriteString(tableName) + batchSQL.WriteString("(") + for i, c := range w.Columns { + if i > 0 { + batchSQL.WriteString(",") + } + batchSQL.WriteString(d.QuoteIdent(c)) + } + batchSQL.WriteString(") VALUES(") + for i, c := range w.Columns { + if i > 0 { + batchSQL.WriteString(",") + } + batchSQL.WriteString(colP[c]) + } + batchSQL.WriteString(");") + } + + if _, err := tx.ExecContext(ctx, batchSQL.String(), namedArgs(batchRaw)...); err != nil { + return err + } + res.affected, res.hasAff = int64(len(w.Rows)), true + + if len(returning) == 0 { + return nil + } + + // SELECT the upserted rows back by their conflict key values. + var selSQL strings.Builder + selSQL.WriteString("SELECT ") + for i, c := range returning { + if i > 0 { + selSQL.WriteString(",") + } + selSQL.WriteString(d.QuoteIdent(c)) + } + selSQL.WriteString(" FROM ") + selSQL.WriteString(tableName) + selSQL.WriteString(" WHERE ") + selRaw := []any{} + selN := 0 + for ri, row := range w.Rows { + if ri > 0 { + selSQL.WriteString(" OR ") + } + selSQL.WriteString("(") + for ci, c := range conflictCols { + if ci > 0 { + selSQL.WriteString(" AND ") + } + selN++ + selSQL.WriteString(d.QuoteIdent(c) + "=@p" + strconv.Itoa(selN)) + selRaw = append(selRaw, sqlgen.WriteArg(row[c])) + } + selSQL.WriteString(")") + } + + rows, err := tx.QueryContext(ctx, selSQL.String(), namedArgs(selRaw)...) + if err != nil { + return err + } + cols, err := rows.Columns() + if err != nil { + rows.Close() + return err + } + jsonIdx, timeIdx := buildColMaps(rows, nil) + buf, err := drain(rows, cols, jsonIdx, timeIdx) + rows.Close() + if err != nil { + return err + } + res.cols, res.rows = cols, buf + res.affected, res.hasAff = int64(len(buf)), true + return nil +} + // executeUpdate runs UPDATE [t] SET ... OUTPUT INSERTED.* WHERE ... // Compiler emits: UPDATE [t] SET [c]=@p1 WHERE [id]=@p2 // Rewritten to: UPDATE [t] SET [c]=@p1 OUTPUT INSERTED.[c],... WHERE [id]=@p2 From 77ab670efcf0bc08cd9774e5eb5796ff239099ce Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Thu, 11 Jun 2026 10:24:45 +0700 Subject: [PATCH 08/12] sqlserver: fix ArrayLiteral PG-quoting and rewrite upsert as MERGE ArrayLiteral now passes through elements that are already JSON-quoted (start and end with ") rather than wrapping them again. The postgrest-go client sends {"go"} with the element double-quoted per PG array literal syntax, and the old code produced ["\"go\""] instead of ["go"]. executeUpsert is rewritten as a single MERGE statement. The previous UPDATE; IF @@ROWCOUNT=0 INSERT pattern uses a semicolon-separated batch that go-mssqldb rejects inside sp_executesql. MERGE is a single statement, takes all rows in the source VALUES table, and includes an OUTPUT clause when returning columns are requested. --- backend/sqlserver/dialect.go | 3 + backend/sqlserver/execute.go | 211 ++++++++++++++++++----------------- 2 files changed, 113 insertions(+), 101 deletions(-) diff --git a/backend/sqlserver/dialect.go b/backend/sqlserver/dialect.go index 7dfae71..0578d98 100644 --- a/backend/sqlserver/dialect.go +++ b/backend/sqlserver/dialect.go @@ -278,6 +278,9 @@ func (Dialect) ArrayLiteral(pgText string) string { p = strings.TrimSpace(p) if p == "NULL" { quoted[i] = "null" + } else if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + // PostgreSQL double-quote escaping: "foo" is already valid JSON; pass through. + quoted[i] = p } else { quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` } diff --git a/backend/sqlserver/execute.go b/backend/sqlserver/execute.go index 6712003..da72a6c 100644 --- a/backend/sqlserver/execute.go +++ b/backend/sqlserver/execute.go @@ -195,13 +195,14 @@ func (b *Backend) executeInsert( return nil } -// executeUpsert implements the SQL Server multi-statement upsert pattern: -// for each row emit UPDATE … WHERE pk=@pN; IF @@ROWCOUNT=0 INSERT … -// inside the request transaction. Named @pN placeholders let each value be -// referenced by both the UPDATE and the INSERT within the same batch. +// executeUpsert implements the SQL Server upsert as a single MERGE statement per +// batch. MERGE avoids the semicolon-separated multi-statement pattern that +// go-mssqldb rejects when sent via sp_executesql. // -// After the batch, when returning columns are requested, the upserted rows are -// read back via SELECT … WHERE (pk1=@q1 AND pk2=@q2) OR … +// All rows are merged in one statement: the source is a VALUES(...) table with +// one row-tuple per input row; the ON clause matches the conflict (primary-key) +// columns; WHEN MATCHED updates non-key columns; WHEN NOT MATCHED inserts. +// The OUTPUT clause captures written rows when returning is requested. func (b *Backend) executeUpsert( ctx context.Context, tx *sql.Tx, q *ir.Query, returning []string, @@ -235,133 +236,141 @@ func (b *Backend) executeUpsert( } } - var batchSQL strings.Builder - batchRaw := []any{} // raw values; wrapped by namedArgs() as p1, p2, ... + // Collect args; @pN bind positions match the order we append. + raw := []any{} argN := 0 bind := func(v any) string { argN++ - batchRaw = append(batchRaw, v) + raw = append(raw, v) return "@p" + strconv.Itoa(argN) } - for _, row := range w.Rows { - // Bind each column value once; named placeholders can be reused. - colP := make(map[string]string, len(w.Columns)) - for _, c := range w.Columns { - colP[c] = bind(sqlgen.WriteArg(row[c])) - } + // Build the source alias column names: s0, s1, ... + srcCols := make([]string, len(w.Columns)) + for i := range w.Columns { + srcCols[i] = "s" + strconv.Itoa(i) + } - if len(nonConflictCols) > 0 { - // UPDATE … SET non-pk cols WHERE pk cols - batchSQL.WriteString("UPDATE ") - batchSQL.WriteString(tableName) - batchSQL.WriteString(" WITH (UPDLOCK,HOLDLOCK) SET ") - for i, c := range nonConflictCols { - if i > 0 { - batchSQL.WriteString(",") - } - batchSQL.WriteString(d.QuoteIdent(c) + "=" + colP[c]) - } - batchSQL.WriteString(" WHERE ") - for i, c := range conflictCols { - if i > 0 { - batchSQL.WriteString(" AND ") - } - batchSQL.WriteString(d.QuoteIdent(c) + "=" + colP[c]) - } - batchSQL.WriteString("; IF @@ROWCOUNT=0 ") - } else { - // No non-conflict columns: row is pk-only; insert if absent. - batchSQL.WriteString("IF NOT EXISTS(SELECT 1 FROM ") - batchSQL.WriteString(tableName) - batchSQL.WriteString(" WITH (UPDLOCK,HOLDLOCK) WHERE ") - for i, c := range conflictCols { - if i > 0 { - batchSQL.WriteString(" AND ") - } - batchSQL.WriteString(d.QuoteIdent(c) + "=" + colP[c]) - } - batchSQL.WriteString(") ") - } + var sb strings.Builder - batchSQL.WriteString("INSERT INTO ") - batchSQL.WriteString(tableName) - batchSQL.WriteString("(") - for i, c := range w.Columns { - if i > 0 { - batchSQL.WriteString(",") - } - batchSQL.WriteString(d.QuoteIdent(c)) + // MERGE INTO target USING (VALUES (...),(...)) AS src(s0,s1,...) + sb.WriteString("MERGE INTO ") + sb.WriteString(tableName) + sb.WriteString(" WITH (HOLDLOCK) AS [_target] USING (VALUES ") + for ri, row := range w.Rows { + if ri > 0 { + sb.WriteString(",") } - batchSQL.WriteString(") VALUES(") - for i, c := range w.Columns { - if i > 0 { - batchSQL.WriteString(",") + sb.WriteString("(") + for ci, c := range w.Columns { + if ci > 0 { + sb.WriteString(",") } - batchSQL.WriteString(colP[c]) + sb.WriteString(bind(sqlgen.WriteArg(row[c]))) } - batchSQL.WriteString(");") + sb.WriteString(")") } - - if _, err := tx.ExecContext(ctx, batchSQL.String(), namedArgs(batchRaw)...); err != nil { - return err + sb.WriteString(") AS [_src](") + for i, sc := range srcCols { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString(d.QuoteIdent(sc)) } - res.affected, res.hasAff = int64(len(w.Rows)), true + sb.WriteString(") ON (") + // ON conflict columns match + for i, c := range conflictCols { + if i > 0 { + sb.WriteString(" AND ") + } + ci := colIndex(w.Columns, c) + sb.WriteString("[_target]." + d.QuoteIdent(c) + "=[_src]." + d.QuoteIdent(srcCols[ci])) + } + sb.WriteString(")") - if len(returning) == 0 { - return nil + // WHEN MATCHED THEN UPDATE (skip if ignore or no non-conflict cols) + if w.Conflict.Resolution != ir.ConflictIgnore && len(nonConflictCols) > 0 { + sb.WriteString(" WHEN MATCHED THEN UPDATE SET ") + for i, c := range nonConflictCols { + if i > 0 { + sb.WriteString(",") + } + ci := colIndex(w.Columns, c) + sb.WriteString("[_target]." + d.QuoteIdent(c) + "=[_src]." + d.QuoteIdent(srcCols[ci])) + } } - // SELECT the upserted rows back by their conflict key values. - var selSQL strings.Builder - selSQL.WriteString("SELECT ") - for i, c := range returning { + // WHEN NOT MATCHED THEN INSERT + sb.WriteString(" WHEN NOT MATCHED THEN INSERT (") + for i, c := range w.Columns { if i > 0 { - selSQL.WriteString(",") + sb.WriteString(",") } - selSQL.WriteString(d.QuoteIdent(c)) + sb.WriteString(d.QuoteIdent(c)) } - selSQL.WriteString(" FROM ") - selSQL.WriteString(tableName) - selSQL.WriteString(" WHERE ") - selRaw := []any{} - selN := 0 - for ri, row := range w.Rows { - if ri > 0 { - selSQL.WriteString(" OR ") + sb.WriteString(") VALUES (") + for i, sc := range srcCols { + if i > 0 { + sb.WriteString(",") } - selSQL.WriteString("(") - for ci, c := range conflictCols { - if ci > 0 { - selSQL.WriteString(" AND ") + sb.WriteString("[_src]." + d.QuoteIdent(sc)) + } + sb.WriteString(")") + + // OUTPUT clause when returning is requested + if len(returning) > 0 { + sb.WriteString(" OUTPUT ") + for i, c := range returning { + if i > 0 { + sb.WriteString(",") } - selN++ - selSQL.WriteString(d.QuoteIdent(c) + "=@p" + strconv.Itoa(selN)) - selRaw = append(selRaw, sqlgen.WriteArg(row[c])) + sb.WriteString("INSERTED." + d.QuoteIdent(c)) } - selSQL.WriteString(")") } - rows, err := tx.QueryContext(ctx, selSQL.String(), namedArgs(selRaw)...) - if err != nil { - return err - } - cols, err := rows.Columns() - if err != nil { + // MERGE requires a terminating semicolon. + sb.WriteString(";") + + if len(returning) > 0 { + rows, err := tx.QueryContext(ctx, sb.String(), namedArgs(raw)...) + if err != nil { + return err + } + cols, err := rows.Columns() + if err != nil { + rows.Close() + return err + } + jsonIdx, timeIdx := buildColMaps(rows, nil) + buf, err := drain(rows, cols, jsonIdx, timeIdx) rows.Close() - return err + if err != nil { + return err + } + res.cols, res.rows = cols, buf + res.affected, res.hasAff = int64(len(buf)), true + return nil } - jsonIdx, timeIdx := buildColMaps(rows, nil) - buf, err := drain(rows, cols, jsonIdx, timeIdx) - rows.Close() + + out, err := tx.ExecContext(ctx, sb.String(), namedArgs(raw)...) if err != nil { return err } - res.cols, res.rows = cols, buf - res.affected, res.hasAff = int64(len(buf)), true + n, _ := out.RowsAffected() + res.affected, res.hasAff = n, true return nil } +// colIndex returns the position of name in cols, or 0 as a safe fallback. +func colIndex(cols []string, name string) int { + for i, c := range cols { + if c == name { + return i + } + } + return 0 +} + // executeUpdate runs UPDATE [t] SET ... OUTPUT INSERTED.* WHERE ... // Compiler emits: UPDATE [t] SET [c]=@p1 WHERE [id]=@p2 // Rewritten to: UPDATE [t] SET [c]=@p1 OUTPUT INSERTED.[c],... WHERE [id]=@p2 From 5f57b75537e15176ddacf097ca7f216c6a3fcd2c Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Thu, 11 Jun 2026 10:35:01 +0700 Subject: [PATCH 09/12] sqlserver: fix error mapping, IDENTITY_INSERT for upsert, identity introspection asMSSQLError used *mssql.Error as target for errors.As but mssql.Error implements error via a value receiver; switch to value target so constraint violations (547=FK, 2627/2601=unique, 2812=proc not found) are correctly mapped to 409/422/404 instead of falling through to 500. Add schema.Column.Identity to track auto-generated identity columns. SQL Server introspection now fetches COLUMNPROPERTY IsIdentity alongside the existing column query. executeUpsert enables SET IDENTITY_INSERT ON before the MERGE when any conflict column is an identity column, allowing explicit id values to be upserted without error 8101. --- backend/sqlserver/execute.go | 25 +++++++++++++++++++++++-- backend/sqlserver/introspect.go | 10 ++++++---- backend/sqlserver/sqlserver.go | 9 +++++---- schema/model.go | 5 +++++ 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/backend/sqlserver/execute.go b/backend/sqlserver/execute.go index da72a6c..160e388 100644 --- a/backend/sqlserver/execute.go +++ b/backend/sqlserver/execute.go @@ -155,7 +155,7 @@ func (b *Backend) executeInsert( res *writeResult, ) error { if q.Kind == ir.Upsert { - return b.executeUpsert(ctx, tx, q, returning, res) + return b.executeUpsert(ctx, tx, q, returning, rel, res) } st, apiErr := sqlgen.CompileInsert(Dialect{}, q, nil) @@ -205,7 +205,7 @@ func (b *Backend) executeInsert( // The OUTPUT clause captures written rows when returning is requested. func (b *Backend) executeUpsert( ctx context.Context, tx *sql.Tx, - q *ir.Query, returning []string, + q *ir.Query, returning []string, rel *schema.Relation, res *writeResult, ) error { w := q.Write @@ -331,6 +331,16 @@ func (b *Backend) executeUpsert( // MERGE requires a terminating semicolon. sb.WriteString(";") + // When any conflict column is an IDENTITY column and the user provided + // an explicit value, SQL Server requires IDENTITY_INSERT to be ON. + needIdentityInsert := rel != nil && hasIdentityConflictCol(rel, conflictCols) + if needIdentityInsert { + if _, err := tx.ExecContext(ctx, "SET IDENTITY_INSERT "+tableName+" ON"); err != nil { + return err + } + defer func() { _, _ = tx.ExecContext(ctx, "SET IDENTITY_INSERT "+tableName+" OFF") }() + } + if len(returning) > 0 { rows, err := tx.QueryContext(ctx, sb.String(), namedArgs(raw)...) if err != nil { @@ -361,6 +371,17 @@ func (b *Backend) executeUpsert( return nil } +// hasIdentityConflictCol reports whether any of the conflict target columns is +// an identity column in rel. +func hasIdentityConflictCol(rel *schema.Relation, conflictCols []string) bool { + for _, c := range conflictCols { + if col, ok := rel.Column(c); ok && col.Identity { + return true + } + } + return false +} + // colIndex returns the position of name in cols, or 0 as a safe fallback. func colIndex(cols []string, name string) int { for i, c := range cols { diff --git a/backend/sqlserver/introspect.go b/backend/sqlserver/introspect.go index e9f961b..1cd8138 100644 --- a/backend/sqlserver/introspect.go +++ b/backend/sqlserver/introspect.go @@ -81,7 +81,8 @@ func (b *Backend) columns(ctx context.Context, table string) ([]*schema.Column, c.IS_NULLABLE, c.COLUMN_DEFAULT, CASE WHEN k.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS is_pk, - ISNULL(k.ORDINAL_POSITION, 0) AS pk_ord + ISNULL(k.ORDINAL_POSITION, 0) AS pk_ord, + COLUMNPROPERTY(OBJECT_ID(SCHEMA_NAME()+'.'+c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS is_identity FROM INFORMATION_SCHEMA.COLUMNS c LEFT JOIN ( SELECT kcu.COLUMN_NAME, kcu.ORDINAL_POSITION @@ -113,16 +114,17 @@ func (b *Backend) columns(ctx context.Context, table string) ([]*schema.Column, for rows.Next() { var name, dataType, isNullable string var colDefault sql.NullString - var isPK, pkOrd int - if err := rows.Scan(&name, &dataType, &isNullable, &colDefault, &isPK, &pkOrd); err != nil { + var isPK, pkOrd, isIdentity int + if err := rows.Scan(&name, &dataType, &isNullable, &colDefault, &isPK, &pkOrd, &isIdentity); err != nil { return nil, nil, err } - hasDefault := isPK == 1 || colDefault.Valid + hasDefault := isPK == 1 || colDefault.Valid || isIdentity == 1 col := &schema.Column{ Name: name, Type: sqlServerCanonicalType(dataType), Nullable: isNullable == "YES", HasDefault: hasDefault, + Identity: isIdentity == 1, } colRows = append(colRows, colRow{col: col, isPK: isPK == 1, pkOrd: pkOrd}) } diff --git a/backend/sqlserver/sqlserver.go b/backend/sqlserver/sqlserver.go index c31afb6..8fb01ef 100644 --- a/backend/sqlserver/sqlserver.go +++ b/backend/sqlserver/sqlserver.go @@ -120,7 +120,7 @@ func (b *Backend) MapError(err error) *pgerr.APIError { } // mapSQLServerError builds the unified API error from a SQL Server error. -func mapSQLServerError(me *mssql.Error) *pgerr.APIError { +func mapSQLServerError(me mssql.Error) *pgerr.APIError { switch me.Number { case 2627, 2601: // unique constraint / unique index violation return pgerr.ErrUniqueViolation(me.Message) @@ -206,9 +206,10 @@ func sqlServerCanonicalType(dataType string) string { } } -// asMSSQLError unwraps err as a *mssql.Error. -func asMSSQLError(err error) (*mssql.Error, bool) { - var me *mssql.Error +// asMSSQLError unwraps err as a mssql.Error. mssql.Error implements error via a +// value receiver so errors.As requires a value target, not a pointer. +func asMSSQLError(err error) (mssql.Error, bool) { + var me mssql.Error ok := errors.As(err, &me) return me, ok } diff --git a/schema/model.go b/schema/model.go index e243a02..7d29955 100644 --- a/schema/model.go +++ b/schema/model.go @@ -55,6 +55,11 @@ type Column struct { Type string // canonical PG type name (spec 16) Nullable bool HasDefault bool + // Identity reports whether the column is an auto-generated identity/serial + // column (IDENTITY on SQL Server, SERIAL/GENERATED ALWAYS AS IDENTITY on + // PostgreSQL). Backends that support explicit-identity inserts (e.g. SQL + // Server's IDENTITY_INSERT) use this to decide whether to enable it. + Identity bool // Position is the 1-based ordinal, used for stable ordering. Position int } From 6fe23c0ebdab792b8ff133f1f67390c1af488375 Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Thu, 11 Jun 2026 10:40:18 +0700 Subject: [PATCH 10/12] sqlserver: only set IDENTITY_INSERT ON when identity col is in payload --- backend/sqlserver/execute.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/backend/sqlserver/execute.go b/backend/sqlserver/execute.go index 160e388..6106daf 100644 --- a/backend/sqlserver/execute.go +++ b/backend/sqlserver/execute.go @@ -333,7 +333,7 @@ func (b *Backend) executeUpsert( // When any conflict column is an IDENTITY column and the user provided // an explicit value, SQL Server requires IDENTITY_INSERT to be ON. - needIdentityInsert := rel != nil && hasIdentityConflictCol(rel, conflictCols) + needIdentityInsert := rel != nil && hasIdentityConflictCol(rel, conflictCols, w.Columns) if needIdentityInsert { if _, err := tx.ExecContext(ctx, "SET IDENTITY_INSERT "+tableName+" ON"); err != nil { return err @@ -371,11 +371,17 @@ func (b *Backend) executeUpsert( return nil } -// hasIdentityConflictCol reports whether any of the conflict target columns is -// an identity column in rel. -func hasIdentityConflictCol(rel *schema.Relation, conflictCols []string) bool { +// hasIdentityConflictCol reports whether any conflict column is an identity +// column AND is present in payloadCols (the user provided an explicit value). +// When IDENTITY_INSERT is ON, SQL Server requires an explicit value, so we only +// enable it when the identity column is actually in the payload. +func hasIdentityConflictCol(rel *schema.Relation, conflictCols, payloadCols []string) bool { + payload := make(map[string]bool, len(payloadCols)) + for _, c := range payloadCols { + payload[c] = true + } for _, c := range conflictCols { - if col, ok := rel.Column(c); ok && col.Identity { + if col, ok := rel.Column(c); ok && col.Identity && payload[c] { return true } } From b4d2cf36438f865745da68a9c965442f6da9ae17 Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Thu, 11 Jun 2026 16:06:50 +0700 Subject: [PATCH 11/12] mysql: fix IsBool, datetime args, and UpdateReturn Three fixes for the MySQL compat job: 1. IsBool now returns "col = 1/0" instead of falling back to "col IS 1". MySQL 8 only accepts IS NULL/UNKNOWN/TRUE/FALSE, not integer literals. 2. normalizeArgs converts ISO 8601 strings (e.g. "2024-01-01T00:00:00Z") to time.Time before binding. MySQL's string-to-DATETIME cast rejects the T separator and Z timezone suffix; passing time.Time bypasses the cast. Applied to all CompileRead, CompileCount, compileWrite, and CompileCall statement args. 3. Open() sets ClientFoundRows=true so MySQL counts matched rows rather than changed rows. Without this, an UPDATE that sets the same values reports RowsAffected=0, which gates the re-select and returns empty results instead of the matched rows. --- backend/mysql/dialect.go | 9 +++++--- backend/mysql/execute.go | 49 ++++++++++++++++++++++++++++++++++------ backend/mysql/mysql.go | 1 + 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/backend/mysql/dialect.go b/backend/mysql/dialect.go index 4f7451c..f8a9f27 100644 --- a/backend/mysql/dialect.go +++ b/backend/mysql/dialect.go @@ -215,9 +215,12 @@ func (Dialect) ArrayOp(col, op, val, colType string) (string, bool) { // ILike uses plain LIKE; MySQL's default utf8mb4_unicode_ci collation is CI. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } -// IsBool falls back to the generic "IS 1"/"IS 0" form; MySQL treats IS TRUE -// and IS 1 equivalently for TINYINT(1) columns. -func (Dialect) IsBool(string, bool) (string, bool) { return "", false } +// IsBool renders "col = 1" or "col = 0". MySQL 8's IS operator only accepts +// NULL/UNKNOWN/TRUE/FALSE, not integer literals, so "col IS 1" is a syntax +// error; equality works for TINYINT(1) boolean columns. +func (Dialect) IsBool(col string, v bool) (string, bool) { + return col + " = " + Dialect{}.BoolValue(v), true +} // BoolValue renders a boolean as 1/0. MySQL's BOOL is an alias for TINYINT(1), // so there is no native boolean keyword. diff --git a/backend/mysql/execute.go b/backend/mysql/execute.go index 0b3a7d8..9a69bb6 100644 --- a/backend/mysql/execute.go +++ b/backend/mysql/execute.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "strings" + "time" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/backend/sqlgen" @@ -14,6 +15,30 @@ import ( "github.com/tamnd/dbrest/schema" ) +// normalizeArgs converts ISO 8601 datetime strings (e.g. "2024-01-01T00:00:00Z") +// to time.Time so the MySQL driver can bind them correctly. MySQL rejects the ISO +// T-separator format; passing time.Time avoids the string-to-DATETIME cast entirely. +func normalizeArgs(args []any) []any { + if len(args) == 0 { + return args + } + out := make([]any, len(args)) + for i, a := range args { + if s, ok := a.(string); ok { + if t, err := time.Parse(time.RFC3339, s); err == nil { + out[i] = t + continue + } + if t, err := time.Parse("2006-01-02T15:04:05", s); err == nil { + out[i] = t + continue + } + } + out[i] = a + } + return out +} + // Execute lowers a resolved plan to MySQL operations and returns a streamable // result. Reads stream from an open cursor; writes run in a transaction and // buffer their rows (since MySQL 8 has no RETURNING, rows are re-selected after @@ -47,7 +72,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if apiErr != nil { return nil, apiErr } - if err := b.db.QueryRowContext(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil { + if err := b.db.QueryRowContext(ctx, cst.SQL, normalizeArgs(cst.Args)...).Scan(&res.count); err != nil { return nil, b.MapError(err) } res.hasCount = true @@ -57,7 +82,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if apiErr != nil { return nil, apiErr } - rows, err := b.db.QueryContext(ctx, st.SQL, st.Args...) + rows, err := b.db.QueryContext(ctx, st.SQL, normalizeArgs(st.Args)...) if err != nil { return nil, b.MapError(err) } @@ -250,7 +275,7 @@ func (b *Backend) executeUpdateEmulated( if apiErr != nil { return apiErr } - rows, err := tx.QueryContext(ctx, readST.SQL, readST.Args...) + rows, err := tx.QueryContext(ctx, readST.SQL, normalizeArgs(readST.Args)...) if err != nil { return err } @@ -284,7 +309,7 @@ func (b *Backend) executeDeleteEmulated( if apiErr != nil { return apiErr } - rows, err := tx.QueryContext(ctx, readST.SQL, readST.Args...) + rows, err := tx.QueryContext(ctx, readST.SQL, normalizeArgs(readST.Args)...) if err != nil { return err } @@ -322,6 +347,7 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if apiErr != nil { return nil, apiErr } + st.Args = normalizeArgs(st.Args) if plan.ReadOnly { res := &result{controls: rc.Controls()} @@ -383,17 +409,26 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con // compileWrite dispatches to the right compiler for the mutation kind. // When returning is empty the compiler omits the RETURNING / OUTPUT clause. +// Args are normalized for MySQL (ISO 8601 → time.Time) before returning. func compileWrite(q *ir.Query, returning []string) (*sqlgen.Statement, *pgerr.APIError) { + var ( + st *sqlgen.Statement + apiErr *pgerr.APIError + ) switch q.Kind { case ir.Insert, ir.Upsert: - return sqlgen.CompileInsert(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileInsert(Dialect{}, q, returning) case ir.Update: - return sqlgen.CompileUpdate(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileUpdate(Dialect{}, q, returning) case ir.Delete: - return sqlgen.CompileDelete(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileDelete(Dialect{}, q, returning) default: return nil, pgerr.ErrUnsupported("this operation", "mysql") } + if st != nil { + st.Args = normalizeArgs(st.Args) + } + return st, apiErr } // returningCols decides which columns to read back after a write. diff --git a/backend/mysql/mysql.go b/backend/mysql/mysql.go index e5e3996..b5ad482 100644 --- a/backend/mysql/mysql.go +++ b/backend/mysql/mysql.go @@ -52,6 +52,7 @@ func Open(dsn string) (*Backend, error) { return nil, fmt.Errorf("invalid MySQL DSN: %w", err) } cfg.ParseTime = true + cfg.ClientFoundRows = true // report matched rows, not changed rows (UPDATE re-select gate) delete(cfg.Params, "tinyInt1IsBool") // removed in v1.8; schema-layer handles coercion connector, err := mysqldrv.NewConnector(cfg) From ed9bb73b9ba5ac27a9d07c90d33734e848539149 Mon Sep 17 00:00:00 2001 From: Tam Nguyen Duc Date: Thu, 11 Jun 2026 16:14:43 +0700 Subject: [PATCH 12/12] mysql: fix UpdateReturn by pre-capturing PKs before UPDATE The re-select after UPDATE must use primary keys, not the original filter. The original filter may reference a column being updated (e.g. PATCH /todos?task=eq.old with body {task:new}), so after the UPDATE the filter matches nothing and the re-select returns empty rows. Fix: capture matching PKs before the UPDATE executes, then re-select by those PKs after the UPDATE to get the post-mutation representation. --- backend/mysql/execute.go | 92 +++++++++++++++++++++++++++++++++------- 1 file changed, 77 insertions(+), 15 deletions(-) diff --git a/backend/mysql/execute.go b/backend/mysql/execute.go index 9a69bb6..5398061 100644 --- a/backend/mysql/execute.go +++ b/backend/mysql/execute.go @@ -247,7 +247,10 @@ func (b *Backend) executeInsertEmulated( return nil } -// executeUpdateEmulated runs UPDATE then re-selects with the same filter. +// executeUpdateEmulated runs UPDATE then re-selects by pre-captured primary keys. +// The re-select must use PKs, not the original filter, because the UPDATE may +// change the very column being filtered (e.g. PATCH /todos?task=eq.old sets +// task=new — after the UPDATE, task=eq.old matches nothing). func (b *Backend) executeUpdateEmulated( ctx context.Context, tx *sql.Tx, q *ir.Query, returning []string, rel *schema.Relation, @@ -257,6 +260,16 @@ func (b *Backend) executeUpdateEmulated( if apiErr != nil { return apiErr } + + // Pre-capture PKs when we need to return representation. + var pkValues []any + if len(returning) > 0 && len(rel.PrimaryKey) == 1 { + pkValues, apiErr = b.selectPKs(ctx, tx, q, rel.PrimaryKey[0]) + if apiErr != nil { + return apiErr + } + } + out, err := tx.ExecContext(ctx, st.SQL, st.Args...) if err != nil { return err @@ -264,35 +277,84 @@ func (b *Backend) executeUpdateEmulated( n, _ := out.RowsAffected() res.affected, res.hasAff = n, true - if len(returning) == 0 || n == 0 { + if len(returning) == 0 || len(pkValues) == 0 { return nil } - // Re-select: compile the equivalent SELECT with the same filters. - readQ := *q - readQ.Kind = ir.Read - readST, apiErr := sqlgen.CompileRead(Dialect{}, &readQ) + // Re-select by PK (post-update values). + colNames, buf, err := b.selectByPKs(ctx, tx, rel, rel.PrimaryKey[0], pkValues, returning) + if err != nil { + return err + } + res.cols, res.rows = colNames, buf + return nil +} + +// selectPKs runs "SELECT pk FROM table WHERE " and returns +// the raw PK values. Used to anchor the post-write re-select. +func (b *Backend) selectPKs( + ctx context.Context, tx *sql.Tx, + q *ir.Query, pkCol string, +) ([]any, *pgerr.APIError) { + pkQ := *q + pkQ.Kind = ir.Read + pkQ.Select = []ir.SelectItem{ir.Column{Path: []string{pkCol}}} + pkQ.Embeds = nil + pkQ.Order = nil + pkQ.Singular = false + st, apiErr := sqlgen.CompileRead(Dialect{}, &pkQ) if apiErr != nil { - return apiErr + return nil, apiErr } - rows, err := tx.QueryContext(ctx, readST.SQL, normalizeArgs(readST.Args)...) + rows, err := tx.QueryContext(ctx, st.SQL, normalizeArgs(st.Args)...) if err != nil { - return err + return nil, pgerr.New(500, "XX000", err.Error()) + } + defer rows.Close() + var vals []any + for rows.Next() { + var v any + if err := rows.Scan(&v); err != nil { + return nil, pgerr.New(500, "XX000", err.Error()) + } + vals = append(vals, v) + } + if err := rows.Err(); err != nil { + return nil, pgerr.New(500, "XX000", err.Error()) + } + return vals, nil +} + +// selectByPKs runs "SELECT cols FROM table WHERE pk IN (?,...)" using pre-captured +// PK values and returns the column names and buffered rows. +func (b *Backend) selectByPKs( + ctx context.Context, tx *sql.Tx, + rel *schema.Relation, pkCol string, pkValues []any, cols []string, +) ([]string, [][]any, error) { + d := Dialect{} + table := d.QuoteIdent(rel.Name) + pk := d.QuoteIdent(pkCol) + selCols := quotedCols(cols) + placeholders := make([]string, len(pkValues)) + for i := range pkValues { + placeholders[i] = "?" + } + sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s IN (%s)", + selCols, table, pk, strings.Join(placeholders, ",")) + rows, err := tx.QueryContext(ctx, sql, pkValues...) + if err != nil { + return nil, nil, err } colNames, err := rows.Columns() if err != nil { rows.Close() - return err + return nil, nil, err } boolCols := buildBoolCols(rel) jsonIdx, boolIdx, _ := buildColMaps(rows, boolCols) buf, err := drain(rows, colNames, jsonIdx, boolIdx) rows.Close() - if err != nil { - return err - } - res.cols, res.rows = colNames, buf - return nil + return colNames, buf, err } // executeDeleteEmulated selects the rows to return, then deletes them.