Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions backend/mysql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,36 @@ func (Dialect) SessionRead(string) string { return "" }
// SessionWrite reports ok=false: there is no engine setting to write.
func (Dialect) SessionWrite(string) (string, bool) { return "", false }

// ArrayOp returns false; MySQL has no native array types or containment operators.
func (Dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false }
// ArrayOp renders a JSON array containment/overlap expression using MySQL's
// JSON_CONTAINS and JSON_OVERLAPS functions (MySQL 8.0.17+). The column must be
// declared as JSON type; for any other column type ok=false is returned so the
// compiler raises PGRST127. colType is the canonical column type enriched by the
// planner; op is one of "@>" (contains), "<@" (contained-by), "&&" (overlaps).
func (Dialect) ArrayOp(col, op, val, colType string) (string, bool) {
if colType != "json" && colType != "jsonb" {
return "", false
}
switch op {
case "@>": // contains: col contains all elements of val
return "JSON_CONTAINS(" + col + ", " + val + ")", true
case "<@": // contained-by: val contains all elements of col
return "JSON_CONTAINS(" + val + ", " + col + ")", true
case "&&": // overlaps: at least one common element
return "JSON_OVERLAPS(" + col + ", " + val + ")", true
}
return "", false
}

// ILike uses plain LIKE; MySQL's default utf8mb4_unicode_ci collation is CI.
func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true }

// IsBool renders "col = 1" or "col = 0". MySQL 8's IS operator only accepts
// NULL/UNKNOWN/TRUE/FALSE, not integer literals, so "col IS 1" is a syntax
// error; equality works for TINYINT(1) boolean columns.
func (Dialect) IsBool(col string, v bool) (string, bool) {
return col + " = " + Dialect{}.BoolValue(v), true
}

// BoolValue renders a boolean as 1/0. MySQL's BOOL is an alias for TINYINT(1),
// so there is no native boolean keyword.
func (Dialect) BoolValue(v bool) string {
Expand All @@ -206,3 +230,27 @@ func (Dialect) BoolValue(v bool) string {
}
return "0"
}

// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array
// ["a","b"] so JSON_CONTAINS/JSON_OVERLAPS in ArrayOp can process it.
func (Dialect) ArrayLiteral(pgText string) string {
s := strings.TrimSpace(pgText)
if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' {
return pgText // already JSON or empty; pass through
}
inner := s[1 : len(s)-1]
if inner == "" {
return "[]"
}
parts := strings.Split(inner, ",")
quoted := make([]string, len(parts))
for i, p := range parts {
p = strings.TrimSpace(p)
if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' {
quoted[i] = p // already JSON-quoted
} else {
quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"`
}
}
return "[" + strings.Join(quoted, ",") + "]"
}
139 changes: 118 additions & 21 deletions backend/mysql/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"strings"
"time"

"github.com/tamnd/dbrest/backend"
"github.com/tamnd/dbrest/backend/sqlgen"
Expand All @@ -14,6 +15,30 @@ import (
"github.com/tamnd/dbrest/schema"
)

// normalizeArgs converts ISO 8601 datetime strings (e.g. "2024-01-01T00:00:00Z")
// to time.Time so the MySQL driver can bind them correctly. MySQL rejects the ISO
// T-separator format; passing time.Time avoids the string-to-DATETIME cast entirely.
func normalizeArgs(args []any) []any {
if len(args) == 0 {
return args
}
out := make([]any, len(args))
for i, a := range args {
if s, ok := a.(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
out[i] = t
continue
}
if t, err := time.Parse("2006-01-02T15:04:05", s); err == nil {
out[i] = t
continue
}
}
out[i] = a
}
return out
}

// Execute lowers a resolved plan to MySQL operations and returns a streamable
// result. Reads stream from an open cursor; writes run in a transaction and
// buffer their rows (since MySQL 8 has no RETURNING, rows are re-selected after
Expand Down Expand Up @@ -47,7 +72,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con
if apiErr != nil {
return nil, apiErr
}
if err := b.db.QueryRowContext(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil {
if err := b.db.QueryRowContext(ctx, cst.SQL, normalizeArgs(cst.Args)...).Scan(&res.count); err != nil {
return nil, b.MapError(err)
}
res.hasCount = true
Expand All @@ -57,7 +82,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con
if apiErr != nil {
return nil, apiErr
}
rows, err := b.db.QueryContext(ctx, st.SQL, st.Args...)
rows, err := b.db.QueryContext(ctx, st.SQL, normalizeArgs(st.Args)...)
if err != nil {
return nil, b.MapError(err)
}
Expand Down Expand Up @@ -222,7 +247,10 @@ func (b *Backend) executeInsertEmulated(
return nil
}

// executeUpdateEmulated runs UPDATE then re-selects with the same filter.
// executeUpdateEmulated runs UPDATE then re-selects by pre-captured primary keys.
// The re-select must use PKs, not the original filter, because the UPDATE may
// change the very column being filtered (e.g. PATCH /todos?task=eq.old sets
// task=new — after the UPDATE, task=eq.old matches nothing).
func (b *Backend) executeUpdateEmulated(
ctx context.Context, tx *sql.Tx,
q *ir.Query, returning []string, rel *schema.Relation,
Expand All @@ -232,42 +260,101 @@ func (b *Backend) executeUpdateEmulated(
if apiErr != nil {
return apiErr
}

// Pre-capture PKs when we need to return representation.
var pkValues []any
if len(returning) > 0 && len(rel.PrimaryKey) == 1 {
pkValues, apiErr = b.selectPKs(ctx, tx, q, rel.PrimaryKey[0])
if apiErr != nil {
return apiErr
}
}

out, err := tx.ExecContext(ctx, st.SQL, st.Args...)
if err != nil {
return err
}
n, _ := out.RowsAffected()
res.affected, res.hasAff = n, true

if len(returning) == 0 || n == 0 {
if len(returning) == 0 || len(pkValues) == 0 {
return nil
}

// Re-select: compile the equivalent SELECT with the same filters.
readQ := *q
readQ.Kind = ir.Read
readST, apiErr := sqlgen.CompileRead(Dialect{}, &readQ)
// Re-select by PK (post-update values).
colNames, buf, err := b.selectByPKs(ctx, tx, rel, rel.PrimaryKey[0], pkValues, returning)
if err != nil {
return err
}
res.cols, res.rows = colNames, buf
return nil
}

// selectPKs runs "SELECT pk FROM table WHERE <original filter>" and returns
// the raw PK values. Used to anchor the post-write re-select.
func (b *Backend) selectPKs(
ctx context.Context, tx *sql.Tx,
q *ir.Query, pkCol string,
) ([]any, *pgerr.APIError) {
pkQ := *q
pkQ.Kind = ir.Read
pkQ.Select = []ir.SelectItem{ir.Column{Path: []string{pkCol}}}
pkQ.Embeds = nil
pkQ.Order = nil
pkQ.Singular = false
st, apiErr := sqlgen.CompileRead(Dialect{}, &pkQ)
if apiErr != nil {
return apiErr
return nil, apiErr
}
rows, err := tx.QueryContext(ctx, readST.SQL, readST.Args...)
rows, err := tx.QueryContext(ctx, st.SQL, normalizeArgs(st.Args)...)
if err != nil {
return err
return nil, pgerr.New(500, "XX000", err.Error())
}
defer rows.Close()
var vals []any
for rows.Next() {
var v any
if err := rows.Scan(&v); err != nil {
return nil, pgerr.New(500, "XX000", err.Error())
}
vals = append(vals, v)
}
if err := rows.Err(); err != nil {
return nil, pgerr.New(500, "XX000", err.Error())
}
return vals, nil
}

// selectByPKs runs "SELECT cols FROM table WHERE pk IN (?,...)" using pre-captured
// PK values and returns the column names and buffered rows.
func (b *Backend) selectByPKs(
ctx context.Context, tx *sql.Tx,
rel *schema.Relation, pkCol string, pkValues []any, cols []string,
) ([]string, [][]any, error) {
d := Dialect{}
table := d.QuoteIdent(rel.Name)
pk := d.QuoteIdent(pkCol)
selCols := quotedCols(cols)
placeholders := make([]string, len(pkValues))
for i := range pkValues {
placeholders[i] = "?"
}
sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s IN (%s)",
selCols, table, pk, strings.Join(placeholders, ","))
rows, err := tx.QueryContext(ctx, sql, pkValues...)
if err != nil {
return nil, nil, err
}
colNames, err := rows.Columns()
if err != nil {
rows.Close()
return err
return nil, nil, err
}
boolCols := buildBoolCols(rel)
jsonIdx, boolIdx, _ := buildColMaps(rows, boolCols)
buf, err := drain(rows, colNames, jsonIdx, boolIdx)
rows.Close()
if err != nil {
return err
}
res.cols, res.rows = colNames, buf
return nil
return colNames, buf, err
}

// executeDeleteEmulated selects the rows to return, then deletes them.
Expand All @@ -284,7 +371,7 @@ func (b *Backend) executeDeleteEmulated(
if apiErr != nil {
return apiErr
}
rows, err := tx.QueryContext(ctx, readST.SQL, readST.Args...)
rows, err := tx.QueryContext(ctx, readST.SQL, normalizeArgs(readST.Args)...)
if err != nil {
return err
}
Expand Down Expand Up @@ -322,6 +409,7 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con
if apiErr != nil {
return nil, apiErr
}
st.Args = normalizeArgs(st.Args)

if plan.ReadOnly {
res := &result{controls: rc.Controls()}
Expand Down Expand Up @@ -383,17 +471,26 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con

// compileWrite dispatches to the right compiler for the mutation kind.
// When returning is empty the compiler omits the RETURNING / OUTPUT clause.
// Args are normalized for MySQL (ISO 8601 → time.Time) before returning.
func compileWrite(q *ir.Query, returning []string) (*sqlgen.Statement, *pgerr.APIError) {
var (
st *sqlgen.Statement
apiErr *pgerr.APIError
)
switch q.Kind {
case ir.Insert, ir.Upsert:
return sqlgen.CompileInsert(Dialect{}, q, returning)
st, apiErr = sqlgen.CompileInsert(Dialect{}, q, returning)
case ir.Update:
return sqlgen.CompileUpdate(Dialect{}, q, returning)
st, apiErr = sqlgen.CompileUpdate(Dialect{}, q, returning)
case ir.Delete:
return sqlgen.CompileDelete(Dialect{}, q, returning)
st, apiErr = sqlgen.CompileDelete(Dialect{}, q, returning)
default:
return nil, pgerr.ErrUnsupported("this operation", "mysql")
}
if st != nil {
st.Args = normalizeArgs(st.Args)
}
return st, apiErr
}

// returningCols decides which columns to read back after a write.
Expand Down
1 change: 1 addition & 0 deletions backend/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func Open(dsn string) (*Backend, error) {
return nil, fmt.Errorf("invalid MySQL DSN: %w", err)
}
cfg.ParseTime = true
cfg.ClientFoundRows = true // report matched rows, not changed rows (UPDATE re-select gate)
delete(cfg.Params, "tinyInt1IsBool") // removed in v1.8; schema-layer handles coercion

connector, err := mysqldrv.NewConnector(cfg)
Expand Down
10 changes: 9 additions & 1 deletion backend/postgres/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func sqlLiteral(s string) string {
}

// ArrayOp renders a PostgreSQL array containment/overlap expression.
func (Dialect) ArrayOp(col, op, val string) (string, bool) {
func (Dialect) ArrayOp(col, op, val, _ string) (string, bool) {
return col + " " + op + " " + val, true
}

Expand All @@ -236,3 +236,11 @@ func (Dialect) BoolValue(v bool) string {
}
return "FALSE"
}

// IsBool falls back to the generic "IS TRUE"/"IS FALSE" form; PostgreSQL
// supports IS <bool> natively.
func (Dialect) IsBool(string, bool) (string, bool) { return "", false }

// ArrayLiteral returns the PostgreSQL {a,b} array literal unchanged; PostgreSQL
// accepts it natively.
func (Dialect) ArrayLiteral(pgText string) string { return pgText }
17 changes: 13 additions & 4 deletions backend/sqlgen/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,9 @@ func (b *builder) writeSelect(items []ir.SelectItem) *pgerr.APIError {
}
b.sb.WriteString(expr)
// Alias the output so the renderer sees the PostgREST key, not the raw
// column. Only needed when the key differs from the bare column name.
if name := col.Name(); name != "" && name != lastPath(col.Path) {
// column expression. Always alias when a cast is present (the expression
// differs from the bare column name) or when an explicit alias was set.
if name := col.Name(); name != "" && (name != lastPath(col.Path) || col.Cast != "") {
b.sb.WriteString(" AS ")
b.sb.WriteString(b.d.QuoteIdent(name))
}
Expand Down Expand Up @@ -601,9 +602,11 @@ func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError {
default:
sqlOp = "&&"
}
val := b.bind(c.Value.Text)
// Normalize the PostgreSQL {a,b} array literal to the engine's format
// before binding; the dialect is a no-op for engines that accept {a,b}.
val := b.bind(b.d.ArrayLiteral(c.Value.Text))
var ok bool
frag, ok = b.d.ArrayOp(col, sqlOp, val)
frag, ok = b.d.ArrayOp(col, sqlOp, val, c.ColumnType)
if !ok {
return pgerr.ErrUnsupported("array operator "+sqlOp, "sql")
}
Expand Down Expand Up @@ -697,8 +700,14 @@ func (b *builder) writeIs(col, text string) (string, *pgerr.APIError) {
case "not_null":
return col + " IS NOT NULL", nil
case "true":
if frag, ok := b.d.IsBool(col, true); ok {
return frag, nil
}
return col + " IS " + b.d.BoolValue(true), nil
case "false":
if frag, ok := b.d.IsBool(col, false); ok {
return frag, nil
}
return col + " IS " + b.d.BoolValue(false), nil
default:
return "", pgerr.ErrParse("unknown is value " + text)
Expand Down
4 changes: 3 additions & 1 deletion backend/sqlgen/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,18 @@ func (stub) FullText(col string, _ *FullTextRef, v ir.FTSVariant, _, _ string) (
}
func (stub) SessionRead(k string) string { return "" }
func (stub) SessionWrite(k string) (string, bool) { return "", false }
func (stub) ArrayOp(col, op, val string) (string, bool) {
func (stub) ArrayOp(col, op, val, _ string) (string, bool) {
return col + " " + op + " " + val, true
}
func (stub) ArrayLiteral(s string) string { return s }
func (stub) ILike(col, val string) (string, bool) { return col + " ILIKE " + val, true }
func (stub) BoolValue(v bool) string {
if v {
return "TRUE"
}
return "FALSE"
}
func (stub) IsBool(string, bool) (string, bool) { return "", false }

func compile(t *testing.T, q *ir.Query) *Statement {
t.Helper()
Expand Down
Loading
Loading