diff --git a/backend/postgres/execute.go b/backend/postgres/execute.go index b063a80..02531b5 100644 --- a/backend/postgres/execute.go +++ b/backend/postgres/execute.go @@ -2,6 +2,8 @@ package postgres import ( "context" + "encoding/json" + "strconv" "strings" "github.com/jackc/pgx/v5" @@ -36,75 +38,52 @@ func (b *Backend) Execute(ctx context.Context, plan *ir.Plan, rc *reqctx.Context } } -// executeRead compiles and runs the windowed read. The entire request is sent as -// a single pgx.Batch: [BEGIN, session setup, count (if needed), query, ROLLBACK]. -// One network write to PostgreSQL covers all round trips, matching PostgREST's -// hasql pipeline behaviour. Rows stream from within the open batch; Close drains -// the trailing ROLLBACK item and releases the connection. +// executeRead compiles and runs the windowed read in a read-only transaction. +// Session setup (SET LOCAL ROLE, search_path, GUCs) is applied via applySession +// before the main query is sent so the PostgreSQL planner sees the correct role +// at parse time. Rows stream from within the open transaction; Close commits it. +// +// Note: a single-batch approach (BEGIN + session + query + ROLLBACK in one +// pipeline) would let pgx pre-parse the main SELECT while the connection is still +// authenticator (NOINHERIT, no schema USAGE), causing a 42501 error. applySession +// completes its batch before the main query is issued, so Parse runs as the +// request role, which has the required privileges. func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - conn, err := b.pool.Acquire(ctx) + tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) if err != nil { return nil, b.MapError(err) } - release := func() { conn.Release() } + rollback := func() { _ = tx.Rollback(ctx) } - // Build the single batch: BEGIN → session → [count] → query → ROLLBACK. - batch := &pgx.Batch{} - batch.Queue("BEGIN TRANSACTION READ ONLY") - sessionN := queueSessionItems(batch, b, rc) + if err := applySession(ctx, tx, b, rc); err != nil { + rollback() + return nil, b.MapError(err) + } + + res := &streamResult{ctx: ctx, tx: tx, controls: rc.Controls()} - hasCount := plan.Query.Count != ir.CountNone - var cst *sqlgen.Statement - if hasCount { - var apiErr *pgerr.APIError - cst, apiErr = sqlgen.CompileCount(Dialect{}, plan.Query) + if plan.Query.Count != ir.CountNone { + cst, apiErr := sqlgen.CompileCount(Dialect{}, plan.Query) if apiErr != nil { - release() + rollback() return nil, apiErr } - batch.Queue(cst.SQL, cst.Args...) + if err := tx.QueryRow(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil { + rollback() + return nil, b.MapError(err) + } + res.hasCount = true } st, apiErr := sqlgen.CompileRead(Dialect{}, plan.Query) if apiErr != nil { - release() + rollback() return nil, apiErr } - batch.Queue(st.SQL, st.Args...) - batch.Queue("ROLLBACK") - - br := conn.SendBatch(ctx, batch) - - abort := func(e error) (backend.Result, error) { - _ = br.Close() - release() - return nil, e - } - - // Drain BEGIN. - if _, err := br.Exec(); err != nil { - return abort(b.MapError(err)) - } - // Drain session setup items. - for range sessionN { - if _, err := br.Exec(); err != nil { - return abort(b.MapError(err)) - } - } - - res := &batchStreamResult{ctx: ctx, conn: conn, br: br, controls: rc.Controls()} - - if hasCount { - _ = cst // already queued - if err := br.QueryRow().Scan(&res.count); err != nil { - return abort(b.MapError(err)) - } - res.hasCount = true - } - - rows, err := br.Query() + rows, err := tx.Query(ctx, st.SQL, st.Args...) if err != nil { - return abort(b.MapError(err)) + rollback() + return nil, b.MapError(err) } res.rows = rows res.cols = fieldNames(rows) @@ -324,10 +303,11 @@ func (b *Backend) executeCallRead(ctx context.Context, plan *ir.Plan, rc *reqctx // compileNativeCall generates the PostgreSQL function-call SQL for the native // RPC path (NativeRPC=true), where there is no declared function registry. It -// renders SELECT * FROM schema.fn(arg := $1, ...) using the search path's first -// schema as the function schema. Arguments come from the call's parsed arg map; -// they are bound as named parameters (fn_name := $N) which is how PostgREST -// calls PG functions. When no args are supplied the call has an empty arg list. +// renders SELECT * FROM schema.fn(arg := , ...) with values embedded +// as SQL literals so PostgreSQL infers the parameter types from the function +// signature and the call does not depend on pgx OID mapping. String values are +// single-quote escaped; numeric JSON values are written as numeric literals; +// booleans become TRUE/FALSE; null or absent values become NULL. func (b *Backend) compileNativeCall(c *ir.Call) (*sqlgen.Statement, *pgerr.APIError) { schema := "public" if len(b.searchPath) > 0 { @@ -342,7 +322,6 @@ func (b *Backend) compileNativeCall(c *ir.Call) (*sqlgen.Statement, *pgerr.APIEr sb.WriteString(d.QuoteIdent(c.Function.Name)) sb.WriteString("(") - args := make([]any, 0, len(c.Args)) i := 0 for name, val := range c.Args { if i > 0 { @@ -350,12 +329,51 @@ func (b *Backend) compileNativeCall(c *ir.Call) (*sqlgen.Statement, *pgerr.APIEr } sb.WriteString(d.QuoteIdent(name)) sb.WriteString(" := ") - sb.WriteString(d.Placeholder(i + 1)) - args = append(args, val.Text) + appendNativeArg(&sb, val) i++ } sb.WriteString(")") - return &sqlgen.Statement{SQL: sb.String(), Args: args}, nil + return &sqlgen.Statement{SQL: sb.String()}, nil +} + +// appendNativeArg writes one function argument as a safe SQL literal. Numbers +// are written unquoted so PostgreSQL resolves their type from context; strings +// use single-quote escaping; booleans are TRUE/FALSE; anything else (including +// absent values) becomes NULL. Objects and arrays are JSON-quoted. +func appendNativeArg(sb *strings.Builder, val ir.Value) { + if val.JSON != nil { + switch v := val.JSON.(type) { + case string: + sb.WriteString("'") + sb.WriteString(strings.ReplaceAll(v, "'", "''")) + sb.WriteString("'") + case json.Number: + // json.Number from dec.UseNumber() — write as-is; it is a valid SQL numeric literal. + sb.WriteString(v.String()) + case float64: + sb.WriteString(strconv.FormatFloat(v, 'f', -1, 64)) + case bool: + if v { + sb.WriteString("TRUE") + } else { + sb.WriteString("FALSE") + } + default: + // JSON object / array: pass as json literal. + enc, _ := json.Marshal(v) + sb.WriteString("'") + sb.WriteString(strings.ReplaceAll(string(enc), "'", "''")) + sb.WriteString("'::json") + } + return + } + if val.Text != "" { + sb.WriteString("'") + sb.WriteString(strings.ReplaceAll(val.Text, "'", "''")) + sb.WriteString("'") + return + } + sb.WriteString("NULL") } // compileWrite dispatches to the right compiler for the mutation kind. @@ -407,6 +425,50 @@ func fieldNames(rows pgx.Rows) []string { return names } +// ExplainRead runs EXPLAIN (FORMAT JSON) on the read query and returns the raw +// JSON plan from PostgreSQL. When analyze is true EXPLAIN ANALYZE is used +// instead, which also executes the query and includes timing. The request runs +// in a read-only transaction with the full session setup (role + GUCs) so the +// planner sees the same context as a real request. +func (b *Backend) ExplainRead(ctx context.Context, p *ir.Plan, rc *reqctx.Context, analyze bool) ([]byte, error) { + tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) + if err != nil { + return nil, b.MapError(err) + } + defer func() { _ = tx.Rollback(ctx) }() + + if err := applySession(ctx, tx, b, rc); err != nil { + return nil, b.MapError(err) + } + + st, apiErr := sqlgen.CompileRead(Dialect{}, p.Query) + if apiErr != nil { + return nil, apiErr + } + + var prefix string + if analyze { + prefix = "EXPLAIN (ANALYZE, FORMAT JSON) " + } else { + prefix = "EXPLAIN (FORMAT JSON) " + } + rows, err := tx.Query(ctx, prefix+st.SQL, st.Args...) + if err != nil { + return nil, b.MapError(err) + } + defer rows.Close() + var plan []byte + for rows.Next() { + if err := rows.Scan(&plan); err != nil { + return nil, b.MapError(err) + } + } + if err := rows.Err(); err != nil { + return nil, b.MapError(err) + } + return plan, nil +} + // drainRows reads every row of a pgx cursor into memory, normalizing values so // json/jsonb, bytea, and date columns render correctly. The rows are closed by // drainRows; the caller must not close them again. diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index 06ccc30..202a0cb 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -199,8 +199,8 @@ func statusForSQLState(code string) int { return 404 case "42P01": // undefined_table return 404 - case "42501": // insufficient_privilege - return 403 + case "42501": // insufficient_privilege → 401 matching PostgREST + return 401 case "42P17": // infinite_recursion return 500 } diff --git a/backend/postgres/postgres_test.go b/backend/postgres/postgres_test.go index d53b46c..214a407 100644 --- a/backend/postgres/postgres_test.go +++ b/backend/postgres/postgres_test.go @@ -81,7 +81,7 @@ func TestStatusForSQLState(t *testing.T) { {"25006", 405}, {"42883", 404}, {"42P01", 404}, - {"42501", 403}, + {"42501", 401}, // matches PostgREST: insufficient_privilege → 401 // PTxxx convention {"PT403", 403}, {"PT201", 201}, diff --git a/backend/postgres/result.go b/backend/postgres/result.go index 30e39ff..a222df8 100644 --- a/backend/postgres/result.go +++ b/backend/postgres/result.go @@ -8,7 +8,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/reqctx" @@ -77,64 +76,6 @@ func (s *streamRows) Close() error { return s.tx.Commit(s.ctx) } -// batchStreamResult adapts an in-flight pgx.BatchResults to the backend.Result -// contract for a read. The entire request (BEGIN + session setup + query + -// ROLLBACK) was sent in one pgx.Batch network write; the caller has already -// consumed the non-row items and positioned br at the query result. Streaming -// rows through the open BatchResults and draining ROLLBACK at Close reduces the -// read path to a single PostgreSQL round trip. -type batchStreamResult struct { - ctx context.Context - conn *pgxpool.Conn - br pgx.BatchResults - rows pgx.Rows - cols []string - controls *reqctx.ResponseControls - count int64 - hasCount bool -} - -func (r *batchStreamResult) Body() io.Reader { return nil } -func (r *batchStreamResult) Rows() backend.RowStream { - return &batchStreamRows{ctx: r.ctx, conn: r.conn, br: r.br, rows: r.rows, cols: r.cols} -} -func (r *batchStreamResult) Count() (int64, bool) { return r.count, r.hasCount } -func (r *batchStreamResult) Affected() (int64, bool) { return 0, false } -func (r *batchStreamResult) ResponseControls() *reqctx.ResponseControls { return r.controls } - -// batchStreamRows streams rows from within an open pgx.BatchResults. On Close -// it drains the remaining ROLLBACK item, closes the batch, and releases the -// connection back to the pool. -type batchStreamRows struct { - ctx context.Context - conn *pgxpool.Conn - br pgx.BatchResults - rows pgx.Rows - cols []string -} - -func (s *batchStreamRows) Columns() []string { return s.cols } -func (s *batchStreamRows) Next() bool { return s.rows.Next() } -func (s *batchStreamRows) Err() error { return s.rows.Err() } - -func (s *batchStreamRows) Values() ([]any, error) { - vals, err := s.rows.Values() - if err != nil { - return nil, err - } - return normalizeValues(vals, s.rows.FieldDescriptions()), nil -} - -// Close drains the ROLLBACK batch item and releases the connection. -func (s *batchStreamRows) Close() error { - s.rows.Close() - rowErr := s.rows.Err() - s.br.Exec() //nolint:errcheck // ROLLBACK; ignore error, it's cleanup - _ = s.br.Close() - s.conn.Release() - return rowErr -} - // bufResult holds the buffered outcome of a write or a function call. A write // runs inside a transaction that must commit (or roll back, under tx=rollback) // before the response is sent, and a function call's response headers and status diff --git a/backend/spi.go b/backend/spi.go index 43e6400..8573a29 100644 --- a/backend/spi.go +++ b/backend/spi.go @@ -60,6 +60,16 @@ type Result interface { ResponseControls() *reqctx.ResponseControls } +// Explainer is an optional backend capability for the vnd.pgrst.plan+json +// Accept type. Backends that support EXPLAIN implement this interface; +// the frontend type-asserts to it and falls back to 406 when absent. +type Explainer interface { + // ExplainRead runs EXPLAIN on the read query and returns raw JSON from the + // engine's query planner. If analyze is true the engine also executes and + // times the query (EXPLAIN ANALYZE equivalent). + ExplainRead(ctx context.Context, p *ir.Plan, rc *reqctx.Context, analyze bool) ([]byte, error) +} + // RowStream is a forward-only cursor over result rows. The renderer drives it to // assemble the response body when the backend does not assemble JSON itself. type RowStream interface { diff --git a/backend/sqlgen/compile.go b/backend/sqlgen/compile.go index 769a935..a31e955 100644 --- a/backend/sqlgen/compile.go +++ b/backend/sqlgen/compile.go @@ -554,12 +554,20 @@ func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError { frag = col + " " + binaryOp(c.Op) + " " + b.bind(c.Value.Text) } case ir.OpGt, ir.OpGte, ir.OpLt, ir.OpLte, ir.OpLike: - frag = col + " " + binaryOp(c.Op) + " " + b.bind(c.Value.Text) + if c.Quant != ir.QNone { + frag, err = b.writeLikeQuantified(col, ir.OpLike, c.Quant, c.Value.List) + } else { + frag = col + " " + binaryOp(c.Op) + " " + b.bind(c.Value.Text) + } case ir.OpILike: - var ok bool - frag, ok = b.d.ILike(col, b.bind(c.Value.Text)) - if !ok { - return pgerr.ErrUnsupported("case-insensitive LIKE", "sql") + if c.Quant != ir.QNone { + frag, err = b.writeLikeQuantified(col, ir.OpILike, c.Quant, c.Value.List) + } else { + var ok bool + frag, ok = b.d.ILike(col, b.bind(c.Value.Text)) + if !ok { + return pgerr.ErrUnsupported("case-insensitive LIKE", "sql") + } } case ir.OpIn: frag, err = b.writeIn(col, c.Value.List) @@ -648,6 +656,40 @@ func (b *builder) writeIn(col string, list []string) (string, *pgerr.APIError) { return col + " IN (" + strings.Join(parts, ", ") + ")", nil } +// writeLikeQuantified expands like(any)/{...} and like(all)/{...} into a +// conjunction or disjunction of individual LIKE / ILIKE predicates. An empty +// list generates a no-match literal (1 = 0) for ANY and always-match (1 = 1) +// for ALL, consistent with SQL ANY/ALL semantics over an empty set. +func (b *builder) writeLikeQuantified(col string, op ir.Op, q ir.Quant, list []string) (string, *pgerr.APIError) { + if len(list) == 0 { + if q == ir.QAny { + return "1 = 0", nil + } + return "1 = 1", nil + } + sep := " OR " + if q == ir.QAll { + sep = " AND " + } + parts := make([]string, len(list)) + for i, pat := range list { + bound := b.bind(pat) + if op == ir.OpILike { + expr, ok := b.d.ILike(col, bound) + if !ok { + return "", pgerr.ErrUnsupported("case-insensitive LIKE", "sql") + } + parts[i] = expr + } else { + parts[i] = col + " LIKE " + bound + } + } + if len(parts) == 1 { + return parts[0], nil + } + return "(" + strings.Join(parts, sep) + ")", nil +} + func (b *builder) writeIs(col, text string) (string, *pgerr.APIError) { switch text { case "null": diff --git a/backend/sqlgen/cond_test.go b/backend/sqlgen/cond_test.go index 154adb4..c0e5b5e 100644 --- a/backend/sqlgen/cond_test.go +++ b/backend/sqlgen/cond_test.go @@ -77,6 +77,36 @@ func TestCompileEveryInfixOperator(t *testing.T) { } } +// like(any) expands a {pat1,pat2} list into col LIKE $1 OR col LIKE $2. +func TestCompileLikeAny(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"task"}, + Op: ir.OpLike, + Quant: ir.QAny, + Value: ir.Value{List: []string{"%cat%", "%laundry%"}}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "todos"}, Where: &where}) + want := `SELECT * FROM "todos" WHERE ("task" LIKE $1 OR "task" LIKE $2)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// like(all) expands a {pat1,pat2} list into col LIKE $1 AND col LIKE $2. +func TestCompileLikeAll(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"task"}, + Op: ir.OpLike, + Quant: ir.QAll, + Value: ir.Value{List: []string{"%A%", "%o%"}}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "todos"}, Where: &where}) + want := `SELECT * FROM "todos" WHERE ("task" LIKE $1 AND "task" LIKE $2)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + // indexFTSDialect models an engine whose full text needs a covering index: it // quotes the index reference into the emitted expression, and reports ok=false // when the planner attached none. diff --git a/backend/sqlgen/embed.go b/backend/sqlgen/embed.go index e5fe91c..7747d15 100644 --- a/backend/sqlgen/embed.go +++ b/backend/sqlgen/embed.go @@ -290,6 +290,10 @@ func (b *builder) embedObject(emb *ir.Embed, alias string) (string, *pgerr.APIEr return "", err } pairs = append(pairs, Pair{Key: nested.OutKey, Value: b.d.Cast(sub, "json")}) + case ir.Aggregate: + if v.Func == ir.AggCount && v.Arg == nil { + pairs = append(pairs, Pair{Key: "count", Value: "count(*)"}) + } default: return "", pgerr.ErrUnsupported("aggregates in embedded resources", "sql") } diff --git a/httpapi/negotiate.go b/httpapi/negotiate.go index db4eb9f..95a07ae 100644 --- a/httpapi/negotiate.go +++ b/httpapi/negotiate.go @@ -14,12 +14,13 @@ const ( mediaJSON = "application/json" mediaArray = "application/vnd.pgrst.array+json" mediaObject = "application/vnd.pgrst.object+json" + mediaPlan = "application/vnd.pgrst.plan+json" mediaCSV = "text/csv" mediaOctet = "application/octet-stream" mediaText = "text/plain" ) -var supportedMedia = []string{mediaJSON, mediaArray, mediaObject, mediaCSV, mediaOctet, mediaText} +var supportedMedia = []string{mediaJSON, mediaArray, mediaObject, mediaPlan, mediaCSV, mediaOctet, mediaText} // mediaRange is one parsed entry of an Accept header: a type/subtype pair, its // quality value, and its position in the header for stable tie-breaking. @@ -62,6 +63,33 @@ func parseAccept(headers []string) []mediaRange { return ranges } +// planAnalyze reports whether the Accept header for vnd.pgrst.plan+json carries +// "options=analyze", which asks for EXPLAIN ANALYZE rather than plain EXPLAIN. +func planAnalyze(headers []string) bool { + for _, h := range headers { + for part := range strings.SplitSeq(h, ",") { + part = strings.TrimSpace(part) + segs := strings.Split(part, ";") + typ, sub, ok := strings.Cut(strings.TrimSpace(segs[0]), "/") + if !ok { + continue + } + if strings.ToLower(typ)+"/"+strings.ToLower(sub) != "application/vnd.pgrst.plan+json" { + continue + } + for _, p := range segs[1:] { + p = strings.TrimSpace(p) + if v, ok := strings.CutPrefix(strings.ToLower(p), "options="); ok { + if strings.Contains(v, "analyze") { + return true + } + } + } + } + } + return false +} + // negotiate picks the best supported response media type for the Accept header. // An absent or fully wildcard Accept yields application/json. The second return // is false when no listed media type can be produced, which the caller turns diff --git a/httpapi/render.go b/httpapi/render.go index a33392b..080333f 100644 --- a/httpapi/render.go +++ b/httpapi/render.go @@ -57,9 +57,20 @@ func renderFor(media string, res backend.Result, rawCols map[string]bool) (*rend // scalar media). A scalar return is the bare value; a setof-scalar return is a // JSON array of bare values. The object media type asks for a single value and // enforces the zero-or-many rule, so a setof function with one row can satisfy a -// singular request. -func renderCall(media string, res backend.Result, fn *rpc.Function) (*rendered, *pgerr.APIError) { - if fn == nil || fn.Returns.Kind == rpc.ReturnTable { +// singular request. fnName is the bare function name; it is used for native-RPC +// heuristic detection when fn is nil. +func renderCall(media string, res backend.Result, fn *rpc.Function, fnName string) (*rendered, *pgerr.APIError) { + if fn == nil { + // Native RPC: detect scalar vs table by inspecting column names. + // res.Rows().Columns() does not advance the cursor; the stream remains + // fully readable for the render path below. + cols := res.Rows().Columns() + if len(cols) == 1 && cols[0] == fnName { + fn = &rpc.Function{Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar}} + } else { + return renderFor(media, res, nil) + } + } else if fn.Returns.Kind == rpc.ReturnTable { return renderFor(media, res, nil) } switch media { diff --git a/httpapi/server.go b/httpapi/server.go index 416283f..d4fca75 100644 --- a/httpapi/server.go +++ b/httpapi/server.go @@ -274,7 +274,7 @@ func (s *Server) handleRPC(w http.ResponseWriter, r *http.Request, fn string, id return } - out, apiErr := renderCall(media, res, planned.Func) + out, apiErr := renderCall(media, res, planned.Func, fn) if apiErr != nil { writeError(w, apiErr) return @@ -318,9 +318,10 @@ func (s *Server) handleRead(w http.ResponseWriter, r *http.Request, id identity) return } - media, ok := negotiate(r.Header.Values("Accept")) + acceptHdrs := r.Header.Values("Accept") + media, ok := negotiate(acceptHdrs) if !ok { - writeError(w, pgerr.ErrNotAcceptable(strings.Join(r.Header.Values("Accept"), ", "))) + writeError(w, pgerr.ErrNotAcceptable(strings.Join(acceptHdrs, ", "))) return } @@ -332,18 +333,17 @@ func (s *Server) handleRead(w http.ResponseWriter, r *http.Request, id identity) q.Singular = media == mediaObject // Range: header overrides ?limit=&offset= and marks the request as a - // Range request so the server can return 206 Partial Content, matching - // PostgREST's behaviour: 206 only comes from a Range header (or from - // count=exact showing the window is partial). - if rangeUnit := r.Header.Get("Range-Unit"); rangeUnit == "items" { - if rangeHdr := r.Header.Get("Range"); rangeHdr != "" { - if off, lim, ok := parseRangeHeader(rangeHdr); ok { - q.Offset = &off - if lim >= 0 { - l := lim - q.Limit = &l - q.FromRange = true // bounded Range → eligible for 206 - } + // Range request so the server can return 206 Partial Content. PostgREST + // accepts Range: 0-9 (item range) without requiring Range-Unit: items. + // Only treat Range as item pagination when it has no unit prefix (i.e. + // not "bytes=0-9" form), matching PostgREST's parsing behaviour. + if rangeHdr := r.Header.Get("Range"); rangeHdr != "" && !strings.Contains(rangeHdr, "=") { + if off, lim, ok := parseRangeHeader(rangeHdr); ok { + q.Offset = &off + if lim >= 0 { + l := lim + q.Limit = &l + q.FromRange = true // bounded Range → eligible for 206 } } } @@ -361,6 +361,24 @@ func (s *Server) handleRead(w http.ResponseWriter, r *http.Request, id identity) return } + // vnd.pgrst.plan+json: return EXPLAIN JSON when the backend supports it. + if media == mediaPlan { + exp, supported := s.backend.(backend.Explainer) + if !supported { + writeError(w, pgerr.ErrNotAcceptable(mediaPlan)) + return + } + planJSON, err := exp.ExplainRead(r.Context(), planned, rc, planAnalyze(acceptHdrs)) + if err != nil { + writeError(w, mapExecError(s.backend, err, id.anonymous)) + return + } + w.Header().Set("Content-Type", mediaPlan) + w.WriteHeader(http.StatusOK) + w.Write(planJSON) + return + } + res, err := s.backend.Execute(r.Context(), planned, rc) if err != nil { writeError(w, mapExecError(s.backend, err, id.anonymous)) @@ -445,6 +463,13 @@ func (s *Server) writeWrite(w http.ResponseWriter, r *http.Request, q *ir.Query, representation := q.Write.Return == ir.ReturnRepresentation if !representation { + // When count=exact was requested, include Content-Range: */ so the + // client knows how many rows were affected, matching PostgREST's wire. + if q.Count == ir.CountExact { + if n, ok := res.Affected(); ok { + w.Header().Set("Content-Range", fmt.Sprintf("*/%d", n)) + } + } w.WriteHeader(applyControls(w, ctrl, writeStatus(r.Method, q.Kind, false, ctrl))) return } @@ -456,7 +481,16 @@ func (s *Server) writeWrite(w http.ResponseWriter, r *http.Request, q *ir.Query, } w.Header().Set("Content-Type", out.contentType) if !q.Singular { - w.Header().Set("Content-Range", contentRange(0, out.nRows, 0, false)) + // For writes with count=exact, include the total in Content-Range. + if q.Count == ir.CountExact { + if n, ok := res.Affected(); ok { + w.Header().Set("Content-Range", contentRange(0, out.nRows, n, true)) + } else { + w.Header().Set("Content-Range", contentRange(0, out.nRows, 0, false)) + } + } else { + w.Header().Set("Content-Range", contentRange(0, out.nRows, 0, false)) + } } w.WriteHeader(applyControls(w, ctrl, writeStatus(r.Method, q.Kind, true, ctrl))) if r.Method != http.MethodHead { @@ -635,10 +669,13 @@ func asAPIError(b backend.Backend, err error) *pgerr.APIError { // (insufficient_privilege) error to an anonymous request is 401 (authentication // required), not 403 (forbidden). An authenticated request that is denied // remains 403 so the caller knows to authenticate, not just retry. +// The original PostgreSQL message is preserved to match PostgREST wire behavior. func mapExecError(b backend.Backend, err error, anonymous bool) *pgerr.APIError { e := asAPIError(b, err) if anonymous && e.Code == pgerr.CodeInsufficientPrivilege { - e = pgerr.ErrPermissionDenied("", anonymous) + lifted := *e + lifted.HTTPStatus = http.StatusUnauthorized + return &lifted } return e } diff --git a/ir/parse.go b/ir/parse.go index 3a7dc32..08f65dd 100644 --- a/ir/parse.go +++ b/ir/parse.go @@ -420,7 +420,14 @@ func buildInsert(objs []map[string]any, columnsParam string, header []string) ([ var cols []string switch { case columnsParam != "": - cols = splitComma(columnsParam) + raw := splitComma(columnsParam) + cols = make([]string, len(raw)) + for i, c := range raw { + if len(c) >= 2 && c[0] == '"' && c[len(c)-1] == '"' { + c = c[1 : len(c)-1] + } + cols[i] = c + } case header != nil: cols = header case len(objs) > 0: @@ -475,6 +482,12 @@ func parseSelect(s string) ([]SelectItem, []Embed, *pgerr.APIError) { embeds = append(embeds, emb) continue } + // PostgREST supports a bare "count" inside an embed select as a virtual + // aggregate that maps to count(*) in the JSON output. + if raw == "count" { + items = append(items, Aggregate{Func: AggCount}) + continue + } col, perr := parseColumnItem(raw) if perr != nil { return nil, nil, perr @@ -823,7 +836,19 @@ func parseCompare(path []string, raw string) (Compare, *pgerr.APIError) { } case OpLike, OpILike: // PostgREST maps * to % in LIKE/ILIKE patterns so URL-friendly wildcards work. - c.Value = Value{Text: strings.ReplaceAll(operand, "*", "%")} + if c.Quant != QNone { + // like(any)/{*cat*,*laundry*} — expand {…} into a list, * → % in each. + list, perr := parseLikeList(operand) + if perr != nil { + return Compare{}, perr + } + for i, p := range list { + list[i] = strings.ReplaceAll(p, "*", "%") + } + c.Value = Value{List: list} + } else { + c.Value = Value{Text: strings.ReplaceAll(operand, "*", "%")} + } default: c.Value = Value{Text: operand} } @@ -855,6 +880,26 @@ func parseInList(raw string) ([]string, *pgerr.APIError) { return out, nil } +// parseLikeList parses a {pat1,pat2,...} literal (PostgREST quantified-LIKE +// syntax) into a slice of raw pattern strings. No wildcard translation is done +// here; the caller applies * → % after parsing. +func parseLikeList(raw string) ([]string, *pgerr.APIError) { + raw = strings.TrimSpace(raw) + if len(raw) < 2 || raw[0] != '{' || raw[len(raw)-1] != '}' { + return nil, pgerr.ErrParse("like(any/all) expects a {…} list") + } + inner := raw[1 : len(raw)-1] + if inner == "" { + return []string{}, nil + } + parts := strings.Split(inner, ",") + out := make([]string, len(parts)) + for i, p := range parts { + out[i] = strings.TrimSpace(p) + } + return out, nil +} + // ftsVariant maps a full-text operator token to its IR variant. The four tokens // share the single OpFTS op and differ only in the query grammar a backend lowers // them to (spec 21). diff --git a/plan/plan.go b/plan/plan.go index badefcf..1d2740b 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -278,6 +278,9 @@ func validateSelect(rel *schema.Relation, items []ir.SelectItem) *pgerr.APIError // Aggregates and embeds are checked by their subsystems; leave them. continue } + if isStarPath(col.Path) { + continue + } if err := checkColumn(rel, col.Path); err != nil { return err }