diff --git a/backend/mysql/dialect.go b/backend/mysql/dialect.go index 62822f0..f8a9f27 100644 --- a/backend/mysql/dialect.go +++ b/backend/mysql/dialect.go @@ -192,12 +192,36 @@ func (Dialect) SessionRead(string) string { return "" } // SessionWrite reports ok=false: there is no engine setting to write. func (Dialect) SessionWrite(string) (string, bool) { return "", false } -// ArrayOp returns false; MySQL has no native array types or containment operators. -func (Dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +// ArrayOp renders a JSON array containment/overlap expression using MySQL's +// JSON_CONTAINS and JSON_OVERLAPS functions (MySQL 8.0.17+). The column must be +// declared as JSON type; for any other column type ok=false is returned so the +// compiler raises PGRST127. colType is the canonical column type enriched by the +// planner; op is one of "@>" (contains), "<@" (contained-by), "&&" (overlaps). +func (Dialect) ArrayOp(col, op, val, colType string) (string, bool) { + if colType != "json" && colType != "jsonb" { + return "", false + } + switch op { + case "@>": // contains: col contains all elements of val + return "JSON_CONTAINS(" + col + ", " + val + ")", true + case "<@": // contained-by: val contains all elements of col + return "JSON_CONTAINS(" + val + ", " + col + ")", true + case "&&": // overlaps: at least one common element + return "JSON_OVERLAPS(" + col + ", " + val + ")", true + } + return "", false +} // ILike uses plain LIKE; MySQL's default utf8mb4_unicode_ci collation is CI. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } +// IsBool renders "col = 1" or "col = 0". MySQL 8's IS operator only accepts +// NULL/UNKNOWN/TRUE/FALSE, not integer literals, so "col IS 1" is a syntax +// error; equality works for TINYINT(1) boolean columns. +func (Dialect) IsBool(col string, v bool) (string, bool) { + return col + " = " + Dialect{}.BoolValue(v), true +} + // BoolValue renders a boolean as 1/0. MySQL's BOOL is an alias for TINYINT(1), // so there is no native boolean keyword. func (Dialect) BoolValue(v bool) string { @@ -206,3 +230,27 @@ func (Dialect) BoolValue(v bool) string { } return "0" } + +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so JSON_CONTAINS/JSON_OVERLAPS in ArrayOp can process it. +func (Dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText // already JSON or empty; pass through + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + quoted[i] = p // already JSON-quoted + } else { + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} diff --git a/backend/mysql/execute.go b/backend/mysql/execute.go index 0b3a7d8..5398061 100644 --- a/backend/mysql/execute.go +++ b/backend/mysql/execute.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "strings" + "time" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/backend/sqlgen" @@ -14,6 +15,30 @@ import ( "github.com/tamnd/dbrest/schema" ) +// normalizeArgs converts ISO 8601 datetime strings (e.g. "2024-01-01T00:00:00Z") +// to time.Time so the MySQL driver can bind them correctly. MySQL rejects the ISO +// T-separator format; passing time.Time avoids the string-to-DATETIME cast entirely. +func normalizeArgs(args []any) []any { + if len(args) == 0 { + return args + } + out := make([]any, len(args)) + for i, a := range args { + if s, ok := a.(string); ok { + if t, err := time.Parse(time.RFC3339, s); err == nil { + out[i] = t + continue + } + if t, err := time.Parse("2006-01-02T15:04:05", s); err == nil { + out[i] = t + continue + } + } + out[i] = a + } + return out +} + // Execute lowers a resolved plan to MySQL operations and returns a streamable // result. Reads stream from an open cursor; writes run in a transaction and // buffer their rows (since MySQL 8 has no RETURNING, rows are re-selected after @@ -47,7 +72,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if apiErr != nil { return nil, apiErr } - if err := b.db.QueryRowContext(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil { + if err := b.db.QueryRowContext(ctx, cst.SQL, normalizeArgs(cst.Args)...).Scan(&res.count); err != nil { return nil, b.MapError(err) } res.hasCount = true @@ -57,7 +82,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if apiErr != nil { return nil, apiErr } - rows, err := b.db.QueryContext(ctx, st.SQL, st.Args...) + rows, err := b.db.QueryContext(ctx, st.SQL, normalizeArgs(st.Args)...) if err != nil { return nil, b.MapError(err) } @@ -222,7 +247,10 @@ func (b *Backend) executeInsertEmulated( return nil } -// executeUpdateEmulated runs UPDATE then re-selects with the same filter. +// executeUpdateEmulated runs UPDATE then re-selects by pre-captured primary keys. +// The re-select must use PKs, not the original filter, because the UPDATE may +// change the very column being filtered (e.g. PATCH /todos?task=eq.old sets +// task=new — after the UPDATE, task=eq.old matches nothing). func (b *Backend) executeUpdateEmulated( ctx context.Context, tx *sql.Tx, q *ir.Query, returning []string, rel *schema.Relation, @@ -232,6 +260,16 @@ func (b *Backend) executeUpdateEmulated( if apiErr != nil { return apiErr } + + // Pre-capture PKs when we need to return representation. + var pkValues []any + if len(returning) > 0 && len(rel.PrimaryKey) == 1 { + pkValues, apiErr = b.selectPKs(ctx, tx, q, rel.PrimaryKey[0]) + if apiErr != nil { + return apiErr + } + } + out, err := tx.ExecContext(ctx, st.SQL, st.Args...) if err != nil { return err @@ -239,35 +277,84 @@ func (b *Backend) executeUpdateEmulated( n, _ := out.RowsAffected() res.affected, res.hasAff = n, true - if len(returning) == 0 || n == 0 { + if len(returning) == 0 || len(pkValues) == 0 { return nil } - // Re-select: compile the equivalent SELECT with the same filters. - readQ := *q - readQ.Kind = ir.Read - readST, apiErr := sqlgen.CompileRead(Dialect{}, &readQ) + // Re-select by PK (post-update values). + colNames, buf, err := b.selectByPKs(ctx, tx, rel, rel.PrimaryKey[0], pkValues, returning) + if err != nil { + return err + } + res.cols, res.rows = colNames, buf + return nil +} + +// selectPKs runs "SELECT pk FROM table WHERE " and returns +// the raw PK values. Used to anchor the post-write re-select. +func (b *Backend) selectPKs( + ctx context.Context, tx *sql.Tx, + q *ir.Query, pkCol string, +) ([]any, *pgerr.APIError) { + pkQ := *q + pkQ.Kind = ir.Read + pkQ.Select = []ir.SelectItem{ir.Column{Path: []string{pkCol}}} + pkQ.Embeds = nil + pkQ.Order = nil + pkQ.Singular = false + st, apiErr := sqlgen.CompileRead(Dialect{}, &pkQ) if apiErr != nil { - return apiErr + return nil, apiErr } - rows, err := tx.QueryContext(ctx, readST.SQL, readST.Args...) + rows, err := tx.QueryContext(ctx, st.SQL, normalizeArgs(st.Args)...) if err != nil { - return err + return nil, pgerr.New(500, "XX000", err.Error()) + } + defer rows.Close() + var vals []any + for rows.Next() { + var v any + if err := rows.Scan(&v); err != nil { + return nil, pgerr.New(500, "XX000", err.Error()) + } + vals = append(vals, v) + } + if err := rows.Err(); err != nil { + return nil, pgerr.New(500, "XX000", err.Error()) + } + return vals, nil +} + +// selectByPKs runs "SELECT cols FROM table WHERE pk IN (?,...)" using pre-captured +// PK values and returns the column names and buffered rows. +func (b *Backend) selectByPKs( + ctx context.Context, tx *sql.Tx, + rel *schema.Relation, pkCol string, pkValues []any, cols []string, +) ([]string, [][]any, error) { + d := Dialect{} + table := d.QuoteIdent(rel.Name) + pk := d.QuoteIdent(pkCol) + selCols := quotedCols(cols) + placeholders := make([]string, len(pkValues)) + for i := range pkValues { + placeholders[i] = "?" + } + sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s IN (%s)", + selCols, table, pk, strings.Join(placeholders, ",")) + rows, err := tx.QueryContext(ctx, sql, pkValues...) + if err != nil { + return nil, nil, err } colNames, err := rows.Columns() if err != nil { rows.Close() - return err + return nil, nil, err } boolCols := buildBoolCols(rel) jsonIdx, boolIdx, _ := buildColMaps(rows, boolCols) buf, err := drain(rows, colNames, jsonIdx, boolIdx) rows.Close() - if err != nil { - return err - } - res.cols, res.rows = colNames, buf - return nil + return colNames, buf, err } // executeDeleteEmulated selects the rows to return, then deletes them. @@ -284,7 +371,7 @@ func (b *Backend) executeDeleteEmulated( if apiErr != nil { return apiErr } - rows, err := tx.QueryContext(ctx, readST.SQL, readST.Args...) + rows, err := tx.QueryContext(ctx, readST.SQL, normalizeArgs(readST.Args)...) if err != nil { return err } @@ -322,6 +409,7 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if apiErr != nil { return nil, apiErr } + st.Args = normalizeArgs(st.Args) if plan.ReadOnly { res := &result{controls: rc.Controls()} @@ -383,17 +471,26 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con // compileWrite dispatches to the right compiler for the mutation kind. // When returning is empty the compiler omits the RETURNING / OUTPUT clause. +// Args are normalized for MySQL (ISO 8601 → time.Time) before returning. func compileWrite(q *ir.Query, returning []string) (*sqlgen.Statement, *pgerr.APIError) { + var ( + st *sqlgen.Statement + apiErr *pgerr.APIError + ) switch q.Kind { case ir.Insert, ir.Upsert: - return sqlgen.CompileInsert(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileInsert(Dialect{}, q, returning) case ir.Update: - return sqlgen.CompileUpdate(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileUpdate(Dialect{}, q, returning) case ir.Delete: - return sqlgen.CompileDelete(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileDelete(Dialect{}, q, returning) default: return nil, pgerr.ErrUnsupported("this operation", "mysql") } + if st != nil { + st.Args = normalizeArgs(st.Args) + } + return st, apiErr } // returningCols decides which columns to read back after a write. diff --git a/backend/mysql/mysql.go b/backend/mysql/mysql.go index e5e3996..b5ad482 100644 --- a/backend/mysql/mysql.go +++ b/backend/mysql/mysql.go @@ -52,6 +52,7 @@ func Open(dsn string) (*Backend, error) { return nil, fmt.Errorf("invalid MySQL DSN: %w", err) } cfg.ParseTime = true + cfg.ClientFoundRows = true // report matched rows, not changed rows (UPDATE re-select gate) delete(cfg.Params, "tinyInt1IsBool") // removed in v1.8; schema-layer handles coercion connector, err := mysqldrv.NewConnector(cfg) diff --git a/backend/postgres/dialect.go b/backend/postgres/dialect.go index 77fac40..d6b1c45 100644 --- a/backend/postgres/dialect.go +++ b/backend/postgres/dialect.go @@ -222,7 +222,7 @@ func sqlLiteral(s string) string { } // ArrayOp renders a PostgreSQL array containment/overlap expression. -func (Dialect) ArrayOp(col, op, val string) (string, bool) { +func (Dialect) ArrayOp(col, op, val, _ string) (string, bool) { return col + " " + op + " " + val, true } @@ -236,3 +236,11 @@ func (Dialect) BoolValue(v bool) string { } return "FALSE" } + +// IsBool falls back to the generic "IS TRUE"/"IS FALSE" form; PostgreSQL +// supports IS natively. +func (Dialect) IsBool(string, bool) (string, bool) { return "", false } + +// ArrayLiteral returns the PostgreSQL {a,b} array literal unchanged; PostgreSQL +// accepts it natively. +func (Dialect) ArrayLiteral(pgText string) string { return pgText } diff --git a/backend/sqlgen/compile.go b/backend/sqlgen/compile.go index a31e955..55781ba 100644 --- a/backend/sqlgen/compile.go +++ b/backend/sqlgen/compile.go @@ -457,8 +457,9 @@ func (b *builder) writeSelect(items []ir.SelectItem) *pgerr.APIError { } b.sb.WriteString(expr) // Alias the output so the renderer sees the PostgREST key, not the raw - // column. Only needed when the key differs from the bare column name. - if name := col.Name(); name != "" && name != lastPath(col.Path) { + // column expression. Always alias when a cast is present (the expression + // differs from the bare column name) or when an explicit alias was set. + if name := col.Name(); name != "" && (name != lastPath(col.Path) || col.Cast != "") { b.sb.WriteString(" AS ") b.sb.WriteString(b.d.QuoteIdent(name)) } @@ -601,9 +602,11 @@ func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError { default: sqlOp = "&&" } - val := b.bind(c.Value.Text) + // Normalize the PostgreSQL {a,b} array literal to the engine's format + // before binding; the dialect is a no-op for engines that accept {a,b}. + val := b.bind(b.d.ArrayLiteral(c.Value.Text)) var ok bool - frag, ok = b.d.ArrayOp(col, sqlOp, val) + frag, ok = b.d.ArrayOp(col, sqlOp, val, c.ColumnType) if !ok { return pgerr.ErrUnsupported("array operator "+sqlOp, "sql") } @@ -697,8 +700,14 @@ func (b *builder) writeIs(col, text string) (string, *pgerr.APIError) { case "not_null": return col + " IS NOT NULL", nil case "true": + if frag, ok := b.d.IsBool(col, true); ok { + return frag, nil + } return col + " IS " + b.d.BoolValue(true), nil case "false": + if frag, ok := b.d.IsBool(col, false); ok { + return frag, nil + } return col + " IS " + b.d.BoolValue(false), nil default: return "", pgerr.ErrParse("unknown is value " + text) diff --git a/backend/sqlgen/compile_test.go b/backend/sqlgen/compile_test.go index 9a80da1..ad86b38 100644 --- a/backend/sqlgen/compile_test.go +++ b/backend/sqlgen/compile_test.go @@ -76,9 +76,10 @@ func (stub) FullText(col string, _ *FullTextRef, v ir.FTSVariant, _, _ string) ( } func (stub) SessionRead(k string) string { return "" } func (stub) SessionWrite(k string) (string, bool) { return "", false } -func (stub) ArrayOp(col, op, val string) (string, bool) { +func (stub) ArrayOp(col, op, val, _ string) (string, bool) { return col + " " + op + " " + val, true } +func (stub) ArrayLiteral(s string) string { return s } func (stub) ILike(col, val string) (string, bool) { return col + " ILIKE " + val, true } func (stub) BoolValue(v bool) string { if v { @@ -86,6 +87,7 @@ func (stub) BoolValue(v bool) string { } return "FALSE" } +func (stub) IsBool(string, bool) (string, bool) { return "", false } func compile(t *testing.T, q *ir.Query) *Statement { t.Helper() diff --git a/backend/sqlgen/dialect.go b/backend/sqlgen/dialect.go index a18da81..a0453e4 100644 --- a/backend/sqlgen/dialect.go +++ b/backend/sqlgen/dialect.go @@ -86,12 +86,27 @@ type Dialect interface { // BoolValue renders a boolean literal. BoolValue(v bool) string + // IsBool renders "col IS TRUE" or "col IS FALSE" in the engine's syntax. + // Engines that restrict IS to NULL/UNKNOWN (SQL Server) return ok=true with + // a = expression; engines that support IS return ok=false to fall back + // to "col IS ". + IsBool(col string, v bool) (string, bool) + // ArrayOp renders an array containment/overlap operator expression, or - // reports ok=false when the engine does not support array types (SQLite, - // MySQL, SQL Server). The compiler emits PGRST127 when ok=false. - // op is one of "@>", "<@", "&&"; col is the quoted column; val is the - // placeholder returned by bind(). - ArrayOp(col, op, val string) (string, bool) + // reports ok=false when the engine does not support array types (MySQL, SQL + // Server) or when the column type does not support array semantics (SQLite + // requires a JSON-typed column for json_each). The compiler emits PGRST127 + // when ok=false. op is one of "@>", "<@", "&&"; col is the quoted column; + // val is the placeholder returned by bind(); colType is the canonical + // column type resolved by the planner ("json", "text", "integer", …). + ArrayOp(col, op, val, colType string) (string, bool) + + // ArrayLiteral converts a PostgREST array literal (PostgreSQL {a,b} syntax) + // to the engine's native format for use as a bound parameter. PostgreSQL + // accepts {a,b} natively; SQLite needs JSON ["a","b"]. Other engines that + // do not support arrays may return the text unchanged (they never reach + // ArrayOp either). + ArrayLiteral(pgText string) string } // PatternMark is the sentinel a Dialect.Regex fragment carries where the bound diff --git a/backend/sqlgen/embed.go b/backend/sqlgen/embed.go index 7747d15..006b5e3 100644 --- a/backend/sqlgen/embed.go +++ b/backend/sqlgen/embed.go @@ -168,7 +168,12 @@ func (b *builder) writeEmbed(emb *ir.Embed, parentAlias string) *pgerr.APIError if err := b.writeEmbedFilter(emb, alias); err != nil { return err } - b.sb.WriteString(" LIMIT 1)") + lim := 1 + if lo := b.d.LimitOffset(&lim, nil, false); lo != "" { + b.sb.WriteString(" ") + b.sb.WriteString(lo) + } + b.sb.WriteString(")") return nil } diff --git a/backend/sqlite/dialect.go b/backend/sqlite/dialect.go index a1c8c1b..dcafa17 100644 --- a/backend/sqlite/dialect.go +++ b/backend/sqlite/dialect.go @@ -146,12 +146,57 @@ func (dialect) SessionRead(string) string { return "" } // SessionWrite reports ok=false: there is no engine setting to write. func (dialect) SessionWrite(string) (string, bool) { return "", false } -// ArrayOp returns false; SQLite has no array types or containment operators. -func (dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so json_each() in ArrayOp can iterate over it. +func (dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText // already JSON or empty; pass through + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + quoted[i] = p // already JSON-quoted + } else { + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} + +// ArrayOp implements array containment/overlap via SQLite's json_each(). The +// column must be declared as JSON type and store a JSON array (e.g. +// '["cat","work"]'). For any other column type the operator is unsupported +// (ok=false) so the compiler raises PGRST127. op is one of "@>" (contains), +// "<@" (contained-by), "&&" (overlaps). +func (dialect) ArrayOp(col, op, val, colType string) (string, bool) { + if colType != "json" && colType != "jsonb" { + return "", false + } + switch op { + case "@>": // contains: every element of val appears in col + return "NOT EXISTS (SELECT 1 FROM json_each(" + val + ") AS f WHERE f.value NOT IN (SELECT value FROM json_each(" + col + ")))", true + case "<@": // contained-by: every element of col appears in val + return "NOT EXISTS (SELECT 1 FROM json_each(" + col + ") AS f WHERE f.value NOT IN (SELECT value FROM json_each(" + val + ")))", true + case "&&": // overlaps: at least one common element + return "EXISTS (SELECT 1 FROM json_each(" + col + ") AS f WHERE f.value IN (SELECT value FROM json_each(" + val + ")))", true + } + return "", false +} // ILike uses plain LIKE which is case-insensitive for ASCII in SQLite. func (dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } +// IsBool falls back to the generic "IS 1"/"IS 0" form; SQLite's IS operator is +// a NULL-safe equality that works with any value. +func (dialect) IsBool(string, bool) (string, bool) { return "", false } + // BoolValue renders a boolean as 1/0; SQLite has no native boolean. func (dialect) BoolValue(v bool) string { if v { diff --git a/backend/sqlite/result.go b/backend/sqlite/result.go index 622e744..73a8a29 100644 --- a/backend/sqlite/result.go +++ b/backend/sqlite/result.go @@ -2,7 +2,9 @@ package sqlite import ( "database/sql" + "encoding/json" "io" + "strings" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/reqctx" @@ -64,8 +66,9 @@ func (s *bufStream) Close() error { return nil } // rowStream is a forward-only cursor over the result rows. Values decode each // row into a []any the renderer maps to JSON by column name. type rowStream struct { - rows *sql.Rows - cols []string + rows *sql.Rows + cols []string + colTypes []*sql.ColumnType // lazily populated on first call to Values } func (s *rowStream) Columns() []string { return s.cols } @@ -74,9 +77,19 @@ func (s *rowStream) Err() error { return s.rows.Err() } func (s *rowStream) Close() error { return s.rows.Close() } // Values scans the current row into Go values. SQLite returns int64, float64, -// string, []byte, or nil; []byte is normalized to string so text columns render -// as JSON strings rather than base64. +// string, []byte, or nil. Post-scan coercions: +// - []byte → string so text columns render as JSON strings rather than base64. +// - BOOLEAN/BOOL declared columns: int64 0/1 → false/true so JSON marshals +// correctly as false/true rather than 0/1. +// - JSON declared columns: string → json.RawMessage so the JSON encoder embeds +// the value verbatim rather than quoting it as a string. func (s *rowStream) Values() ([]any, error) { + if s.colTypes == nil { + ct, err := s.rows.ColumnTypes() + if err == nil { + s.colTypes = ct + } + } holders := make([]any, len(s.cols)) ptrs := make([]any, len(s.cols)) for i := range holders { @@ -87,7 +100,20 @@ func (s *rowStream) Values() ([]any, error) { } for i, v := range holders { if b, ok := v.([]byte); ok { - holders[i] = string(b) + v = string(b) + holders[i] = v + } + if s.colTypes != nil && i < len(s.colTypes) { + switch strings.ToUpper(s.colTypes[i].DatabaseTypeName()) { + case "BOOLEAN", "BOOL": + if n, ok := v.(int64); ok { + holders[i] = n != 0 + } + case "JSON": + if str, ok := v.(string); ok && json.Valid([]byte(str)) { + holders[i] = json.RawMessage(str) + } + } } } return holders, nil diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index e817051..8a04fa7 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -4,9 +4,11 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "errors" "fmt" "regexp" + "strings" sqlitedrv "modernc.org/sqlite" sqlite3 "modernc.org/sqlite/lib" @@ -68,6 +70,13 @@ func Open(dsn string) (*Backend, error) { if err != nil { return nil, err } + // SQLite does not enforce FK constraints by default. Pin to one connection so + // the PRAGMA stays in effect for the lifetime of the pool. + db.SetMaxOpenConns(1) + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + db.Close() + return nil, err + } if err := db.Ping(); err != nil { db.Close() return nil, err @@ -357,9 +366,11 @@ func returningCols(q *ir.Query, rel *schema.Relation) []string { return nil } -// drain reads every row of a returning cursor into memory, normalizing []byte to -// string so text columns render as JSON strings. +// drain reads every row of a returning cursor into memory, applying the same +// type coercions as rowStream.Values: []byte→string, BOOLEAN int64→bool, +// JSON string→json.RawMessage. func drain(rows *sql.Rows, ncols int) ([][]any, error) { + colTypes, _ := rows.ColumnTypes() var out [][]any for rows.Next() { holders := make([]any, ncols) @@ -372,7 +383,20 @@ func drain(rows *sql.Rows, ncols int) ([][]any, error) { } for i, v := range holders { if bs, ok := v.([]byte); ok { - holders[i] = string(bs) + v = string(bs) + holders[i] = v + } + if colTypes != nil && i < len(colTypes) { + switch strings.ToUpper(colTypes[i].DatabaseTypeName()) { + case "BOOLEAN", "BOOL": + if n, ok := v.(int64); ok { + holders[i] = n != 0 + } + case "JSON": + if str, ok := v.(string); ok && json.Valid([]byte(str)) { + holders[i] = json.RawMessage(str) + } + } } } out = append(out, holders) diff --git a/backend/sqlserver/dialect.go b/backend/sqlserver/dialect.go index 8fac96c..0578d98 100644 --- a/backend/sqlserver/dialect.go +++ b/backend/sqlserver/dialect.go @@ -50,6 +50,11 @@ func (Dialect) Placeholder(n int) string { return "@p" + strconv.Itoa(n) } // used uniformly. func (Dialect) LimitOffset(limit, offset *int, hasOrder bool) string { if limit == nil && offset == nil { + if hasOrder { + // ORDER BY in a derived table requires OFFSET even when no paging is + // requested; OFFSET 0 ROWS keeps all rows while making the ORDER BY valid. + return "OFFSET 0 ROWS" + } return "" } off := 0 @@ -135,12 +140,13 @@ func (Dialect) JSONObject(pairs []sqlgen.Pair) string { return "JSON_OBJECT(" + strings.Join(parts, ", ") + ")" } -// JSONAgg aggregates rows with the SQL Server 2022 JSON_ARRAYAGG. The aggregate -// takes no ORDER BY argument, so a requested embed order is applied on the -// derived table feeding the aggregate, not here; orderBy is therefore unused and -// the row order within the array is best-effort (spec 06). +// JSONAgg aggregates rows into a JSON array using STRING_AGG. JSON_ARRAYAGG was +// only added in SQL Server 2025 (version 17); for 2022 compatibility the dialect +// constructs the array manually: '[' + STRING_AGG(elem,',') + ']'. The elements +// are cast to NVARCHAR(MAX) so STRING_AGG accepts them. orderBy is unused; a +// requested embed order is applied on the derived table feeding the aggregate. func (Dialect) JSONAgg(elem, _ string) string { - return "JSON_ARRAYAGG(" + elem + ")" + return "'['+STRING_AGG(CAST((" + elem + ") AS NVARCHAR(MAX)),',')+']'" } // Cast translates a canonical type to a T-SQL CAST target. SQL Server has no @@ -218,8 +224,30 @@ func (Dialect) SessionWrite(key string) (string, bool) { return "EXEC sp_set_session_context N'" + strings.ReplaceAll(key, "'", "''") + "', " + sqlgen.PatternMark, true } -// ArrayOp returns false; SQL Server has no array types or containment operators. -func (Dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +// ArrayOp implements array containment/overlap operators using OPENJSON, which +// parses the JSON array argument and the JSON array column for element-level +// comparisons. val is a bound placeholder (@pN) whose value is a JSON array +// string (converted from PostgreSQL {a,b} syntax by ArrayLiteral). +func (Dialect) ArrayOp(col, op, val, _ string) (string, bool) { + switch op { + case "@>": + // col contains every element of val + return "NOT EXISTS(SELECT [value] FROM OPENJSON(" + val + ") WHERE [value] NOT IN (SELECT [value] FROM OPENJSON(" + col + ")))", true + case "<@": + // every element of col exists in val + return "NOT EXISTS(SELECT [value] FROM OPENJSON(" + col + ") WHERE [value] NOT IN (SELECT [value] FROM OPENJSON(" + val + ")))", true + case "&&": + // at least one element in common + return "EXISTS(SELECT 1 FROM OPENJSON(" + col + ") a WHERE a.[value] IN (SELECT [value] FROM OPENJSON(" + val + ")))", true + } + return "", false +} + +// IsBool renders "col = 1" or "col = 0" for SQL Server BIT columns. SQL +// Server's IS operator only accepts NULL/UNKNOWN, not integer literals. +func (Dialect) IsBool(col string, v bool) (string, bool) { + return col + " = " + Dialect{}.BoolValue(v), true +} // ILike uses plain LIKE; SQL Server's default collation is case-insensitive. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } @@ -232,3 +260,30 @@ func (Dialect) BoolValue(v bool) string { } return "0" } + +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so OPENJSON in ArrayOp can iterate over it. +func (Dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + if p == "NULL" { + quoted[i] = "null" + } else if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + // PostgreSQL double-quote escaping: "foo" is already valid JSON; pass through. + quoted[i] = p + } else { + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} diff --git a/backend/sqlserver/dialect_test.go b/backend/sqlserver/dialect_test.go index ea41b2f..c4f29f5 100644 --- a/backend/sqlserver/dialect_test.go +++ b/backend/sqlserver/dialect_test.go @@ -98,7 +98,7 @@ func TestJSON(t *testing.T) { if obj != "JSON_OBJECT('name': d.[name], 'year': d.[year])" { t.Errorf("JSONObject = %q", obj) } - if got := d.JSONAgg("t", "t.[id] DESC"); got != "JSON_ARRAYAGG(t)" { + if got := d.JSONAgg("t", "t.[id] DESC"); got != "'['+STRING_AGG(CAST((t) AS NVARCHAR(MAX)),',')+']'" { t.Errorf("JSONAgg = %q", got) } } diff --git a/backend/sqlserver/execute.go b/backend/sqlserver/execute.go index b5ed5d1..6106daf 100644 --- a/backend/sqlserver/execute.go +++ b/backend/sqlserver/execute.go @@ -3,6 +3,8 @@ package sqlserver import ( "context" "database/sql" + "encoding/json" + "strconv" "strings" "github.com/tamnd/dbrest/backend" @@ -145,11 +147,17 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co // The compiler emits: INSERT INTO [t] ([c1],[c2]) VALUES (@p1,@p2) // The data plane rewrites to: INSERT INTO [t] ([c1],[c2]) OUTPUT INSERTED.[c1],... VALUES (@p1,@p2) // by injecting the OUTPUT fragment before the " VALUES " marker. +// Upsert (on_conflict) is routed to executeUpsert instead of the single-statement +// compiler, which returns errUpsertMultiStatement. func (b *Backend) executeInsert( ctx context.Context, tx *sql.Tx, q *ir.Query, returning []string, rel *schema.Relation, res *writeResult, ) error { + if q.Kind == ir.Upsert { + return b.executeUpsert(ctx, tx, q, returning, rel, res) + } + st, apiErr := sqlgen.CompileInsert(Dialect{}, q, nil) if apiErr != nil { return apiErr @@ -187,6 +195,209 @@ func (b *Backend) executeInsert( return nil } +// executeUpsert implements the SQL Server upsert as a single MERGE statement per +// batch. MERGE avoids the semicolon-separated multi-statement pattern that +// go-mssqldb rejects when sent via sp_executesql. +// +// All rows are merged in one statement: the source is a VALUES(...) table with +// one row-tuple per input row; the ON clause matches the conflict (primary-key) +// columns; WHEN MATCHED updates non-key columns; WHEN NOT MATCHED inserts. +// The OUTPUT clause captures written rows when returning is requested. +func (b *Backend) executeUpsert( + ctx context.Context, tx *sql.Tx, + q *ir.Query, returning []string, rel *schema.Relation, + res *writeResult, +) error { + w := q.Write + if len(w.Rows) == 0 { + res.affected, res.hasAff = 0, true + return nil + } + + d := Dialect{} + sch := q.Relation.Schema + if sch == "" { + sch = b.schema + if sch == "" { + sch = "dbo" + } + } + tableName := d.QuoteIdent(sch) + "." + d.QuoteIdent(q.Relation.Name) + + conflictCols := w.Conflict.Target + conflictSet := make(map[string]bool, len(conflictCols)) + for _, c := range conflictCols { + conflictSet[c] = true + } + nonConflictCols := make([]string, 0, len(w.Columns)) + for _, c := range w.Columns { + if !conflictSet[c] { + nonConflictCols = append(nonConflictCols, c) + } + } + + // Collect args; @pN bind positions match the order we append. + raw := []any{} + argN := 0 + bind := func(v any) string { + argN++ + raw = append(raw, v) + return "@p" + strconv.Itoa(argN) + } + + // Build the source alias column names: s0, s1, ... + srcCols := make([]string, len(w.Columns)) + for i := range w.Columns { + srcCols[i] = "s" + strconv.Itoa(i) + } + + var sb strings.Builder + + // MERGE INTO target USING (VALUES (...),(...)) AS src(s0,s1,...) + sb.WriteString("MERGE INTO ") + sb.WriteString(tableName) + sb.WriteString(" WITH (HOLDLOCK) AS [_target] USING (VALUES ") + for ri, row := range w.Rows { + if ri > 0 { + sb.WriteString(",") + } + sb.WriteString("(") + for ci, c := range w.Columns { + if ci > 0 { + sb.WriteString(",") + } + sb.WriteString(bind(sqlgen.WriteArg(row[c]))) + } + sb.WriteString(")") + } + sb.WriteString(") AS [_src](") + for i, sc := range srcCols { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString(d.QuoteIdent(sc)) + } + sb.WriteString(") ON (") + // ON conflict columns match + for i, c := range conflictCols { + if i > 0 { + sb.WriteString(" AND ") + } + ci := colIndex(w.Columns, c) + sb.WriteString("[_target]." + d.QuoteIdent(c) + "=[_src]." + d.QuoteIdent(srcCols[ci])) + } + sb.WriteString(")") + + // WHEN MATCHED THEN UPDATE (skip if ignore or no non-conflict cols) + if w.Conflict.Resolution != ir.ConflictIgnore && len(nonConflictCols) > 0 { + sb.WriteString(" WHEN MATCHED THEN UPDATE SET ") + for i, c := range nonConflictCols { + if i > 0 { + sb.WriteString(",") + } + ci := colIndex(w.Columns, c) + sb.WriteString("[_target]." + d.QuoteIdent(c) + "=[_src]." + d.QuoteIdent(srcCols[ci])) + } + } + + // WHEN NOT MATCHED THEN INSERT + sb.WriteString(" WHEN NOT MATCHED THEN INSERT (") + for i, c := range w.Columns { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString(d.QuoteIdent(c)) + } + sb.WriteString(") VALUES (") + for i, sc := range srcCols { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString("[_src]." + d.QuoteIdent(sc)) + } + sb.WriteString(")") + + // OUTPUT clause when returning is requested + if len(returning) > 0 { + sb.WriteString(" OUTPUT ") + for i, c := range returning { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString("INSERTED." + d.QuoteIdent(c)) + } + } + + // MERGE requires a terminating semicolon. + sb.WriteString(";") + + // When any conflict column is an IDENTITY column and the user provided + // an explicit value, SQL Server requires IDENTITY_INSERT to be ON. + needIdentityInsert := rel != nil && hasIdentityConflictCol(rel, conflictCols, w.Columns) + if needIdentityInsert { + if _, err := tx.ExecContext(ctx, "SET IDENTITY_INSERT "+tableName+" ON"); err != nil { + return err + } + defer func() { _, _ = tx.ExecContext(ctx, "SET IDENTITY_INSERT "+tableName+" OFF") }() + } + + if len(returning) > 0 { + rows, err := tx.QueryContext(ctx, sb.String(), namedArgs(raw)...) + if err != nil { + return err + } + cols, err := rows.Columns() + if err != nil { + rows.Close() + return err + } + jsonIdx, timeIdx := buildColMaps(rows, nil) + buf, err := drain(rows, cols, jsonIdx, timeIdx) + rows.Close() + if err != nil { + return err + } + res.cols, res.rows = cols, buf + res.affected, res.hasAff = int64(len(buf)), true + return nil + } + + out, err := tx.ExecContext(ctx, sb.String(), namedArgs(raw)...) + if err != nil { + return err + } + n, _ := out.RowsAffected() + res.affected, res.hasAff = n, true + return nil +} + +// hasIdentityConflictCol reports whether any conflict column is an identity +// column AND is present in payloadCols (the user provided an explicit value). +// When IDENTITY_INSERT is ON, SQL Server requires an explicit value, so we only +// enable it when the identity column is actually in the payload. +func hasIdentityConflictCol(rel *schema.Relation, conflictCols, payloadCols []string) bool { + payload := make(map[string]bool, len(payloadCols)) + for _, c := range payloadCols { + payload[c] = true + } + for _, c := range conflictCols { + if col, ok := rel.Column(c); ok && col.Identity && payload[c] { + return true + } + } + return false +} + +// colIndex returns the position of name in cols, or 0 as a safe fallback. +func colIndex(cols []string, name string) int { + for i, c := range cols { + if c == name { + return i + } + } + return 0 +} + // executeUpdate runs UPDATE [t] SET ... OUTPUT INSERTED.* WHERE ... // Compiler emits: UPDATE [t] SET [c]=@p1 WHERE [id]=@p2 // Rewritten to: UPDATE [t] SET [c]=@p1 OUTPUT INSERTED.[c],... WHERE [id]=@p2 @@ -279,14 +490,26 @@ func (b *Backend) executeDelete( // executeCall runs a stored procedure or portable RPC function. func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - st, apiErr := sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func) + var st *sqlgen.Statement + var apiErr *pgerr.APIError + if plan.Func != nil { + // Portable registry function: the function body is a parameterised SQL + // statement whose :name placeholders are bound by CompileCall. + st, apiErr = sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func) + } else { + // Native RPC (NativeRPC=true): no registry function — generate EXEC + // [schema].[name] @param = @pN from the call's argument map. + st, apiErr = b.compileNativeCall(plan.Call) + } if apiErr != nil { return nil, apiErr } if plan.ReadOnly { res := &result{controls: rc.Controls()} - if plan.Call.Count != ir.CountNone { + // count=exact is only supported for portable registry functions; native + // stored procedures cannot be wrapped in SELECT count(*) in T-SQL. + if plan.Call.Count != ir.CountNone && plan.Func != nil { cst, apiErr := sqlgen.CompileCallCount(Dialect{}, plan.Call, plan.Func) if apiErr != nil { return nil, apiErr @@ -387,6 +610,57 @@ func returningCols(q *ir.Query, rel *schema.Relation) []string { return nil } +// compileNativeCall generates EXEC [schema].[name] @arg1 = @p1, @arg2 = @p2 for +// the NativeRPC path (plan.Func == nil). SQL Server stored procedures accept +// named parameters in any order, so the argument map can be emitted as-is. +// Scalar stored procedures should SELECT the result in a column named after the +// function (e.g. SELECT @a + @b AS [add]) so renderCall can detect scalar return +// by seeing a single column whose name matches the function name. +func (b *Backend) compileNativeCall(c *ir.Call) (*sqlgen.Statement, *pgerr.APIError) { + sch := b.schema + if sch == "" { + sch = "dbo" + } + d := Dialect{} + var sb strings.Builder + sb.WriteString("EXEC ") + sb.WriteString(d.QuoteIdent(sch)) + sb.WriteString(".") + sb.WriteString(d.QuoteIdent(c.Function.Name)) + + args := make([]any, 0, len(c.Args)) + i := 1 + for name, val := range c.Args { + if i == 1 { + sb.WriteString(" ") + } else { + sb.WriteString(", ") + } + sb.WriteString("@" + name + " = @p" + strconv.Itoa(i)) + // A POST arg has a decoded JSON value; a GET arg is raw text. + if val.JSON != nil { + args = append(args, nativeArgValue(val.JSON)) + } else { + args = append(args, val.Text) + } + i++ + } + return &sqlgen.Statement{SQL: sb.String(), Args: args}, nil +} + +// nativeArgValue converts a decoded JSON argument value to a driver-ready type. +// Scalars (string, float64, bool, nil) pass through; composite values are +// re-encoded as JSON text so the stored procedure can receive them as NVARCHAR. +func nativeArgValue(v any) any { + switch v.(type) { + case string, float64, bool, nil: + return v + default: + b, _ := json.Marshal(v) + return string(b) + } +} + // _ is a compile-time check that Backend implements backend.DB. var _ interface { Execute(context.Context, *ir.Plan, *reqctx.Context) (backend.Result, error) diff --git a/backend/sqlserver/introspect.go b/backend/sqlserver/introspect.go index e9f961b..1cd8138 100644 --- a/backend/sqlserver/introspect.go +++ b/backend/sqlserver/introspect.go @@ -81,7 +81,8 @@ func (b *Backend) columns(ctx context.Context, table string) ([]*schema.Column, c.IS_NULLABLE, c.COLUMN_DEFAULT, CASE WHEN k.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS is_pk, - ISNULL(k.ORDINAL_POSITION, 0) AS pk_ord + ISNULL(k.ORDINAL_POSITION, 0) AS pk_ord, + COLUMNPROPERTY(OBJECT_ID(SCHEMA_NAME()+'.'+c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS is_identity FROM INFORMATION_SCHEMA.COLUMNS c LEFT JOIN ( SELECT kcu.COLUMN_NAME, kcu.ORDINAL_POSITION @@ -113,16 +114,17 @@ func (b *Backend) columns(ctx context.Context, table string) ([]*schema.Column, for rows.Next() { var name, dataType, isNullable string var colDefault sql.NullString - var isPK, pkOrd int - if err := rows.Scan(&name, &dataType, &isNullable, &colDefault, &isPK, &pkOrd); err != nil { + var isPK, pkOrd, isIdentity int + if err := rows.Scan(&name, &dataType, &isNullable, &colDefault, &isPK, &pkOrd, &isIdentity); err != nil { return nil, nil, err } - hasDefault := isPK == 1 || colDefault.Valid + hasDefault := isPK == 1 || colDefault.Valid || isIdentity == 1 col := &schema.Column{ Name: name, Type: sqlServerCanonicalType(dataType), Nullable: isNullable == "YES", HasDefault: hasDefault, + Identity: isIdentity == 1, } colRows = append(colRows, colRow{col: col, isPK: isPK == 1, pkOrd: pkOrd}) } diff --git a/backend/sqlserver/sqlserver.go b/backend/sqlserver/sqlserver.go index c31afb6..8fb01ef 100644 --- a/backend/sqlserver/sqlserver.go +++ b/backend/sqlserver/sqlserver.go @@ -120,7 +120,7 @@ func (b *Backend) MapError(err error) *pgerr.APIError { } // mapSQLServerError builds the unified API error from a SQL Server error. -func mapSQLServerError(me *mssql.Error) *pgerr.APIError { +func mapSQLServerError(me mssql.Error) *pgerr.APIError { switch me.Number { case 2627, 2601: // unique constraint / unique index violation return pgerr.ErrUniqueViolation(me.Message) @@ -206,9 +206,10 @@ func sqlServerCanonicalType(dataType string) string { } } -// asMSSQLError unwraps err as a *mssql.Error. -func asMSSQLError(err error) (*mssql.Error, bool) { - var me *mssql.Error +// asMSSQLError unwraps err as a mssql.Error. mssql.Error implements error via a +// value receiver so errors.As requires a value target, not a pointer. +func asMSSQLError(err error) (mssql.Error, bool) { + var me mssql.Error ok := errors.As(err, &me) return me, ok } diff --git a/cmd/dbrest/main.go b/cmd/dbrest/main.go index 8fa4fcb..9eca13f 100644 --- a/cmd/dbrest/main.go +++ b/cmd/dbrest/main.go @@ -21,6 +21,7 @@ import ( _ "github.com/tamnd/dbrest/backend/sqlserver" "github.com/tamnd/dbrest/config" "github.com/tamnd/dbrest/httpapi" + "github.com/tamnd/dbrest/rpc" ) func main() { @@ -79,6 +80,17 @@ func openBackend(cfg *config.Config) (backend.Backend, error) { if sc, ok := be.(interface{ SetSchemas([]string) }); ok { sc.SetSchemas(cfg.Schemas) } + // Wire declared function registry for backends that cannot discover + // functions from an engine catalog (NativeRPC=false: SQLite, MySQL, …). + if cfg.FunctionRegistry != "" { + reg, err := rpc.ParseRegistry(cfg.FunctionRegistry) + if err != nil { + return nil, fmt.Errorf("function-registry: %w", err) + } + if r, ok := be.(interface{ Register(rpc.Registry) }); ok { + r.Register(reg) + } + } return be, nil } diff --git a/ir/ir.go b/ir/ir.go index a081d21..49537ab 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -222,6 +222,11 @@ type Compare struct { FTS FTSVariant Config string FullText *schema.FullTextIndex + // ColumnType is the canonical type of the column at Path[0], resolved by + // the planner from the schema. The dialect uses it to decide whether an + // engine-specific operator (e.g. json_each for array ops on SQLite) can + // apply; it is empty when the column is unknown or for multi-step paths. + ColumnType string } func (Compare) isCond() {} diff --git a/plan/plan.go b/plan/plan.go index 1d2740b..a490f48 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -323,6 +323,15 @@ func validateCond(rel *schema.Relation, c *ir.Cond) *pgerr.APIError { n.FullText = rel.FullTextIndexFor(n.Path[0]) *c = n } + // Array operators carry the column's canonical type so the dialect can + // decide whether the column supports array semantics (e.g. SQLite's + // json_each only applies to JSON-typed columns). See spec 21. + if (n.Op == ir.OpContains || n.Op == ir.OpContained || n.Op == ir.OpOverlap) && len(n.Path) == 1 { + if col, ok := rel.Column(n.Path[0]); ok { + n.ColumnType = col.Type + *c = n + } + } } return nil } diff --git a/rpc/registry.go b/rpc/registry.go index 596233a..e1c1467 100644 --- a/rpc/registry.go +++ b/rpc/registry.go @@ -9,7 +9,12 @@ // *Function without a cycle. package rpc -import "sort" +import ( + "encoding/json" + "fmt" + "sort" + "strings" +) // Volatility classifies a function's effect, which fixes the methods it allows // and the transaction mode it runs in (spec 12). A registry entry that omits it @@ -218,6 +223,98 @@ func exactMatch(f *Function, args ArgSet) bool { return true } +// ParseRegistry decodes a JSON function-registry declaration into a +// StaticRegistry ready to Register on a backend. The JSON is an array of +// function objects; each carries: +// +// name string required; bare function name +// sql string required; parameterized SQL with :name placeholders +// params []{name, type, optional?, default?} +// returns {kind: "scalar"|"setof"|"table", type?, columns?} +// volatility "volatile"|"stable"|"immutable" (default: volatile) +// +// Returns an error when the JSON is malformed; an empty array yields an empty +// registry. Schemas are stripped from names; a name of "api.add" resolves as "add". +func ParseRegistry(rawJSON string) (*StaticRegistry, error) { + rawJSON = strings.TrimSpace(rawJSON) + if rawJSON == "" { + return NewStaticRegistry(nil), nil + } + type paramDecl struct { + Name string `json:"name"` + Type string `json:"type"` + Optional bool `json:"optional"` + Default any `json:"default"` + } + type returnDecl struct { + Kind string `json:"kind"` + Type string `json:"type"` + Columns []struct { + Name string `json:"name"` + Type string `json:"type"` + } `json:"columns"` + } + type fnDecl struct { + Name string `json:"name"` + SQL string `json:"sql"` + Params []paramDecl `json:"params"` + Returns returnDecl `json:"returns"` + Volatility string `json:"volatility"` + } + var decls []fnDecl + if err := json.Unmarshal([]byte(rawJSON), &decls); err != nil { + return nil, fmt.Errorf("function-registry: %w", err) + } + fns := make([]*Function, 0, len(decls)) + for _, d := range decls { + // Strip schema prefix (e.g. "api.add" → "add"). + name := d.Name + if dot := strings.LastIndex(name, "."); dot >= 0 { + name = name[dot+1:] + } + var vol Volatility + switch strings.ToLower(d.Volatility) { + case "stable": + vol = Stable + case "immutable": + vol = Immutable + default: + vol = Volatile + } + params := make([]Param, len(d.Params)) + for i, p := range d.Params { + params[i] = Param{ + Name: p.Name, + Type: p.Type, + Optional: p.Optional, + Default: p.Default, + } + } + var ret ReturnShape + switch strings.ToLower(d.Returns.Kind) { + case "setof": + ret.Kind = ReturnSetOf + case "table": + ret.Kind = ReturnTable + ret.Columns = make([]Column, len(d.Returns.Columns)) + for i, c := range d.Returns.Columns { + ret.Columns[i] = Column{Name: c.Name, Type: c.Type} + } + default: + ret.Kind = ReturnScalar + } + ret.Type = d.Returns.Type + fns = append(fns, &Function{ + Name: name, + Params: params, + Returns: ret, + Volatility: vol, + Query: &PortableQuery{SQL: d.SQL}, + }) + } + return NewStaticRegistry(fns), nil +} + // EmptyRegistry is a registry with no functions; every Lookup misses. A backend // that has not been given any functions returns this so the frontend raises a // clean PGRST202 rather than dereferencing nil. diff --git a/schema/model.go b/schema/model.go index e243a02..7d29955 100644 --- a/schema/model.go +++ b/schema/model.go @@ -55,6 +55,11 @@ type Column struct { Type string // canonical PG type name (spec 16) Nullable bool HasDefault bool + // Identity reports whether the column is an auto-generated identity/serial + // column (IDENTITY on SQL Server, SERIAL/GENERATED ALWAYS AS IDENTITY on + // PostgreSQL). Backends that support explicit-identity inserts (e.g. SQL + // Server's IDENTITY_INSERT) use this to decide whether to enable it. + Identity bool // Position is the 1-based ordinal, used for stable ordering. Position int }