diff --git a/adminapi/adminapi.go b/adminapi/adminapi.go new file mode 100644 index 0000000..66babd5 --- /dev/null +++ b/adminapi/adminapi.go @@ -0,0 +1,147 @@ +// Package adminapi is the admin listener PostgREST runs next to the API when +// admin-server-port is set: GET /live and /ready for orchestrator probes, +// GET /schema_cache for the loaded cache, and GET /metrics in Prometheus text +// format. The endpoints, paths, and status codes mirror PostgREST v14's +// admin server (PostgREST.Admin): /live is 200 while the API listener accepts +// connections and 500 otherwise; /ready adds the backend and schema cache +// health and degrades to 503; any other path is 404 with an empty body. +// +// Spec 20 sketches a POST /schema_cache for an on-demand reload; upstream has +// no such endpoint (reload is SIGUSR1 or NOTIFY), so it is not served here. +// If the reload entry point lands it belongs next to the signal handler, not +// in this package's GET-only surface. +package adminapi + +import ( + "context" + "fmt" + "net/http" + "sort" + "strings" + "sync" + "time" +) + +// Server serves the admin endpoints. The health checks are injected by the +// command, which knows the API listener address and owns the backend; the +// zero value of any field degrades gracefully (a nil check reports healthy, +// a nil SchemaCache serves an empty body, matching upstream's "no cache yet"). +type Server struct { + // Live reports whether the API listener accepts connections. PostgREST + // implements this as a TCP dial of its own socket; the command wires the + // same here. + Live func(ctx context.Context) error + + // Ready reports whether the backend connection and the schema cache are + // usable. It is consulted in addition to Live. + Ready func(ctx context.Context) error + + // SchemaCache returns the loaded schema cache rendered as JSON. + SchemaCache func() ([]byte, error) + + // Metrics holds the counters rendered at /metrics. + Metrics *Metrics +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch strings.TrimSuffix(r.URL.Path, "/") { + case "/live": + w.WriteHeader(s.liveStatus(r.Context())) + case "/ready": + w.WriteHeader(s.readyStatus(r.Context())) + case "/schema_cache": + body, err := s.schemaCacheJSON() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(body) + case "/metrics": + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusOK) + if s.Metrics != nil { + w.Write([]byte(s.Metrics.Text())) + } + default: + w.WriteHeader(http.StatusNotFound) + } +} + +func (s *Server) liveStatus(ctx context.Context) int { + if s.Live != nil && s.Live(ctx) != nil { + return http.StatusInternalServerError + } + return http.StatusOK +} + +func (s *Server) readyStatus(ctx context.Context) int { + if status := s.liveStatus(ctx); status != http.StatusOK { + return status + } + if s.Ready != nil && s.Ready(ctx) != nil { + return http.StatusServiceUnavailable + } + return http.StatusOK +} + +func (s *Server) schemaCacheJSON() ([]byte, error) { + if s.SchemaCache == nil { + return nil, nil + } + return s.SchemaCache() +} + +// Metrics is a small Prometheus-text registry covering what dbrest measures +// today: schema cache loads and the configured pool ceiling. The names follow +// PostgREST's metric names where the concept matches. +type Metrics struct { + mu sync.Mutex + loads map[string]int64 // by status label: SUCCESS / FAIL + lastLoadSeconds float64 + poolMax int +} + +// NewMetrics builds a registry; poolMax is the db-pool setting. +func NewMetrics(poolMax int) *Metrics { + return &Metrics{loads: map[string]int64{}, poolMax: poolMax} +} + +// ObserveSchemaCacheLoad records one schema cache load attempt. +func (m *Metrics) ObserveSchemaCacheLoad(d time.Duration, err error) { + m.mu.Lock() + defer m.mu.Unlock() + status := "SUCCESS" + if err != nil { + status = "FAIL" + } + m.loads[status]++ + if err == nil { + m.lastLoadSeconds = d.Seconds() + } +} + +// Text renders the registry in the Prometheus text exposition format. +func (m *Metrics) Text() string { + m.mu.Lock() + defer m.mu.Unlock() + var b strings.Builder + b.WriteString("# HELP pgrst_schema_cache_query_time_seconds The query time in seconds of the last schema cache load\n") + b.WriteString("# TYPE pgrst_schema_cache_query_time_seconds gauge\n") + fmt.Fprintf(&b, "pgrst_schema_cache_query_time_seconds %g\n", m.lastLoadSeconds) + b.WriteString("# HELP pgrst_schema_cache_loads_total The total number of schema cache loads\n") + b.WriteString("# TYPE pgrst_schema_cache_loads_total counter\n") + statuses := make([]string, 0, len(m.loads)) + for status := range m.loads { + statuses = append(statuses, status) + } + sort.Strings(statuses) + for _, status := range statuses { + fmt.Fprintf(&b, "pgrst_schema_cache_loads_total{status=%q} %d\n", status, m.loads[status]) + } + b.WriteString("# HELP pgrst_db_pool_max Max pool connections\n") + b.WriteString("# TYPE pgrst_db_pool_max gauge\n") + fmt.Fprintf(&b, "pgrst_db_pool_max %d\n", m.poolMax) + return b.String() +} diff --git a/adminapi/adminapi_test.go b/adminapi/adminapi_test.go new file mode 100644 index 0000000..1f48901 --- /dev/null +++ b/adminapi/adminapi_test.go @@ -0,0 +1,132 @@ +package adminapi + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func get(t *testing.T, s *Server, path string) *http.Response { + t.Helper() + rec := httptest.NewRecorder() + s.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, path, nil)) + return rec.Result() +} + +// TestLive covers both sides of the liveness probe: 200 while the API socket +// answers, 500 once it does not, matching the PostgREST admin server. +func TestLive(t *testing.T) { + up := &Server{Live: func(context.Context) error { return nil }} + if resp := get(t, up, "/live"); resp.StatusCode != http.StatusOK { + t.Errorf("live up: status = %d, want 200", resp.StatusCode) + } + down := &Server{Live: func(context.Context) error { return errors.New("refused") }} + if resp := get(t, down, "/live"); resp.StatusCode != http.StatusInternalServerError { + t.Errorf("live down: status = %d, want 500", resp.StatusCode) + } +} + +// TestReady covers the three readiness answers: 500 when the API is not +// reachable, 503 when it is up but the backend is not usable, 200 otherwise. +func TestReady(t *testing.T) { + ok := func(context.Context) error { return nil } + bad := func(context.Context) error { return errors.New("down") } + + cases := []struct { + name string + srv *Server + want int + }{ + {"loaded", &Server{Live: ok, Ready: ok}, http.StatusOK}, + {"backend pending", &Server{Live: ok, Ready: bad}, http.StatusServiceUnavailable}, + {"api unreachable", &Server{Live: bad, Ready: ok}, http.StatusInternalServerError}, + } + for _, tc := range cases { + if resp := get(t, tc.srv, "/ready"); resp.StatusCode != tc.want { + t.Errorf("%s: status = %d, want %d", tc.name, resp.StatusCode, tc.want) + } + } +} + +// TestSchemaCache checks the dump is served as JSON, and that a failing dump +// degrades to 500 rather than half a body. +func TestSchemaCache(t *testing.T) { + srv := &Server{SchemaCache: func() ([]byte, error) { + return json.Marshal(map[string]any{"relations": []string{"films"}}) + }} + resp := get(t, srv, "/schema_cache") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/json") { + t.Errorf("Content-Type = %q, want application/json", ct) + } + var body map[string]any + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + + broken := &Server{SchemaCache: func() ([]byte, error) { return nil, errors.New("nope") }} + if resp := get(t, broken, "/schema_cache"); resp.StatusCode != http.StatusInternalServerError { + t.Errorf("broken dump: status = %d, want 500", resp.StatusCode) + } +} + +// TestMetrics checks the Prometheus text rendering: content type, the load +// counters by status, the last query time, and the pool gauge. +func TestMetrics(t *testing.T) { + m := NewMetrics(10) + m.ObserveSchemaCacheLoad(250*time.Millisecond, nil) + m.ObserveSchemaCacheLoad(0, errors.New("introspect failed")) + srv := &Server{Metrics: m} + + resp := get(t, srv, "/metrics") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/plain") { + t.Errorf("Content-Type = %q, want text/plain", ct) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + body := string(raw) + for _, want := range []string{ + `pgrst_schema_cache_loads_total{status="SUCCESS"} 1`, + `pgrst_schema_cache_loads_total{status="FAIL"} 1`, + "pgrst_schema_cache_query_time_seconds 0.25", + "pgrst_db_pool_max 10", + } { + if !strings.Contains(body, want) { + t.Errorf("metrics body missing %q\n%s", want, body) + } + } +} + +// TestUnknownPathIs404 checks the fall-through, including the root. +func TestUnknownPathIs404(t *testing.T) { + srv := &Server{} + for _, path := range []string{"/", "/config", "/live/extra"} { + if resp := get(t, srv, path); resp.StatusCode != http.StatusNotFound { + t.Errorf("%s: status = %d, want 404", path, resp.StatusCode) + } + } +} + +// TestNilChecksDegradeGracefully checks the zero-value server: health reports +// up (nothing to check), the cache dump is empty, metrics body is empty. +func TestNilChecksDegradeGracefully(t *testing.T) { + srv := &Server{} + for _, path := range []string{"/live", "/ready", "/schema_cache", "/metrics"} { + if resp := get(t, srv, path); resp.StatusCode != http.StatusOK { + t.Errorf("%s: status = %d, want 200", path, resp.StatusCode) + } + } +} diff --git a/auth/auth.go b/auth/auth.go index 9434b28..7fb4fd8 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,7 +1,7 @@ // Package auth verifies JSON Web Tokens and resolves the request role, the // single piece of PostgREST's stateless auth model that lives in the frontend // (spec 13). It is backend-agnostic: the signature and algorithm checks, the -// exp/nbf/iat/aud validation, the role resolution, and the PGRST301/PGRST302 +// exp/nbf/iat/aud validation, the role resolution, and the PGRST301/302/303 // codes are produced here and are byte-identical on every engine. Only the // unobservable role switch differs per backend, which this package never touches. // @@ -13,9 +13,11 @@ import ( "crypto/ecdsa" "crypto/rsa" "crypto/x509" + "encoding/base64" "encoding/json" "encoding/pem" "errors" + "fmt" "strings" "time" @@ -39,15 +41,23 @@ type Config struct { // Empty derives the set from the configured keys. "none" is never accepted and // is rejected if listed explicitly. AllowedAlgs []string - // Secret is the shared HMAC secret. When set it must be at least 32 bytes. + // Secret is the jwt-secret value. As in PostgREST it is read three ways: a + // JWK Set JSON, a single JWK JSON, or a plain text HMAC secret (which must + // be at least 32 bytes). Secret []byte + // JWKSet is an explicit JWK Set (or single JWK) JSON. Unlike Secret it has + // no text fallback: an unparseable value is a startup error. + JWKSet string // PublicKeyPEM is a PEM-encoded RSA or ECDSA public key (the static key source // for the asymmetric families). PublicKeyPEM string // Audience, when set, must appear in the token's aud claim. Audience string - // RoleClaimKey names the claim the request role is read from; default "role". - // A leading-dot dotted path (".app_metadata.role") reads a nested claim. + // RoleClaimKey names the claim the request role is read from; default ".role". + // The value is a JSPath expression: dotted keys (".app_metadata.role"), quoted + // keys (."https://example.com/role"), array indexes (".roles[0]"), and a + // trailing filter (".roles[?(@ == \"admin\")]"). An invalid value is a + // startup error. RoleClaimKey string // AnonRole is the role an unauthenticated or role-less request runs as. Empty // means such requests are refused rather than run as the connection identity. @@ -77,13 +87,11 @@ type Result struct { // itself. type Verifier struct { validMethods []string - hmac []byte - rsa *rsa.PublicKey - ecdsa *ecdsa.PublicKey + keys []verKey hasKeys bool audience string - roleKeyPath []string + roleKeyPath []jsPathExp anonRole string permitted map[string]bool skew time.Duration @@ -110,21 +118,32 @@ func NewVerifier(cfg Config) (*Verifier, error) { for _, r := range cfg.PermittedRoles { v.permitted[r] = true } - v.roleKeyPath = parseRoleKey(cfg.RoleClaimKey) + roleKey, err := parseJSPath(cfg.RoleClaimKey) + if err != nil { + return nil, err + } + v.roleKeyPath = roleKey if len(cfg.Secret) > 0 { - if len(cfg.Secret) < minHMACSecret { - return nil, errors.New("jwt-secret must be at least 32 characters") + keys, err := parseSecretKeys(cfg.Secret) + if err != nil { + return nil, err } - v.hmac = cfg.Secret - v.hasKeys = true + v.keys = append(v.keys, keys...) + } + if cfg.JWKSet != "" { + keys, err := parseJWKSet(cfg.JWKSet) + if err != nil { + return nil, fmt.Errorf("jwk-set: %w", err) + } + v.keys = append(v.keys, keys...) } if cfg.PublicKeyPEM != "" { if err := v.loadPublicKey(cfg.PublicKeyPEM); err != nil { return nil, err } - v.hasKeys = true } + v.hasKeys = len(v.keys) > 0 methods, err := v.resolveMethods(cfg.AllowedAlgs) if err != nil { @@ -139,21 +158,30 @@ func NewVerifier(cfg Config) (*Verifier, error) { } // Authenticate resolves the identity of a request from its Authorization header -// value. No bearer token runs as anon; an expired token is PGRST301; any other -// verification failure is PGRST302; a valid token naming a forbidden role is 403. -// When no key material is configured, verification is disabled and every request -// runs as anon, matching PostgREST with no jwt-secret. +// value. No bearer token runs as anon; a token that cannot be decoded is +// PGRST301; a decoded token failing claims validation is PGRST303; a valid +// token naming a forbidden role is 403. +// When no key material is configured the server fails closed, as PostgREST +// does: a presented token is a 500 PGRST300, never silently accepted. func (v *Verifier) Authenticate(authHeader string) (*Result, *pgerr.APIError) { raw, ok := bearer(authHeader) - if !ok || !v.hasKeys { + if !ok { return v.anon() } + if raw == "" { + // A bearer scheme with nothing after it is a malformed credential, not + // an anonymous request: PostgREST answers PGRST301 with this message. + return nil, pgerr.ErrJWTDecode("Empty JWT is sent in Authorization header") + } + if !v.hasKeys { + return nil, pgerr.ErrJWTSecretMissing() + } if v.cache != nil { if claims, hit := v.cache.get(raw); hit { - // A cached entry never extends a token's life: the time claims are - // re-checked against the live clock on every request (spec 13). - if apiErr := v.checkTime(claims); apiErr != nil { + // A cached entry never extends a token's life: the claims are + // re-validated against the live clock on every request (spec 13). + if apiErr := v.validateClaims(claims); apiErr != nil { return nil, apiErr } return v.resolve(claims) @@ -171,81 +199,234 @@ func (v *Verifier) Authenticate(authHeader string) (*Result, *pgerr.APIError) { } // verify checks the signature, the pinned algorithm, and the time and audience -// claims with skew, returning the claim set or a JWT error. The error message is -// fixed text: the token and the secret are never reflected back to the client. +// claims with skew, returning the claim set or a JWT error. The error messages +// are PostgREST's fixed texts: the token and the secret are never reflected +// back to the client. func (v *Verifier) verify(raw string) (map[string]any, *pgerr.APIError) { + if n := strings.Count(raw, ".") + 1; n != 3 { + return nil, pgerr.ErrJWTDecode(fmt.Sprintf("Expected 3 parts in JWT; got %d", n)) + } + if apiErr := v.checkAlg(raw); apiErr != nil { + return nil, apiErr + } claims := jwt.MapClaims{} + // Claims validation is done by validateClaims below, not by the library: + // PostgREST's rules differ (an absent or empty aud passes, iat is checked, + // and the type errors carry their own PGRST303 messages). opts := []jwt.ParserOption{ jwt.WithValidMethods(v.validMethods), - jwt.WithLeeway(v.skew), - jwt.WithTimeFunc(v.now), - } - if v.audience != "" { - opts = append(opts, jwt.WithAudience(v.audience)) + jwt.WithoutClaimsValidation(), } if _, err := jwt.NewParser(opts...).ParseWithClaims(raw, claims, v.keyfunc); err != nil { - if errors.Is(err, jwt.ErrTokenExpired) { - return nil, pgerr.ErrJWTExpired() - } - return nil, pgerr.ErrJWTInvalid("JWT invalid") + return nil, mapJWTError(err) + } + if apiErr := v.validateClaims(claims); apiErr != nil { + return nil, apiErr } return map[string]any(claims), nil } -// keyfunc returns the verification key for the token's algorithm family. The -// allowed-methods parser option already blocks a disallowed alg before this runs, -// so the algorithm-confusion swap (an RS token verified against an HMAC secret) +// checkAlg reads the unverified alg header of a compact JWT and rejects a value +// outside the pinned method set before any cryptography runs. The three failure +// shapes carry PostgREST's exact messages: an unsecured token, an alg the +// library does not know, and a known alg with no matching key. +func (v *Verifier) checkAlg(raw string) *pgerr.APIError { + headerPart := raw[:strings.IndexByte(raw, '.')] + headerJSON, err := base64.RawURLEncoding.DecodeString(headerPart) + if err != nil { + return pgerr.ErrJWTDecode("JWT cryptographic operation failed") + } + var header struct { + Alg string `json:"alg"` + } + if err := json.Unmarshal(headerJSON, &header); err != nil { + return pgerr.ErrJWTDecode("JWT cryptographic operation failed") + } + if strings.EqualFold(header.Alg, "none") { + return pgerr.ErrJWTDecode("Wrong or unsupported encoding algorithm"). + WithDetails("JWT is unsecured but expected 'alg' was not 'none'") + } + if jwt.GetSigningMethod(header.Alg) == nil { + return pgerr.ErrJWTDecode("JWT cryptographic operation failed") + } + for _, m := range v.validMethods { + if header.Alg == m { + return nil + } + } + return pgerr.ErrJWTDecode("No suitable key or wrong key type"). + WithDetails("No suitable key was found to decode the JWT") +} + +// mapJWTError translates a golang-jwt failure onto the v14 code split: claim +// validation failures are PGRST303, everything that prevented decoding or +// verifying the token is PGRST301. The messages are PostgREST's own. +func mapJWTError(err error) *pgerr.APIError { + switch { + case errors.Is(err, jwt.ErrTokenExpired): + return pgerr.ErrJWTClaims("JWT expired") + case errors.Is(err, jwt.ErrTokenNotValidYet): + return pgerr.ErrJWTClaims("JWT not yet valid") + case errors.Is(err, jwt.ErrTokenUsedBeforeIssued): + return pgerr.ErrJWTClaims("JWT issued at future") + case errors.Is(err, jwt.ErrTokenInvalidAudience): + return pgerr.ErrJWTClaims("JWT not in audience") + case errors.Is(err, jwt.ErrTokenInvalidClaims): + return pgerr.ErrJWTClaims("Parsing claims failed") + case errors.Is(err, jwt.ErrTokenSignatureInvalid): + return pgerr.ErrJWTDecode("No suitable key or wrong key type"). + WithDetails("None of the keys was able to decode the JWT") + case errors.Is(err, jwt.ErrTokenUnverifiable): + return pgerr.ErrJWTDecode("No suitable key or wrong key type"). + WithDetails("No suitable key was found to decode the JWT") + default: + return pgerr.ErrJWTDecode("JWT cryptographic operation failed") + } +} + +// keyfunc selects the verification keys for a token. A kid header narrows the +// set to the keys carrying that kid; a kid-less token tries every key of the +// right family in turn, as the upstream jose library does. The allowed-methods +// parser option already blocks a disallowed alg before this runs, so the +// algorithm-confusion swap (an RS token verified against an HMAC secret) // cannot reach a key of the wrong family. func (v *Verifier) keyfunc(t *jwt.Token) (any, error) { - switch t.Method.(type) { - case *jwt.SigningMethodHMAC: - if v.hmac == nil { - return nil, errors.New("no HMAC key configured") + kid, _ := t.Header["kid"].(string) + set := jwt.VerificationKeySet{} + for _, k := range v.keys { + if kid != "" && k.kid != kid { + continue } - return v.hmac, nil - case *jwt.SigningMethodRSA: - if v.rsa == nil { - return nil, errors.New("no RSA key configured") + if k.alg != "" && k.alg != t.Method.Alg() { + continue } - return v.rsa, nil - case *jwt.SigningMethodECDSA: - if v.ecdsa == nil { - return nil, errors.New("no ECDSA key configured") + if !methodMatchesKey(t.Method, k.key) { + continue } - return v.ecdsa, nil + set.Keys = append(set.Keys, k.key) + } + switch len(set.Keys) { + case 0: + return nil, errors.New("no suitable key was found to decode the JWT") + case 1: + return set.Keys[0], nil default: - return nil, errors.New("unsupported signing method") + return set, nil + } +} + +// methodMatchesKey reports whether a verification key belongs to the family of +// a signing method. +func methodMatchesKey(method jwt.SigningMethod, key any) bool { + switch method.(type) { + case *jwt.SigningMethodHMAC: + _, ok := key.([]byte) + return ok + case *jwt.SigningMethodRSA, *jwt.SigningMethodRSAPSS: + _, ok := key.(*rsa.PublicKey) + return ok + case *jwt.SigningMethodECDSA: + _, ok := key.(*ecdsa.PublicKey) + return ok } + return false } -// checkTime re-validates the exp and nbf claims against the live clock with the -// configured skew. It runs on a cache hit so a cached verification can never -// resurrect an expired token. -func (v *Verifier) checkTime(claims map[string]any) *pgerr.APIError { - now := v.now() - if exp, ok := numClaim(claims, "exp"); ok { - if now.After(time.Unix(exp, 0).Add(v.skew)) { - return pgerr.ErrJWTExpired() +// validateClaims applies PostgREST's claim checks in its order: exp, nbf, iat, +// then aud, each with the 30 second skew. An absent or null claim passes; a +// present claim of the wrong type is its own PGRST303 error. It runs on every +// request, including cache hits, so a cached verification can never resurrect +// an expired token. +func (v *Verifier) validateClaims(claims map[string]any) *pgerr.APIError { + now := v.now().Unix() + skew := int64(v.skew / time.Second) + + if val, ok := presentClaim(claims, "exp"); ok { + exp, isNum := claimNumber(val) + if !isNum { + return pgerr.ErrJWTClaims("The JWT 'exp' claim must be a number") + } + if now-skew > exp { + return pgerr.ErrJWTClaims("JWT expired") + } + } + if val, ok := presentClaim(claims, "nbf"); ok { + nbf, isNum := claimNumber(val) + if !isNum { + return pgerr.ErrJWTClaims("The JWT 'nbf' claim must be a number") + } + if now+skew < nbf { + return pgerr.ErrJWTClaims("JWT not yet valid") + } + } + if val, ok := presentClaim(claims, "iat"); ok { + iat, isNum := claimNumber(val) + if !isNum { + return pgerr.ErrJWTClaims("The JWT 'iat' claim must be a number") + } + if now+skew < iat { + return pgerr.ErrJWTClaims("JWT issued at future") } } - if nbf, ok := numClaim(claims, "nbf"); ok { - if now.Before(time.Unix(nbf, 0).Add(-v.skew)) { - return pgerr.ErrJWTInvalid("JWT invalid") + if val, ok := presentClaim(claims, "aud"); ok { + if apiErr := v.checkAud(val); apiErr != nil { + return apiErr + } + } + return nil +} + +// checkAud validates the aud claim the PostgREST way: a string must match the +// configured audience, an array passes when empty or when any element matches, +// and anything else is a type error. With no jwt-aud configured every audience +// matches. +func (v *Verifier) checkAud(val any) *pgerr.APIError { + switch aud := val.(type) { + case string: + if !v.audMatches(aud) { + return pgerr.ErrJWTClaims("JWT not in audience") + } + case []any: + matched := len(aud) == 0 + for _, el := range aud { + s, isStr := el.(string) + if !isStr { + return pgerr.ErrJWTClaims("The JWT 'aud' claim must be a string or an array of strings") + } + if v.audMatches(s) { + matched = true + } + } + if !matched { + return pgerr.ErrJWTClaims("JWT not in audience") } + default: + return pgerr.ErrJWTClaims("The JWT 'aud' claim must be a string or an array of strings") } return nil } +// audMatches reports whether a token audience satisfies the configured jwt-aud. +// An unset jwt-aud accepts every audience. +func (v *Verifier) audMatches(aud string) bool { + return v.audience == "" || v.audience == aud +} + // resolve reads the role from the claims and applies the anon fallback and the -// permitted-role check. A valid token that resolves to no role and has no anon -// fallback is refused; a role outside the permitted set is a 403. +// permitted-role check. Only a genuinely absent role claim falls back to the +// anonymous role: a present claim of any other type is rendered to text and +// used as the role name, exactly as PostgREST does (the engine or the authz +// registry then denies a role that does not exist, rather than the client +// being silently downgraded to anonymous data). A valid token that resolves +// to no role and has no anon fallback is refused; a role outside the +// permitted set is a 403. func (v *Verifier) resolve(claims map[string]any) (*Result, *pgerr.APIError) { - role := roleFromClaims(claims, v.roleKeyPath) - if role == "" { + role, present := roleFromClaims(claims, v.roleKeyPath) + if !present { role = v.anonRole - } - if role == "" { - return nil, errAnonDisabled() + if role == "" { + return nil, errAnonDisabled() + } } if apiErr := v.checkPermitted(role); apiErr != nil { return nil, apiErr @@ -272,10 +453,10 @@ func (v *Verifier) anon() (*Result, *pgerr.APIError) { } // errAnonDisabled is the 401 a request gets when it presents no usable identity -// and no anon role is configured, so it cannot be run as anyone. +// and no anon role is configured, so it cannot be run as anyone. The message is +// PostgREST's exact PGRST302 text. func errAnonDisabled() *pgerr.APIError { - return pgerr.ErrJWTInvalid("anonymous access is disabled"). - WithMessage("no JWT was sent and no anonymous role is configured") + return pgerr.ErrJWTRequired() } // loadPublicKey parses a PEM-encoded RSA or ECDSA public key into the verifier. @@ -289,10 +470,8 @@ func (v *Verifier) loadPublicKey(pemText string) error { return errors.New("jwt public key is not a valid PKIX key") } switch k := key.(type) { - case *rsa.PublicKey: - v.rsa = k - case *ecdsa.PublicKey: - v.ecdsa = k + case *rsa.PublicKey, *ecdsa.PublicKey: + v.keys = append(v.keys, verKey{key: k}) default: return errors.New("jwt public key is neither RSA nor ECDSA") } @@ -310,64 +489,90 @@ func (v *Verifier) resolveMethods(allowed []string) ([]string, error) { } return allowed, nil } + var hmac, rsaKey, ecdsaKey bool + for _, k := range v.keys { + switch k.key.(type) { + case []byte: + hmac = true + case *rsa.PublicKey: + rsaKey = true + case *ecdsa.PublicKey: + ecdsaKey = true + } + } var methods []string - if v.hmac != nil { + if hmac { methods = append(methods, "HS256", "HS384", "HS512") } - if v.rsa != nil { - methods = append(methods, "RS256", "RS384", "RS512") + if rsaKey { + methods = append(methods, "RS256", "RS384", "RS512", "PS256", "PS384", "PS512") } - if v.ecdsa != nil { + if ecdsaKey { methods = append(methods, "ES256", "ES384", "ES512") } return methods, nil } -// bearer extracts the token from an Authorization header value, accepting the -// "Bearer" scheme case-insensitively. It reports false for any other header. +// bearer extracts the token from an Authorization header value, mirroring the +// wai-extra extractBearerAuth PostgREST uses: the first whitespace ends the +// scheme word, the comparison is case-insensitive, and the token is whatever +// follows the leading whitespace, possibly empty. It reports false only when +// the credentials are not a bearer scheme at all, which is the anonymous +// path; "Bearer" with an empty token reports true so the caller can answer +// PGRST301 instead of downgrading the client to anon. func bearer(header string) (string, bool) { - const scheme = "bearer " - if len(header) < len(scheme) || !strings.EqualFold(header[:len(scheme)], scheme) { + scheme, rest := header, "" + if i := strings.IndexAny(header, " \t"); i >= 0 { + scheme, rest = header[:i], header[i+1:] + } + if !strings.EqualFold(scheme, "Bearer") { return "", false } - tok := strings.TrimSpace(header[len(scheme):]) - return tok, tok != "" + return strings.TrimLeft(rest, " \t"), true } -// parseRoleKey splits a role-claim key into a path of map keys. A leading dot is -// optional; an empty key defaults to the single segment "role". -func parseRoleKey(key string) []string { - key = strings.TrimPrefix(strings.TrimSpace(key), ".") - if key == "" { - return []string{"role"} +// roleFromClaims walks the role-claim JSPath over the claim set, reporting +// whether the path resolved to a value at all. A resolved value is rendered +// the way PostgREST renders a claim where text is expected: a string is taken +// bare, anything else (a number, bool, null, array, or object) becomes its +// compact JSON text and is used as the role name verbatim. +func roleFromClaims(claims map[string]any, path []jsPathExp) (string, bool) { + val, ok := walkJSPath(claims, path) + if !ok { + return "", false } - return strings.Split(key, ".") + return unquoted(val), true } -// roleFromClaims walks the claim path and returns the string value at its end, -// or "" if any segment is missing or the value is not a string. -func roleFromClaims(claims map[string]any, path []string) string { - var cur any = claims - for _, seg := range path { - m, ok := cur.(map[string]any) - if !ok { - return "" - } - cur, ok = m[seg] - if !ok { - return "" - } - } - if s, ok := cur.(string); ok { +// unquoted renders a claim value as the text PostgREST would use it as: a +// string stays bare, every other JSON value is its compact rendering ("null", +// "42", "true", "[\"a\"]"). +func unquoted(val any) string { + if s, ok := val.(string); ok { return s } - return "" + b, err := json.Marshal(val) + if err != nil { + return fmt.Sprint(val) + } + return string(b) +} + +// presentClaim reports a claim's value when it is present and non-null. An +// absent or null claim is skipped by every check, as upstream. +func presentClaim(claims map[string]any, name string) (any, bool) { + val, ok := claims[name] + if !ok || val == nil { + return nil, false + } + return val, true } -// numClaim reads a numeric claim as a Unix-seconds int64, handling the float64 -// and json.Number forms a decoded claim set can carry. -func numClaim(claims map[string]any, name string) (int64, bool) { - switch t := claims[name].(type) { +// claimNumber reads a numeric claim value as Unix seconds, handling the +// float64 and json.Number forms a decoded claim set can carry. A non-number is +// reported false and becomes the claim's PGRST303 type error. +func claimNumber(val any) (int64, bool) { + switch t := val.(type) { case float64: return int64(t), true case int64: @@ -376,6 +581,9 @@ func numClaim(claims map[string]any, name string) (int64, bool) { if n, err := t.Int64(); err == nil { return n, true } + if f, err := t.Float64(); err == nil { + return int64(f), true + } } return 0, false } diff --git a/auth/auth_test.go b/auth/auth_test.go index c6389c8..5d1cb9d 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -90,6 +90,89 @@ func TestBearerSchemeCaseInsensitive(t *testing.T) { } } +func TestNonBearerSchemeRunsAnon(t *testing.T) { + // Credentials of another scheme are not bearer tokens at all; PostgREST + // ignores them and the request runs anonymous. + v := hmacVerifier(t, Config{}) + res, err := v.Authenticate("Basic d2ViX3VzZXI6cHc=") + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if res.Role != anonRole || !res.Anonymous { + t.Fatalf("non-bearer credentials = %+v, want anon", res) + } +} + +func TestEmptyBearerIs301(t *testing.T) { + // "Authorization: Bearer" with no token is a malformed credential, not an + // anonymous request: PostgREST answers 401 PGRST301 with this exact + // message and the invalid_token challenge. + v := hmacVerifier(t, Config{}) + for _, header := range []string{"Bearer", "Bearer ", "Bearer ", "bearer\t"} { + _, err := v.Authenticate(header) + if err == nil || err.Code != "PGRST301" || err.HTTPStatus != 401 { + t.Fatalf("Authenticate(%q) = %v, want 401 PGRST301", header, err) + } + if err.Message != "Empty JWT is sent in Authorization header" { + t.Errorf("Authenticate(%q) message = %q, want the exact PostgREST text", header, err.Message) + } + want := `Bearer error="invalid_token", error_description="Empty JWT is sent in Authorization header"` + if err.WWWAuthenticate != want { + t.Errorf("Authenticate(%q) WWW-Authenticate = %q, want %q", header, err.WWWAuthenticate, want) + } + } +} + +func TestEmptyBearerBeatsMissingSecret(t *testing.T) { + // The empty-credential check answers before the key-material check: an + // empty bearer is the client's malformed request either way. + v, err := NewVerifier(Config{AnonRole: anonRole}) // no keys + if err != nil { + t.Fatalf("NewVerifier: %v", err) + } + _, aerr := v.Authenticate("Bearer ") + if aerr == nil || aerr.Code != "PGRST301" { + t.Fatalf("empty bearer with no keys = %v, want PGRST301", aerr) + } +} + +func TestNonStringRoleClaimUsedVerbatim(t *testing.T) { + // PostgREST renders a non-string role claim to its compact JSON text and + // uses it as the role name; the engine (or the authz registry) then denies + // a role that does not exist. The client is never silently downgraded to + // the anonymous role. Verified against postgrest/14.12: role 123 yields + // `role "123" does not exist`, role null yields `role "null" does not + // exist`, and so on. + cases := []struct { + name string + claim any + want string + }{ + {"number", 123, "123"}, + {"float", 12.5, "12.5"}, + {"bool", true, "true"}, + {"null", nil, "null"}, + {"array", []any{"web_user"}, `["web_user"]`}, + {"object", map[string]any{"a": 1}, `{"a":1}`}, + } + v := hmacVerifier(t, Config{}) + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + tok := signHS(t, jwt.MapClaims{"role": c.claim}) + res, err := v.Authenticate("Bearer " + tok) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if res.Role != c.want { + t.Errorf("role = %q, want %q (no anon downgrade)", res.Role, c.want) + } + if res.Anonymous { + t.Error("a present role claim must not be anonymous") + } + }) + } +} + func TestTokenWithNoRoleFallsBackToAnon(t *testing.T) { v := hmacVerifier(t, Config{}) tok := signHS(t, jwt.MapClaims{"sub": "123"}) @@ -105,19 +188,23 @@ func TestTokenWithNoRoleFallsBackToAnon(t *testing.T) { } } -func TestExpiredTokenIs301(t *testing.T) { +func TestExpiredTokenIs303(t *testing.T) { v := hmacVerifier(t, Config{}) tok := signHS(t, jwt.MapClaims{ "role": "web_user", "exp": clockNow.Add(-time.Hour).Unix(), }) _, err := v.Authenticate("Bearer " + tok) - if err == nil || err.Code != "PGRST301" { - t.Fatalf("want PGRST301, got %v", err) + if err == nil || err.Code != "PGRST303" || err.Message != "JWT expired" { + t.Fatalf("want PGRST303 JWT expired, got %v", err) } if err.HTTPStatus != 401 { t.Errorf("status = %d, want 401", err.HTTPStatus) } + want := `Bearer error="invalid_token", error_description="JWT expired"` + if err.WWWAuthenticate != want { + t.Errorf("WWW-Authenticate = %q, want %q", err.WWWAuthenticate, want) + } } func TestExpiryWithinSkewStillValid(t *testing.T) { @@ -132,34 +219,55 @@ func TestExpiryWithinSkewStillValid(t *testing.T) { } } -func TestNotBeforeIs302(t *testing.T) { +func TestNotBeforeIs303(t *testing.T) { v := hmacVerifier(t, Config{}) tok := signHS(t, jwt.MapClaims{ "role": "web_user", "nbf": clockNow.Add(time.Hour).Unix(), }) _, err := v.Authenticate("Bearer " + tok) - if err == nil || err.Code != "PGRST302" { - t.Fatalf("want PGRST302, got %v", err) + if err == nil || err.Code != "PGRST303" || err.Message != "JWT not yet valid" { + t.Fatalf("want PGRST303 JWT not yet valid, got %v", err) } } -func TestBadSignatureIs302(t *testing.T) { +func TestBadSignatureIs301(t *testing.T) { v := hmacVerifier(t, Config{}) tok := signHS(t, jwt.MapClaims{"role": "web_user"}) // flip the last character of the signature. bad := tok[:len(tok)-1] + flip(tok[len(tok)-1]) _, err := v.Authenticate("Bearer " + bad) - if err == nil || err.Code != "PGRST302" { - t.Fatalf("want PGRST302, got %v", err) + if err == nil || err.Code != "PGRST301" || err.Message != "No suitable key or wrong key type" { + t.Fatalf("want PGRST301 No suitable key or wrong key type, got %v", err) + } + if err.Details == nil || *err.Details != "None of the keys was able to decode the JWT" { + t.Errorf("details = %v, want the none-of-the-keys detail", err.Details) } } -func TestMalformedTokenIs302(t *testing.T) { +func TestMalformedTokenIs301(t *testing.T) { v := hmacVerifier(t, Config{}) _, err := v.Authenticate("Bearer not.a.jwt") - if err == nil || err.Code != "PGRST302" { - t.Fatalf("want PGRST302, got %v", err) + if err == nil || err.Code != "PGRST301" || err.Message != "JWT cryptographic operation failed" { + t.Fatalf("want PGRST301 JWT cryptographic operation failed, got %v", err) + } +} + +func TestWrongPartCountMessage(t *testing.T) { + v := hmacVerifier(t, Config{}) + cases := []struct { + token string + want string + }{ + {"justonepart", "Expected 3 parts in JWT; got 1"}, + {"two.parts", "Expected 3 parts in JWT; got 2"}, + {"a.b.c.d", "Expected 3 parts in JWT; got 4"}, + } + for _, c := range cases { + _, err := v.Authenticate("Bearer " + c.token) + if err == nil || err.Code != "PGRST301" || err.Message != c.want { + t.Errorf("token %q: want PGRST301 %q, got %v", c.token, c.want, err) + } } } @@ -171,8 +279,11 @@ func TestNoneAlgorithmRejected(t *testing.T) { t.Fatalf("sign none: %v", err) } _, aerr := v.Authenticate("Bearer " + s) - if aerr == nil || aerr.Code != "PGRST302" { - t.Fatalf("the none alg must be rejected, got %v", aerr) + if aerr == nil || aerr.Code != "PGRST301" || aerr.Message != "Wrong or unsupported encoding algorithm" { + t.Fatalf("the none alg must be rejected with PGRST301, got %v", aerr) + } + if aerr.Details == nil || *aerr.Details != "JWT is unsecured but expected 'alg' was not 'none'" { + t.Errorf("details = %v, want the unsecured-token detail", aerr.Details) } } @@ -197,8 +308,34 @@ func TestAudienceEnforced(t *testing.T) { t.Fatalf("matching aud must verify: %v", err) } bad := signHS(t, jwt.MapClaims{"role": "web_user", "aud": "other"}) - if _, err := v.Authenticate("Bearer " + bad); err == nil || err.Code != "PGRST302" { - t.Fatalf("wrong aud must be PGRST302, got %v", err) + if _, err := v.Authenticate("Bearer " + bad); err == nil || err.Code != "PGRST303" || err.Message != "JWT not in audience" { + t.Fatalf("wrong aud must be PGRST303 JWT not in audience, got %v", err) + } +} + +func TestTokenWithoutAudAccepted(t *testing.T) { + // "If the aud key is not present ... allowed for all audiences": a token + // with no aud claim passes even when jwt-aud is configured. + v := hmacVerifier(t, Config{Audience: testAud}) + tok := signHS(t, jwt.MapClaims{"role": "web_user"}) + res, err := v.Authenticate("Bearer " + tok) + if err != nil { + t.Fatalf("a token without aud must verify: %v", err) + } + if res.Role != "web_user" { + t.Fatalf("role = %q", res.Role) + } +} + +func TestFutureIssuedAtIs303(t *testing.T) { + v := hmacVerifier(t, Config{}) + tok := signHS(t, jwt.MapClaims{ + "role": "web_user", + "iat": clockNow.Add(time.Hour).Unix(), + }) + _, err := v.Authenticate("Bearer " + tok) + if err == nil || err.Code != "PGRST303" || err.Message != "JWT issued at future" { + t.Fatalf("want PGRST303 JWT issued at future, got %v", err) } } @@ -242,23 +379,48 @@ func TestAnonDisabledWithoutToken(t *testing.T) { t.Fatalf("NewVerifier: %v", err) } _, aerr := v.Authenticate("") - if aerr == nil || aerr.HTTPStatus != 401 { - t.Fatalf("no anon role + no token must be 401, got %v", aerr) + if aerr == nil || aerr.HTTPStatus != 401 || aerr.Code != "PGRST302" { + t.Fatalf("no anon role + no token must be 401 PGRST302, got %v", aerr) + } + if aerr.Message != "Anonymous access is disabled" { + t.Errorf("message = %q, want the exact PGRST302 text", aerr.Message) + } + if aerr.WWWAuthenticate != "Bearer" { + t.Errorf("WWW-Authenticate = %q, want Bearer", aerr.WWWAuthenticate) } } -func TestNoKeysDisablesVerification(t *testing.T) { +func TestNoKeysNoTokenRunsAnon(t *testing.T) { v, err := NewVerifier(Config{AnonRole: anonRole}) if err != nil { t.Fatalf("NewVerifier: %v", err) } - // A token is presented but no key is configured: it runs as anon. - res, aerr := v.Authenticate("Bearer anything.at.all") + res, aerr := v.Authenticate("") if aerr != nil { t.Fatalf("Authenticate: %v", aerr) } - if res.Role != anonRole { - t.Fatalf("role = %q, want anon when verification is off", res.Role) + if res.Role != anonRole || !res.Anonymous { + t.Fatalf("res = %+v, want anon", res) + } +} + +func TestNoKeysWithTokenIs500(t *testing.T) { + // A token presented to a server without key material is a server + // misconfiguration, not an anonymous request: PostgREST fails closed + // with 500 PGRST300. + v, err := NewVerifier(Config{AnonRole: anonRole}) + if err != nil { + t.Fatalf("NewVerifier: %v", err) + } + _, aerr := v.Authenticate("Bearer anything.at.all") + if aerr == nil || aerr.HTTPStatus != 500 || aerr.Code != "PGRST300" { + t.Fatalf("token with no keys must be 500 PGRST300, got %v", aerr) + } + if aerr.Message != "Server lacks JWT secret" { + t.Errorf("message = %q, want Server lacks JWT secret", aerr.Message) + } + if aerr.WWWAuthenticate != "" { + t.Errorf("a PGRST300 must not carry a challenge, got %q", aerr.WWWAuthenticate) } } @@ -298,9 +460,13 @@ func TestAlgConfusionRejected(t *testing.T) { if err != nil { t.Fatalf("sign: %v", err) } - if _, aerr := v.Authenticate("Bearer " + signed); aerr == nil { + _, aerr := v.Authenticate("Bearer " + signed) + if aerr == nil { t.Fatal("an HS256 token must not verify against an RSA-only verifier") } + if aerr.Code != "PGRST301" || aerr.Message != "No suitable key or wrong key type" { + t.Errorf("want PGRST301 No suitable key or wrong key type, got %v", aerr) + } } func TestECDSAVerification(t *testing.T) { diff --git a/auth/cache_test.go b/auth/cache_test.go index 6549410..1b8cd9d 100644 --- a/auth/cache_test.go +++ b/auth/cache_test.go @@ -75,8 +75,8 @@ func TestCacheHitSkipsVerifyButRechecksExpiry(t *testing.T) { // Advance the clock past exp: the cached entry must not extend its life. v.now = fixedClock(clockNow.Add(2 * time.Minute)) _, err := v.Authenticate("Bearer " + tok) - if err == nil || err.Code != "PGRST301" { - t.Fatalf("a cached but now-expired token must be PGRST301, got %v", err) + if err == nil || err.Code != "PGRST303" { + t.Fatalf("a cached but now-expired token must be PGRST303, got %v", err) } } diff --git a/auth/checktime_test.go b/auth/checktime_test.go deleted file mode 100644 index 9749ef4..0000000 --- a/auth/checktime_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package auth - -import ( - "encoding/json" - "testing" - "time" -) - -// checkTime re-validates the exp and nbf claims on a cache hit, so a cached -// verification can never resurrect a token that has since expired or is not yet -// valid. It is reached directly here because the cache-hit revalidation is hard -// to provoke through the public verify path without racing the cache. -func TestCheckTime(t *testing.T) { - v := hmacVerifier(t, Config{}) - now := clockNow.Unix() - - cases := []struct { - name string - claims map[string]any - expErr string // "" means it must pass - }{ - {"valid window", map[string]any{"exp": float64(now + 60), "nbf": float64(now - 60)}, ""}, - {"no time claims", map[string]any{}, ""}, - {"expired", map[string]any{"exp": float64(now - 60)}, "PGRST301"}, - {"expired within skew", map[string]any{"exp": float64(now - 10)}, ""}, // 30s skew - {"not yet valid", map[string]any{"nbf": float64(now + 60)}, "PGRST302"}, - {"not-before within skew", map[string]any{"nbf": float64(now + 10)}, ""}, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - err := v.checkTime(c.claims) - if c.expErr == "" { - if err != nil { - t.Fatalf("want pass, got %v", err) - } - return - } - if err == nil || err.Code != c.expErr { - t.Fatalf("want %s, got %v", c.expErr, err) - } - }) - } -} - -// numClaim reads a numeric time claim across the forms a decoded claim set can -// carry, and reports false for an absent claim or one whose type or value is not -// a usable integer. -func TestNumClaim(t *testing.T) { - cases := []struct { - name string - in any - want int64 - wantO bool - }{ - {"float64", float64(1700), 1700, true}, - {"int64", int64(1700), 1700, true}, - {"json.Number", json.Number("1700"), 1700, true}, - {"json.Number non-integer", json.Number("1.5e3"), 0, false}, - {"absent", nil, 0, false}, - {"wrong type", "1700", 0, false}, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - claims := map[string]any{} - if c.in != nil { - claims["exp"] = c.in - } - got, ok := numClaim(claims, "exp") - if got != c.want || ok != c.wantO { - t.Errorf("numClaim = (%d, %v), want (%d, %v)", got, ok, c.want, c.wantO) - } - }) - } -} - -// A guard that the test clock and skew defaults are what the time cases assume, -// so a future change to either is caught here rather than silently shifting the -// windows above. -func TestCheckTimeAssumptions(t *testing.T) { - v := hmacVerifier(t, Config{}) - if v.skew != 30*time.Second { - t.Fatalf("skew = %v, want the 30s default the cases assume", v.skew) - } -} diff --git a/auth/claims_test.go b/auth/claims_test.go new file mode 100644 index 0000000..afc1acb --- /dev/null +++ b/auth/claims_test.go @@ -0,0 +1,143 @@ +package auth + +import ( + "encoding/json" + "testing" + "time" +) + +// validateClaims applies PostgREST's exp/nbf/iat/aud rules with the 30s skew. +// It is reached directly here because it also guards the cache-hit path, which +// is hard to provoke through the public verify path without racing the cache. +func TestValidateClaims(t *testing.T) { + v := hmacVerifier(t, Config{}) + now := clockNow.Unix() + + cases := []struct { + name string + claims map[string]any + expErr string // "" means it must pass + message string // when non-empty the exact PGRST message + }{ + {"valid window", map[string]any{"exp": float64(now + 60), "nbf": float64(now - 60)}, "", ""}, + {"no claims at all", map[string]any{}, "", ""}, + {"expired", map[string]any{"exp": float64(now - 60)}, "PGRST303", "JWT expired"}, + {"expired within skew", map[string]any{"exp": float64(now - 10)}, "", ""}, + {"not yet valid", map[string]any{"nbf": float64(now + 60)}, "PGRST303", "JWT not yet valid"}, + {"not-before within skew", map[string]any{"nbf": float64(now + 10)}, "", ""}, + {"issued in future", map[string]any{"iat": float64(now + 60)}, "PGRST303", "JWT issued at future"}, + {"issued-at within skew", map[string]any{"iat": float64(now + 10)}, "", ""}, + {"issued in past is fine", map[string]any{"iat": float64(now - 3600)}, "", ""}, + {"exp not a number", map[string]any{"exp": "soon"}, "PGRST303", "The JWT 'exp' claim must be a number"}, + {"nbf not a number", map[string]any{"nbf": true}, "PGRST303", "The JWT 'nbf' claim must be a number"}, + {"iat not a number", map[string]any{"iat": "x"}, "PGRST303", "The JWT 'iat' claim must be a number"}, + {"null time claim passes", map[string]any{"exp": nil}, "", ""}, + {"exp checked before nbf", map[string]any{ + "exp": float64(now - 60), "nbf": float64(now + 60), + }, "PGRST303", "JWT expired"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + err := v.validateClaims(c.claims) + if c.expErr == "" { + if err != nil { + t.Fatalf("want pass, got %v", err) + } + return + } + if err == nil || err.Code != c.expErr { + t.Fatalf("want %s, got %v", c.expErr, err) + } + if c.message != "" && err.Message != c.message { + t.Errorf("message = %q, want %q", err.Message, c.message) + } + }) + } +} + +// The aud rules with a configured jwt-aud: absent, null, and empty-array +// audiences pass (the token is valid for all audiences), a matching string or +// array element passes, and a wrong type is its own PGRST303. +func TestValidateClaimsAudience(t *testing.T) { + v := hmacVerifier(t, Config{Audience: testAud}) + + cases := []struct { + name string + aud any + present bool + expErr string + message string + }{ + {"absent aud passes", nil, false, "", ""}, + {"null aud passes", nil, true, "", ""}, + {"empty array passes", []any{}, true, "", ""}, + {"matching string", testAud, true, "", ""}, + {"matching array element", []any{"other", testAud}, true, "", ""}, + {"wrong string", "other", true, "PGRST303", "JWT not in audience"}, + {"no array element matches", []any{"a", "b"}, true, "PGRST303", "JWT not in audience"}, + {"non-string element", []any{42}, true, "PGRST303", "The JWT 'aud' claim must be a string or an array of strings"}, + {"number aud", float64(7), true, "PGRST303", "The JWT 'aud' claim must be a string or an array of strings"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + claims := map[string]any{} + if c.present { + claims["aud"] = c.aud + } + err := v.validateClaims(claims) + if c.expErr == "" { + if err != nil { + t.Fatalf("want pass, got %v", err) + } + return + } + if err == nil || err.Code != c.expErr || err.Message != c.message { + t.Fatalf("want %s %q, got %v", c.expErr, c.message, err) + } + }) + } +} + +// With no jwt-aud configured every audience is accepted. +func TestAudienceUncheckedWhenUnconfigured(t *testing.T) { + v := hmacVerifier(t, Config{}) + if err := v.validateClaims(map[string]any{"aud": "anything"}); err != nil { + t.Fatalf("aud must be ignored with no jwt-aud, got %v", err) + } +} + +// claimNumber reads a numeric claim across the forms a decoded claim set can +// carry, and reports false for a value whose type is not a number. +func TestClaimNumber(t *testing.T) { + cases := []struct { + name string + in any + want int64 + wantO bool + }{ + {"float64", float64(1700), 1700, true}, + {"int64", int64(1700), 1700, true}, + {"json.Number", json.Number("1700"), 1700, true}, + {"json.Number scientific", json.Number("1.5e3"), 1500, true}, + {"wrong type", "1700", 0, false}, + {"bool", true, 0, false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, ok := claimNumber(c.in) + if got != c.want || ok != c.wantO { + t.Errorf("claimNumber = (%d, %v), want (%d, %v)", got, ok, c.want, c.wantO) + } + }) + } +} + +// A guard that the test clock and skew defaults are what the time cases assume, +// so a future change to either is caught here rather than silently shifting the +// windows above. +func TestValidateClaimsAssumptions(t *testing.T) { + v := hmacVerifier(t, Config{}) + if v.skew != 30*time.Second { + t.Fatalf("skew = %v, want the 30s default the cases assume", v.skew) + } +} diff --git a/auth/jspath.go b/auth/jspath.go new file mode 100644 index 0000000..089ddf3 --- /dev/null +++ b/auth/jspath.go @@ -0,0 +1,260 @@ +package auth + +// This file implements the JSPath subset PostgREST accepts for +// jwt-role-claim-key (spec 13): dotted keys (".role"), quoted keys +// (."https://example.com/role"), array indexing (".roles[0]"), and a single +// trailing filter ([?(@ == "admin")]) with the five operators ==, !=, ^== +// (prefix), ==^ (suffix), and *== (contains). The grammar mirrors PostgREST's +// Config.JSPath parser: an invalid value is a startup error, never a silent +// fallback to anon. + +import ( + "fmt" + "strconv" + "strings" +) + +// jsPathExp is one step of a parsed role-claim path. +type jsPathExp struct { + kind jsPathKind + + key string // jspKey: the object key to descend into + idx int // jspIdx: the array index to descend into + op string // jspFilter: one of == != ^== ==^ *== + value string // jspFilter: the quoted comparison value +} + +type jsPathKind int + +const ( + jspKey jsPathKind = iota + jspIdx + jspFilter +) + +// defaultRoleKey is the path used when jwt-role-claim-key is unset: the +// top-level "role" claim. +var defaultRoleKey = []jsPathExp{{kind: jspKey, key: "role"}} + +// parseJSPath parses a jwt-role-claim-key value. An empty value yields the +// default ".role" path; anything else must match the DSL exactly, including +// the leading dot PostgREST requires. +func parseJSPath(s string) ([]jsPathExp, error) { + if strings.TrimSpace(s) == "" { + return defaultRoleKey, nil + } + p := &jsPathParser{src: s} + path, err := p.parse() + if err != nil { + return nil, fmt.Errorf("failed to parse role-claim-key value (%s): %s", s, err) + } + return path, nil +} + +// jsPathParser is a single-pass scanner over the role-claim-key text. +type jsPathParser struct { + src string + pos int +} + +// parse consumes one or more path expressions up to the end of input. A filter +// is only legal as the final expression, as in PostgREST. +func (p *jsPathParser) parse() ([]jsPathExp, error) { + var path []jsPathExp + for p.pos < len(p.src) { + exp, err := p.parseExp() + if err != nil { + return nil, err + } + path = append(path, *exp) + if exp.kind == jspFilter && p.pos < len(p.src) { + return nil, fmt.Errorf("a filter must be the last path element") + } + } + if len(path) == 0 { + return nil, fmt.Errorf("empty path") + } + return path, nil +} + +// parseExp reads the next key, index, or filter expression. +func (p *jsPathParser) parseExp() (*jsPathExp, error) { + switch { + case p.peek() == '.': + return p.parseKey() + case strings.HasPrefix(p.src[p.pos:], "[?("): + return p.parseFilter() + case p.peek() == '[': + return p.parseIdx() + default: + return nil, fmt.Errorf("expected '.', '[n]', or '[?(' at position %d", p.pos) + } +} + +// parseKey reads ".name" (alphanumerics plus _$@) or a quoted ."any text" key. +func (p *jsPathParser) parseKey() (*jsPathExp, error) { + p.pos++ // consume '.' + if p.peek() == '"' { + val, err := p.parseQuoted() + if err != nil { + return nil, err + } + return &jsPathExp{kind: jspKey, key: val}, nil + } + start := p.pos + for p.pos < len(p.src) && isKeyChar(p.src[p.pos]) { + p.pos++ + } + if p.pos == start { + return nil, fmt.Errorf("expected a key after '.' at position %d", start) + } + return &jsPathExp{kind: jspKey, key: p.src[start:p.pos]}, nil +} + +// parseIdx reads "[n]" with a non-negative decimal index. +func (p *jsPathParser) parseIdx() (*jsPathExp, error) { + p.pos++ // consume '[' + start := p.pos + for p.pos < len(p.src) && p.src[p.pos] >= '0' && p.src[p.pos] <= '9' { + p.pos++ + } + if p.pos == start { + return nil, fmt.Errorf("expected digits after '[' at position %d", start) + } + n, err := strconv.Atoi(p.src[start:p.pos]) + if err != nil { + return nil, fmt.Errorf("bad array index: %s", err) + } + if p.peek() != ']' { + return nil, fmt.Errorf("expected ']' at position %d", p.pos) + } + p.pos++ + return &jsPathExp{kind: jspIdx, idx: n}, nil +} + +// parseFilter reads `[?(@ "value")]`. The operators are tried in +// PostgREST's order so "==^" wins over "==". +func (p *jsPathParser) parseFilter() (*jsPathExp, error) { + p.pos += len("[?(") + if p.peek() != '@' { + return nil, fmt.Errorf("expected '@' at position %d", p.pos) + } + p.pos++ + p.skipSpaces() + var op string + for _, candidate := range []string{"==^", "==", "!=", "^==", "*=="} { + if strings.HasPrefix(p.src[p.pos:], candidate) { + op = candidate + break + } + } + if op == "" { + return nil, fmt.Errorf("expected a filter operator at position %d", p.pos) + } + p.pos += len(op) + p.skipSpaces() + val, err := p.parseQuoted() + if err != nil { + return nil, err + } + if !strings.HasPrefix(p.src[p.pos:], ")]") { + return nil, fmt.Errorf("expected ')]' at position %d", p.pos) + } + p.pos += len(")]") + return &jsPathExp{kind: jspFilter, op: op, value: val}, nil +} + +// parseQuoted reads a double-quoted string with no escape processing, matching +// the upstream grammar. +func (p *jsPathParser) parseQuoted() (string, error) { + if p.peek() != '"' { + return "", fmt.Errorf("expected '\"' at position %d", p.pos) + } + p.pos++ + start := p.pos + for p.pos < len(p.src) && p.src[p.pos] != '"' { + p.pos++ + } + if p.pos == len(p.src) { + return "", fmt.Errorf("unterminated quoted value at position %d", start) + } + val := p.src[start:p.pos] + p.pos++ // consume closing quote + return val, nil +} + +// peek returns the current byte, or 0 at end of input. +func (p *jsPathParser) peek() byte { + if p.pos >= len(p.src) { + return 0 + } + return p.src[p.pos] +} + +// skipSpaces advances over spaces inside a filter condition. +func (p *jsPathParser) skipSpaces() { + for p.pos < len(p.src) && p.src[p.pos] == ' ' { + p.pos++ + } +} + +// isKeyChar reports whether c may appear in an unquoted key: alphanumerics +// plus the _, $, and @ PostgREST allows. +func isKeyChar(c byte) bool { + return c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || + c == '_' || c == '$' || c == '@' +} + +// walkJSPath descends the decoded claim set along a parsed path. A key step +// requires an object, an index step an array, and a filter step an array whose +// first matching string element is the result. Any mismatch resolves to no +// value, the same as a missing claim. +func walkJSPath(cur any, path []jsPathExp) (any, bool) { + for _, e := range path { + switch e.kind { + case jspKey: + m, ok := cur.(map[string]any) + if !ok { + return nil, false + } + if cur, ok = m[e.key]; !ok { + return nil, false + } + case jspIdx: + ar, ok := cur.([]any) + if !ok || e.idx >= len(ar) { + return nil, false + } + cur = ar[e.idx] + case jspFilter: + ar, ok := cur.([]any) + if !ok { + return nil, false + } + for _, el := range ar { + if s, ok := el.(string); ok && matchFilter(e.op, e.value, s) { + return s, true + } + } + return nil, false + } + } + return cur, true +} + +// matchFilter applies one filter operator to a candidate array element. +func matchFilter(op, pattern, candidate string) bool { + switch op { + case "==": + return candidate == pattern + case "!=": + return candidate != pattern + case "^==": + return strings.HasPrefix(candidate, pattern) + case "==^": + return strings.HasSuffix(candidate, pattern) + case "*==": + return strings.Contains(candidate, pattern) + } + return false +} diff --git a/auth/jspath_test.go b/auth/jspath_test.go new file mode 100644 index 0000000..535b214 --- /dev/null +++ b/auth/jspath_test.go @@ -0,0 +1,124 @@ +package auth + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// The accepted forms, straight from the v14 jwt-role-claim-key documentation: +// plain keys, nested keys, quoted namespaced keys, array indexes, and the five +// filter operators. +func TestParseJSPathAcceptedForms(t *testing.T) { + cases := []string{ + ".role", + ".roles[0]", + ".app_metadata.role", + `."https://example.com/role"`, + `.realm_access.roles[?(@ == "client_admin")]`, + `.roles[?(@ != "user")]`, + `.roles[?(@ ^== "adm")]`, + `.roles[?(@ ==^ "min")]`, + `.roles[?(@ *== "dmi")]`, + `.roles[?(@=="compact")]`, + ".a.b[2].c", + "[0].role", + } + for _, c := range cases { + if _, err := parseJSPath(c); err != nil { + t.Errorf("parseJSPath(%q): %v", c, err) + } + } +} + +// An empty key falls back to the default ".role" path rather than erroring. +func TestParseJSPathEmptyDefaultsToRole(t *testing.T) { + path, err := parseJSPath("") + if err != nil { + t.Fatalf("parseJSPath(\"\"): %v", err) + } + if len(path) != 1 || path[0].kind != jspKey || path[0].key != "role" { + t.Fatalf("default path = %+v, want [.role]", path) + } +} + +// The rejected forms: a missing leading dot, unterminated brackets and quotes, +// an unknown operator, and a filter that is not the final element. PostgREST +// refuses these at config load, so the verifier must refuse them at startup. +func TestParseJSPathRejectedForms(t *testing.T) { + cases := []string{ + "role", + ".", + ".roles[", + ".roles[1", + ".roles[x]", + `.roles[?(@ = "x")]`, + `.roles[?(@ == "x")`, + `.roles[?(@ == "x)]`, + `.roles[?(@ == "x")].more`, + `."unterminated`, + ".role extra", + } + for _, c := range cases { + if _, err := parseJSPath(c); err == nil { + t.Errorf("parseJSPath(%q): want error, got none", c) + } + } +} + +// An invalid jwt-role-claim-key is a startup error on NewVerifier, never a +// silently broken verifier. +func TestInvalidRoleClaimKeyRefusedAtStartup(t *testing.T) { + _, err := NewVerifier(Config{Secret: hmacKey, RoleClaimKey: "role"}) + if err == nil { + t.Fatal("a role-claim-key without a leading dot must be refused") + } +} + +// walkJSPath drives role resolution end to end through Authenticate for each +// documented form. +func TestRoleClaimKeyForms(t *testing.T) { + cases := []struct { + name string + key string + claims jwt.MapClaims + want string + }{ + {"array index", ".roles[1]", + jwt.MapClaims{"roles": []any{"alpha", "beta"}}, "beta"}, + {"quoted namespaced key", `."https://example.com/role"`, + jwt.MapClaims{"https://example.com/role": "web_user"}, "web_user"}, + {"keycloak filter", `.realm_access.roles[?(@ == "client_admin")]`, + jwt.MapClaims{"realm_access": map[string]any{ + "roles": []any{"offline_access", "client_admin"}, + }}, "client_admin"}, + {"not-equals filter takes first non-match", `.roles[?(@ != "user")]`, + jwt.MapClaims{"roles": []any{"user", "editor", "admin"}}, "editor"}, + {"prefix filter", `.roles[?(@ ^== "web_")]`, + jwt.MapClaims{"roles": []any{"admin", "web_user"}}, "web_user"}, + {"suffix filter", `.roles[?(@ ==^ "_user")]`, + jwt.MapClaims{"roles": []any{"admin", "web_user"}}, "web_user"}, + {"contains filter", `.roles[?(@ *== "b_us")]`, + jwt.MapClaims{"roles": []any{"admin", "web_user"}}, "web_user"}, + {"index out of range falls back to anon", ".roles[5]", + jwt.MapClaims{"roles": []any{"only"}}, anonRole}, + {"filter with no match falls back to anon", `.roles[?(@ == "nope")]`, + jwt.MapClaims{"roles": []any{"admin"}}, anonRole}, + {"filter over non-array falls back to anon", `.role[?(@ == "x")]`, + jwt.MapClaims{"role": "admin"}, anonRole}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + v := hmacVerifier(t, Config{RoleClaimKey: c.key}) + c.claims["exp"] = clockNow.Add(time.Hour).Unix() + res, err := v.Authenticate("Bearer " + signHS(t, c.claims)) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if res.Role != c.want { + t.Errorf("role = %q, want %q", res.Role, c.want) + } + }) + } +} diff --git a/auth/jwk.go b/auth/jwk.go new file mode 100644 index 0000000..aadb998 --- /dev/null +++ b/auth/jwk.go @@ -0,0 +1,183 @@ +package auth + +// This file implements the three ways PostgREST accepts jwt-secret (spec 13): +// a literal JWK Set JSON, a single JWK JSON, or a plain text HMAC secret. The +// parsed result is always a list of verification keys; at verify time a +// token's kid selects its key and a kid-less token tries every key in turn, +// the same try-all behavior the upstream jose library applies. + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math/big" + "strings" +) + +// verKey is one verification key: HMAC bytes, an RSA public key, or an ECDSA +// public key, with the optional JWK kid and alg restrictions. +type verKey struct { + kid string + alg string + key any // []byte, *rsa.PublicKey, or *ecdsa.PublicKey +} + +// jwk is the wire form of a JSON Web Key, covering the symmetric (oct), RSA, +// and EC key types. +type jwk struct { + Kty string `json:"kty"` + Kid string `json:"kid"` + Alg string `json:"alg"` + K string `json:"k"` // oct: the key bytes + N string `json:"n"` // RSA: modulus + E string `json:"e"` // RSA: exponent + Crv string `json:"crv"` // EC: curve name + X string `json:"x"` // EC: x coordinate + Y string `json:"y"` // EC: y coordinate +} + +// parseSecretKeys parses a jwt-secret value the way PostgREST does: first as a +// JWK Set, then as a single JWK, and finally as a plain text HMAC secret. Only +// the text form carries the 32-character minimum; a malformed JSON value falls +// through to the text interpretation, matching upstream. +func parseSecretKeys(secret []byte) ([]verKey, error) { + if keys, ok := tryJWKSet(secret); ok { + return keys, nil + } + if key, ok := tryJWK(secret); ok { + return []verKey{*key}, nil + } + if len(secret) < minHMACSecret { + return nil, errors.New("jwt-secret must be at least 32 characters") + } + return []verKey{{key: append([]byte(nil), secret...)}}, nil +} + +// parseJWKSet parses the jwk-set configuration value. Unlike jwt-secret there +// is no text fallback: the value names a key set and must be a JWK Set or a +// single JWK, otherwise startup fails rather than silently disabling auth. +func parseJWKSet(text string) ([]verKey, error) { + b := []byte(text) + if keys, ok := tryJWKSet(b); ok { + return keys, nil + } + if key, ok := tryJWK(b); ok { + return []verKey{*key}, nil + } + return nil, errors.New("not a valid JWK or JWK Set") +} + +// tryJWKSet attempts to read the bytes as a {"keys": [...]} JWK Set. It only +// succeeds when every listed key is usable, the all-or-nothing reading the +// upstream JSON decoder applies. +func tryJWKSet(b []byte) ([]verKey, bool) { + var set struct { + Keys []json.RawMessage `json:"keys"` + } + if err := json.Unmarshal(b, &set); err != nil || set.Keys == nil { + return nil, false + } + keys := make([]verKey, 0, len(set.Keys)) + for _, raw := range set.Keys { + var w jwk + if err := json.Unmarshal(raw, &w); err != nil { + return nil, false + } + key, err := w.toKey() + if err != nil { + return nil, false + } + keys = append(keys, *key) + } + return keys, true +} + +// tryJWK attempts to read the bytes as a single JWK. +func tryJWK(b []byte) (*verKey, bool) { + var w jwk + if err := json.Unmarshal(b, &w); err != nil { + return nil, false + } + key, err := w.toKey() + if err != nil { + return nil, false + } + return key, true +} + +// toKey materializes the wire-form JWK into a verification key. +func (w jwk) toKey() (*verKey, error) { + switch w.Kty { + case "oct": + k, err := b64urlDecode(w.K) + if err != nil || len(k) == 0 { + return nil, errors.New("oct key: bad k value") + } + return &verKey{kid: w.Kid, alg: w.Alg, key: k}, nil + case "RSA": + n, err := b64urlDecode(w.N) + if err != nil || len(n) == 0 { + return nil, errors.New("RSA key: bad n value") + } + e, err := b64urlDecode(w.E) + if err != nil || len(e) == 0 { + return nil, errors.New("RSA key: bad e value") + } + pub := &rsa.PublicKey{ + N: new(big.Int).SetBytes(n), + E: int(new(big.Int).SetBytes(e).Int64()), + } + return &verKey{kid: w.Kid, alg: w.Alg, key: pub}, nil + case "EC": + var curve elliptic.Curve + switch w.Crv { + case "P-256": + curve = elliptic.P256() + case "P-384": + curve = elliptic.P384() + case "P-521": + curve = elliptic.P521() + default: + return nil, fmt.Errorf("EC key: unsupported curve %q", w.Crv) + } + x, err := b64urlDecode(w.X) + if err != nil || len(x) == 0 { + return nil, errors.New("EC key: bad x value") + } + y, err := b64urlDecode(w.Y) + if err != nil || len(y) == 0 { + return nil, errors.New("EC key: bad y value") + } + pub := &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(x), + Y: new(big.Int).SetBytes(y), + } + // ECDH validates the point lies on the curve (and is not the identity), + // the modern replacement for the deprecated Curve.IsOnCurve. + if _, err := pub.ECDH(); err != nil { + return nil, errors.New("EC key: point not on curve") + } + return &verKey{kid: w.Kid, alg: w.Alg, key: pub}, nil + default: + return nil, fmt.Errorf("unsupported key type %q", w.Kty) + } +} + +// b64urlDecode decodes the unpadded URL-safe base64 JWK fields use. +func b64urlDecode(s string) ([]byte, error) { + return base64.RawURLEncoding.DecodeString(strings.TrimRight(s, "=")) +} + +// DecodeBase64Secret decodes a jwt-secret marked with jwt-secret-is-base64. +// PostgREST replaces the URL-safe alphabet (_ to /, - to +, . to =) and strips +// whitespace before a standard base64 decode; an undecodable value is a +// startup error. +func DecodeBase64Secret(s string) ([]byte, error) { + replaced := strings.NewReplacer("_", "/", "-", "+", ".", "=").Replace(s) + return base64.StdEncoding.DecodeString(strings.TrimSpace(replaced)) +} diff --git a/auth/jwk_test.go b/auth/jwk_test.go new file mode 100644 index 0000000..91dd410 --- /dev/null +++ b/auth/jwk_test.go @@ -0,0 +1,205 @@ +package auth + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// octJWK renders an HMAC secret as a symmetric JWK with an optional kid. +func octJWK(secret []byte, kid string) string { + w := map[string]string{ + "kty": "oct", + "k": base64.RawURLEncoding.EncodeToString(secret), + } + if kid != "" { + w["kid"] = kid + } + b, _ := json.Marshal(w) + return string(b) +} + +// signWithKid mints an HS256 token carrying a kid header. +func signWithKid(t *testing.T, secret []byte, kid string, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + if kid != "" { + tok.Header["kid"] = kid + } + s, err := tok.SignedString(secret) + if err != nil { + t.Fatalf("sign: %v", err) + } + return s +} + +// jwt-secret can hold a single JWK: the symmetric key inside verifies HS256. +func TestSecretAsSingleJWK(t *testing.T) { + secret := []byte("jwk-borne-secret-thats-not-the-text!") + v := hmacVerifier(t, Config{Secret: []byte(octJWK(secret, ""))}) + tok := signWithKid(t, secret, "", jwt.MapClaims{"role": "web_user"}) + res, err := v.Authenticate("Bearer " + tok) + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if res.Role != "web_user" { + t.Fatalf("role = %q", res.Role) + } +} + +// jwt-secret can hold a JWK Set; the token's kid picks its key, a kid-less +// token tries each key, and an unknown kid is a PGRST301. +func TestSecretAsJWKSetWithKid(t *testing.T) { + k1 := []byte("first-shared-secret-32-bytes-long!!!") + k2 := []byte("second-shared-secret-32-bytes-long!!") + set := fmt.Sprintf(`{"keys":[%s,%s]}`, octJWK(k1, "one"), octJWK(k2, "two")) + v := hmacVerifier(t, Config{Secret: []byte(set)}) + + // kid selects the second key. + tok := signWithKid(t, k2, "two", jwt.MapClaims{"role": "web_user"}) + res, err := v.Authenticate("Bearer " + tok) + if err != nil || res.Role != "web_user" { + t.Fatalf("kid-selected key: %+v, %v", res, err) + } + + // a kid signed with the wrong key fails the signature check. + tok = signWithKid(t, k2, "one", jwt.MapClaims{"role": "web_user"}) + if _, err := v.Authenticate("Bearer " + tok); err == nil || err.Code != "PGRST301" { + t.Fatalf("wrong key for kid must be PGRST301, got %v", err) + } + + // an unknown kid leaves no candidate keys. + tok = signWithKid(t, k1, "ghost", jwt.MapClaims{"role": "web_user"}) + _, aerr := v.Authenticate("Bearer " + tok) + if aerr == nil || aerr.Code != "PGRST301" || aerr.Message != "No suitable key or wrong key type" { + t.Fatalf("unknown kid must be a no-suitable-key PGRST301, got %v", aerr) + } + + // a kid-less token tries every key and verifies with the second. + tok = signWithKid(t, k2, "", jwt.MapClaims{"role": "web_user"}) + res, err = v.Authenticate("Bearer " + tok) + if err != nil || res.Role != "web_user" { + t.Fatalf("kid-less try-all: %+v, %v", res, err) + } +} + +// jwt-secret can hold an RSA JWK and an EC JWK; RS256 and ES256 tokens verify +// against them. +func TestSecretAsAsymmetricJWK(t *testing.T) { + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("genkey: %v", err) + } + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("genkey: %v", err) + } + rsaJWK, _ := json.Marshal(map[string]string{ + "kty": "RSA", + "n": base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()), + }) + ecJWK, _ := json.Marshal(map[string]string{ + "kty": "EC", + "crv": "P-256", + "x": base64.RawURLEncoding.EncodeToString(ecKey.X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(ecKey.Y.Bytes()), + }) + set := fmt.Sprintf(`{"keys":[%s,%s]}`, rsaJWK, ecJWK) + v := hmacVerifier(t, Config{Secret: []byte(set)}) + + rsTok := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{"role": "web_user"}) + signed, err := rsTok.SignedString(rsaKey) + if err != nil { + t.Fatalf("sign rs: %v", err) + } + if res, aerr := v.Authenticate("Bearer " + signed); aerr != nil || res.Role != "web_user" { + t.Fatalf("RS256 against RSA JWK: %+v, %v", res, aerr) + } + + esTok := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{"role": "web_user"}) + signed, err = esTok.SignedString(ecKey) + if err != nil { + t.Fatalf("sign es: %v", err) + } + if res, aerr := v.Authenticate("Bearer " + signed); aerr != nil || res.Role != "web_user" { + t.Fatalf("ES256 against EC JWK: %+v, %v", res, aerr) + } +} + +// A JSON value that is neither a JWK nor a JWK Set falls through to the text +// secret interpretation, the same reading PostgREST applies. +func TestMalformedJWKFallsBackToText(t *testing.T) { + secret := []byte(`{"not_a_jwk": "but long enough to be a passphrase"}`) + v := hmacVerifier(t, Config{Secret: secret}) + tok := signWithKid(t, secret, "", jwt.MapClaims{"role": "web_user"}) + res, err := v.Authenticate("Bearer " + tok) + if err != nil || res.Role != "web_user" { + t.Fatalf("text fallback: %+v, %v", res, err) + } +} + +// The explicit jwk-set value has no text fallback: configuring it with an +// unusable value is a startup error, never a silently auth-less server. +func TestJWKSetConfigRefusedWhenUnusable(t *testing.T) { + _, err := NewVerifier(Config{JWKSet: "this is not a key set", AnonRole: anonRole}) + if err == nil { + t.Fatal("an unparseable jwk-set must fail startup") + } + _, err = NewVerifier(Config{JWKSet: `{"keys":[{"kty":"alien"}]}`, AnonRole: anonRole}) + if err == nil { + t.Fatal("a jwk-set with an unsupported key must fail startup") + } +} + +// The jwk-set value wires real keys: a token verifies against it even with no +// jwt-secret configured. +func TestJWKSetConfigVerifies(t *testing.T) { + secret := []byte("set-borne-secret-32-bytes-long-okay!") + cfg := Config{JWKSet: octJWK(secret, ""), AnonRole: anonRole} + v, err := NewVerifier(cfg) + if err != nil { + t.Fatalf("NewVerifier: %v", err) + } + v.now = fixedClock(clockNow) + tok := signWithKid(t, secret, "", jwt.MapClaims{ + "role": "web_user", + "exp": clockNow.Add(time.Hour).Unix(), + }) + res, aerr := v.Authenticate("Bearer " + tok) + if aerr != nil || res.Role != "web_user" { + t.Fatalf("jwk-set verification: %+v, %v", res, aerr) + } +} + +// DecodeBase64Secret applies PostgREST's URL-safe character replacement before +// the standard decode and refuses undecodable values. +func TestDecodeBase64Secret(t *testing.T) { + raw := []byte("a-secret-with-bytes-needing-urlsafe-chars???>>>") + std := base64.StdEncoding.EncodeToString(raw) + urlSafe := base64.URLEncoding.EncodeToString(raw) + urlSafe = strings.ReplaceAll(urlSafe, "=", ".") + + for _, enc := range []string{std, urlSafe, " " + std + "\n"} { + got, err := DecodeBase64Secret(enc) + if err != nil { + t.Fatalf("DecodeBase64Secret(%q): %v", enc, err) + } + if string(got) != string(raw) { + t.Errorf("decoded %q, want %q", got, raw) + } + } + if _, err := DecodeBase64Secret("!!! not base64 !!!"); err == nil { + t.Error("an undecodable value must error") + } +} diff --git a/authz/authz.go b/authz/authz.go index 6ae7664..2fd1a1d 100644 --- a/authz/authz.go +++ b/authz/authz.go @@ -26,7 +26,6 @@ import ( "github.com/tamnd/dbrest/ir" "github.com/tamnd/dbrest/pgerr" "github.com/tamnd/dbrest/reqctx" - "github.com/tamnd/dbrest/schema" ) // Action is a privilege verb a role may be granted on a relation. @@ -146,8 +145,8 @@ func NewRegistry(grants []Grant, policies []Policy) *Registry { // Authorize gates a planned query for the request's role and injects any RLS // policy. It runs after planning and before execution, mutating the plan's query -// in place: it rejects a request the role may not make, narrows or rejects a -// projection by column privilege, and AND-s the policy predicate onto the filter +// in place: it rejects a request the role may not make, rejects a projection +// outside the column privilege, and AND-s the policy predicate onto the filter // tree. An RPC plan carries no relation query and is passed through unchanged. func (r *Registry) Authorize(rc *reqctx.Context, p *ir.Plan) *pgerr.APIError { if p == nil || p.Query == nil { @@ -166,9 +165,9 @@ func (r *Registry) Authorize(rc *reqctx.Context, p *ir.Plan) *pgerr.APIError { // The column gate. A read always projects; a write projects only when it // returns the representation. Either way the projection is gated against the - // SELECT column grant, and a star/empty projection is narrowed to it. + // SELECT column grant. if q.Kind == ir.Read || (q.Write != nil && q.Write.Return == ir.ReturnRepresentation) { - if err := r.gateSelect(role, rel, q, p.Rel, rc.Anonymous); err != nil { + if err := r.gateSelect(role, rel, q, rc.Anonymous); err != nil { return err } } @@ -210,11 +209,14 @@ func (r *Registry) Authorize(rc *reqctx.Context, p *ir.Plan) *pgerr.APIError { return nil } -// gateSelect enforces the SELECT column grant on a projection. An explicitly -// named forbidden column rejects the request; a star or empty projection is -// narrowed to the granted columns, matching how PostgreSQL drops what a role may -// not read while refusing an explicit ungranted column. -func (r *Registry) gateSelect(role, rel string, q *ir.Query, relSchema *schema.Relation, anon bool) *pgerr.APIError { +// gateSelect enforces the SELECT column grant on a projection. Under a +// column-limited grant, a forbidden named column rejects the request, and so +// does a star or empty projection: both mean SELECT every column, which the +// grant does not cover. That matches PostgreSQL, where SELECT * raises 42501 +// under partial column grants, and PostgREST, whose maintainers explicitly +// rejected narrowing * to the granted set (issue #1732); the client must name +// the columns it may read. +func (r *Registry) gateSelect(role, rel string, q *ir.Query, anon bool) *pgerr.APIError { g, ok := r.grants[grantKey{role, rel, Select}] if !ok { return pgerr.ErrPermissionDenied(rel, anon) @@ -222,23 +224,18 @@ func (r *Registry) gateSelect(role, rel string, q *ir.Query, relSchema *schema.R if g.all { return nil } - hasStar := len(q.Select) == 0 + if len(q.Select) == 0 { + return pgerr.ErrPermissionDenied(rel, anon) + } for _, it := range q.Select { c, isCol := it.(ir.Column) if !isCol { continue } - if isStar(c) { - hasStar = true - continue - } - if !g.cols[baseColumn(c)] { + if isStar(c) || !g.cols[baseColumn(c)] { return pgerr.ErrPermissionDenied(rel, anon) } } - if hasStar && relSchema != nil { - q.Select = narrowProjection(q.Select, g.cols, relSchema) - } return nil } @@ -265,35 +262,6 @@ func (r *Registry) gateWriteColumns(role, rel string, w *ir.WriteSpec, action Ac return nil } -// narrowProjection rewrites a star or empty projection to the granted columns in -// relation order, keeping any embed references and any already-allowed explicit -// columns, with duplicates removed. -func narrowProjection(items []ir.SelectItem, allowed map[string]bool, rel *schema.Relation) []ir.SelectItem { - out := make([]ir.SelectItem, 0, len(rel.Columns)) - seen := map[string]bool{} - add := func(name string) { - if seen[name] || !allowed[name] { - return - } - seen[name] = true - out = append(out, ir.Column{Path: []string{name}}) - } - for _, c := range rel.Columns { - add(c.Name) - } - for _, it := range items { - switch v := it.(type) { - case ir.Column: - if !isStar(v) { - add(baseColumn(v)) - } - default: - out = append(out, it) - } - } - return out -} - // usingConds turns a USING predicate into IR conditions with each claim resolved // to a literal. A term whose claim is missing becomes an always-false condition // (an empty IN), so an absent claim denies every row rather than leaking them. diff --git a/authz/authz_test.go b/authz/authz_test.go index 99da83a..1ba981d 100644 --- a/authz/authz_test.go +++ b/authz/authz_test.go @@ -96,33 +96,69 @@ func TestExplicitForbiddenColumnDenied(t *testing.T) { } } -func TestStarNarrowedToGrantedColumns(t *testing.T) { +// A star projection under a column-limited grant means SELECT every column, +// which the grant does not cover. PostgreSQL raises 42501 for that and the +// PostgREST maintainers rejected narrowing * to the granted set (issue #1732), +// so the request is denied; the client must name the granted columns. +func TestStarRejectedUnderColumnLimitedGrant(t *testing.T) { reg := authz.NewRegistry( []authz.Grant{{Role: "web_user", Relation: "films", Action: authz.Select, Columns: []string{"id", "title"}}}, nil, ) p := readPlan(star()) - if err := reg.Authorize(&reqctx.Context{Role: "web_user"}, p); err != nil { - t.Fatalf("Authorize: %v", err) + err := reg.Authorize(&reqctx.Context{Role: "web_user"}, p) + if err == nil { + t.Fatal("star projection under a column-limited grant was allowed") } - got := projectedNames(p.Query.Select) - want := []string{"id", "title"} - if !equal(got, want) { - t.Errorf("narrowed projection = %v, want %v", got, want) + if err.HTTPStatus != http.StatusForbidden { + t.Errorf("status = %d, want 403", err.HTTPStatus) + } + if err.Code != "42501" { + t.Errorf("code = %q, want 42501", err.Code) } } -func TestEmptyProjectionNarrowedToGrantedColumns(t *testing.T) { +func TestEmptyProjectionRejectedUnderColumnLimitedGrant(t *testing.T) { reg := authz.NewRegistry( []authz.Grant{{Role: "web_user", Relation: "films", Action: authz.Select, Columns: []string{"title"}}}, nil, ) - p := readPlan() // no select items: whole-row projection + p := readPlan() // no select items: whole-row projection, same as * + err := reg.Authorize(&reqctx.Context{Role: "web_user"}, p) + if err == nil { + t.Fatal("whole-row projection under a column-limited grant was allowed") + } + if err.HTTPStatus != http.StatusForbidden { + t.Errorf("status = %d, want 403", err.HTTPStatus) + } +} + +func TestStarRejectedForAnonIs401(t *testing.T) { + reg := authz.NewRegistry( + []authz.Grant{{Role: "anon", Relation: "films", Action: authz.Select, Columns: []string{"id"}}}, + nil, + ) + p := readPlan(star()) + err := reg.Authorize(&reqctx.Context{Role: "anon", Anonymous: true}, p) + if err == nil { + t.Fatal("anon star projection under a column-limited grant was allowed") + } + if err.HTTPStatus != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", err.HTTPStatus) + } +} + +func TestGrantedColumnsProjectFine(t *testing.T) { + reg := authz.NewRegistry( + []authz.Grant{{Role: "web_user", Relation: "films", Action: authz.Select, Columns: []string{"id", "title"}}}, + nil, + ) + p := readPlan(col("id"), col("title")) if err := reg.Authorize(&reqctx.Context{Role: "web_user"}, p); err != nil { - t.Fatalf("Authorize: %v", err) + t.Fatalf("Authorize denied a fully granted projection: %v", err) } - if got := projectedNames(p.Query.Select); !equal(got, []string{"title"}) { - t.Errorf("narrowed projection = %v, want [title]", got) + if got := projectedNames(p.Query.Select); !equal(got, []string{"id", "title"}) { + t.Errorf("projection = %v, want untouched [id title]", got) } } @@ -459,7 +495,7 @@ func BenchmarkAuthorizeReadWithPolicy(b *testing.B) { Query: &ir.Query{ Kind: ir.Read, Relation: ir.Ref{Schema: "public", Name: "films"}, - Select: []ir.SelectItem{star()}, + Select: []ir.SelectItem{col("id"), col("title")}, }, } if err := reg.Authorize(rc, p); err != nil { diff --git a/authz/parse.go b/authz/parse.go new file mode 100644 index 0000000..915077e --- /dev/null +++ b/authz/parse.go @@ -0,0 +1,271 @@ +package authz + +// This file parses the policy-registry configuration value (spec 14, spec 20) +// into the Registry the authorization gate consults. The registry is the +// security boundary on the emulated backends, so parsing fails closed: any +// unknown key, unknown action, or unparseable predicate is a startup error, +// never a silently ignored rule. + +import ( + "bytes" + "encoding/json" + "fmt" + "strconv" + "strings" +) + +// grantDecl is one declared privilege: a role may perform the listed actions +// on a relation, optionally narrowed to a column set (empty means every +// column). It expands to one Grant per action. +type grantDecl struct { + Role string `json:"role"` + Relation string `json:"relation"` + Actions []string `json:"actions"` + Columns []string `json:"columns"` +} + +// policyDecl is one declared Row Level Security policy. The predicates use the +// declaration syntax from spec 14: terms of the form `column = rhs` or +// `column != rhs` joined with `and`, where rhs is a claim reference +// (request.jwt.claims.tenant) or a literal ('open', 42, true). +type policyDecl struct { + Role string `json:"role"` + Relation string `json:"relation"` + Using string `json:"using"` + WithCheck string `json:"with_check"` +} + +// registryDecl is the top-level policy-registry document. +type registryDecl struct { + Grants []grantDecl `json:"grants"` + Policies []policyDecl `json:"policies"` +} + +// validActions maps the declared action names onto the privilege verbs. +var validActions = map[string]Action{ + "select": Select, + "insert": Insert, + "update": Update, + "delete": Delete, +} + +// ParseRegistry decodes a JSON policy-registry declaration into a Registry. +// The document is an object with two lists: +// +// grants [{role, relation, actions: ["select", ...], columns?: [...]}] +// policies [{role, relation, using?: "", with_check?: ""}] +// +// A predicate is one or more `column = rhs` / `column != rhs` terms joined +// with `and`; rhs is a request.jwt.claims reference or a literal. Once a +// registry is configured, the absence of a grant is a denial, so a declaration +// this function cannot fully understand is an error: nothing is skipped. +func ParseRegistry(raw string) (*Registry, error) { + dec := json.NewDecoder(bytes.NewReader([]byte(raw))) + dec.DisallowUnknownFields() + var doc registryDecl + if err := dec.Decode(&doc); err != nil { + return nil, fmt.Errorf("policy-registry: %w", err) + } + if dec.More() { + return nil, fmt.Errorf("policy-registry: trailing data after the document") + } + + grants := make([]Grant, 0, len(doc.Grants)) + for i, g := range doc.Grants { + if g.Role == "" || g.Relation == "" { + return nil, fmt.Errorf("policy-registry: grant %d: role and relation are required", i) + } + if len(g.Actions) == 0 { + return nil, fmt.Errorf("policy-registry: grant %d (%s on %s): actions is required", i, g.Role, g.Relation) + } + for _, a := range g.Actions { + action, ok := validActions[strings.ToLower(strings.TrimSpace(a))] + if !ok { + return nil, fmt.Errorf("policy-registry: grant %d (%s on %s): unknown action %q", i, g.Role, g.Relation, a) + } + grants = append(grants, Grant{ + Role: g.Role, + Relation: g.Relation, + Action: action, + Columns: g.Columns, + }) + } + } + + policies := make([]Policy, 0, len(doc.Policies)) + for i, p := range doc.Policies { + if p.Role == "" || p.Relation == "" { + return nil, fmt.Errorf("policy-registry: policy %d: role and relation are required", i) + } + if p.Using == "" && p.WithCheck == "" { + return nil, fmt.Errorf("policy-registry: policy %d (%s on %s): at least one of using and with_check is required", i, p.Role, p.Relation) + } + using, err := parsePredicate(p.Using) + if err != nil { + return nil, fmt.Errorf("policy-registry: policy %d (%s on %s): using: %w", i, p.Role, p.Relation, err) + } + check, err := parsePredicate(p.WithCheck) + if err != nil { + return nil, fmt.Errorf("policy-registry: policy %d (%s on %s): with_check: %w", i, p.Role, p.Relation, err) + } + policies = append(policies, Policy{ + Role: p.Role, + Relation: p.Relation, + Using: using, + WithCheck: check, + }) + } + + return NewRegistry(grants, policies), nil +} + +// claimPrefixes are the accepted spellings of a claim reference. The canonical +// form matches PostgREST's GUC vocabulary (request.jwt.claims.); the +// singular spelling appears in spec 14's examples and is accepted as the same +// thing. +var claimPrefixes = []string{"request.jwt.claims.", "request.jwt.claim."} + +// parsePredicate parses the declared predicate syntax into a conjunction of +// terms. An empty declaration is the always-true predicate (a policy may set +// only one of using/with_check). +func parsePredicate(src string) (Predicate, error) { + if strings.TrimSpace(src) == "" { + return Predicate{}, nil + } + var terms []Term + for _, part := range splitAnd(src) { + t, err := parseTerm(part) + if err != nil { + return Predicate{}, err + } + terms = append(terms, t) + } + return Predicate{Terms: terms}, nil +} + +// splitAnd splits a predicate on the `and` keyword outside quotes. +func splitAnd(src string) []string { + var parts []string + var cur strings.Builder + inQuote := false + i := 0 + for i < len(src) { + c := src[i] + if c == '\'' { + inQuote = !inQuote + } + if !inQuote && !inWord(src, i) && hasWordAt(src, i, "and") { + parts = append(parts, cur.String()) + cur.Reset() + i += len("and") + continue + } + cur.WriteByte(c) + i++ + } + parts = append(parts, cur.String()) + return parts +} + +// hasWordAt reports whether the keyword appears at position i as a whole word. +func hasWordAt(src string, i int, word string) bool { + if !strings.HasPrefix(strings.ToLower(src[i:]), word) { + return false + } + end := i + len(word) + before := i == 0 || src[i-1] == ' ' || src[i-1] == '\t' + after := end == len(src) || src[end] == ' ' || src[end] == '\t' + return before && after +} + +// inWord reports whether position i continues an identifier started earlier, +// so "band = 1" does not split on its inner "and". +func inWord(src string, i int) bool { + return i > 0 && isIdentChar(src[i-1]) +} + +// parseTerm parses one `column rhs` comparison. +func parseTerm(src string) (Term, error) { + s := strings.TrimSpace(src) + if s == "" { + return Term{}, fmt.Errorf("empty term") + } + + // The operator: != before = so the longer token wins. + var op Op + var lhs, rhs string + if i := strings.Index(s, "!="); i >= 0 { + op, lhs, rhs = OpNeq, s[:i], s[i+2:] + } else if i := strings.Index(s, "="); i >= 0 { + op, lhs, rhs = OpEq, s[:i], s[i+1:] + } else { + return Term{}, fmt.Errorf("term %q: expected = or !=", s) + } + + col := strings.TrimSpace(lhs) + if col == "" || !isIdent(col) { + return Term{}, fmt.Errorf("term %q: %q is not a column name", s, col) + } + + t := Term{Column: col, Op: op} + val := strings.TrimSpace(rhs) + switch { + case val == "": + return Term{}, fmt.Errorf("term %q: missing right-hand side", s) + case isClaimRef(val): + t.Claim = claimPath(val) + if t.Claim == "" { + return Term{}, fmt.Errorf("term %q: empty claim path", s) + } + case val[0] == '\'': + if len(val) < 2 || val[len(val)-1] != '\'' { + return Term{}, fmt.Errorf("term %q: unterminated string literal", s) + } + t.Literal = val[1 : len(val)-1] + case val == "true" || val == "false": + t.Literal = val == "true" + default: + if _, err := strconv.ParseFloat(val, 64); err != nil { + return Term{}, fmt.Errorf("term %q: %q is not a claim reference, string, number, or boolean", s, val) + } + t.Literal = json.Number(val) + } + return t, nil +} + +// isClaimRef reports whether a right-hand side is a request.jwt claim +// reference. +func isClaimRef(val string) bool { + for _, p := range claimPrefixes { + if strings.HasPrefix(val, p) { + return true + } + } + return false +} + +// claimPath strips the claim-reference prefix, leaving the dotted path into +// the claim set. +func claimPath(val string) string { + for _, p := range claimPrefixes { + if strings.HasPrefix(val, p) { + return val[len(p):] + } + } + return "" +} + +// isIdent reports whether s is a plain identifier (a column name). +func isIdent(s string) bool { + for i := 0; i < len(s); i++ { + if !isIdentChar(s[i]) { + return false + } + } + return len(s) > 0 +} + +// isIdentChar is the identifier alphabet for columns in a predicate. +func isIdentChar(c byte) bool { + return c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' +} diff --git a/authz/parse_test.go b/authz/parse_test.go new file mode 100644 index 0000000..533c536 --- /dev/null +++ b/authz/parse_test.go @@ -0,0 +1,145 @@ +package authz + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestParseRegistryGrantsAndPolicies(t *testing.T) { + reg, err := ParseRegistry(`{ + "grants": [ + {"role": "web_user", "relation": "films", "actions": ["select", "insert"], "columns": ["id", "title"]}, + {"role": "web_anon", "relation": "films", "actions": ["select"]} + ], + "policies": [ + {"role": "web_user", "relation": "films", + "using": "owner = request.jwt.claims.sub", + "with_check": "owner = request.jwt.claims.sub"} + ] + }`) + if err != nil { + t.Fatalf("ParseRegistry: %v", err) + } + + sel, ok := reg.grants[grantKey{"web_user", "films", Select}] + if !ok || sel.all || !sel.cols["id"] || !sel.cols["title"] || sel.cols["secret"] { + t.Errorf("web_user select grant = %+v, want columns id,title", sel) + } + ins, ok := reg.grants[grantKey{"web_user", "films", Insert}] + if !ok || ins.all { + t.Errorf("web_user insert grant = %+v, want the same column set", ins) + } + anon, ok := reg.grants[grantKey{"web_anon", "films", Select}] + if !ok || !anon.all { + t.Errorf("web_anon select grant = %+v, want all columns", anon) + } + if _, ok := reg.grants[grantKey{"web_user", "films", Delete}]; ok { + t.Error("an undeclared action must not be granted") + } + + pol, ok := reg.policies[polKey{"web_user", "films"}] + if !ok { + t.Fatal("policy not registered") + } + wantTerm := Term{Column: "owner", Op: OpEq, Claim: "sub"} + if len(pol.Using.Terms) != 1 || pol.Using.Terms[0] != wantTerm { + t.Errorf("using = %+v, want [%+v]", pol.Using.Terms, wantTerm) + } + if len(pol.WithCheck.Terms) != 1 || pol.WithCheck.Terms[0] != wantTerm { + t.Errorf("with_check = %+v, want [%+v]", pol.WithCheck.Terms, wantTerm) + } +} + +func TestParseRegistryEmptyDocumentDeniesAll(t *testing.T) { + // An explicitly empty registry is a deliberate deny-all: it parses, and + // the gate then refuses every request for lack of a grant. + reg, err := ParseRegistry(`{}`) + if err != nil { + t.Fatalf("ParseRegistry: %v", err) + } + if len(reg.grants) != 0 || len(reg.policies) != 0 { + t.Errorf("empty document = %d grants, %d policies, want none", len(reg.grants), len(reg.policies)) + } +} + +func TestParsePredicateForms(t *testing.T) { + cases := []struct { + name string + src string + want []Term + }{ + {"claim canonical", "tenant_id = request.jwt.claims.tenant", + []Term{{Column: "tenant_id", Op: OpEq, Claim: "tenant"}}}, + {"claim singular spelling", "tenant_id = request.jwt.claim.tenant", + []Term{{Column: "tenant_id", Op: OpEq, Claim: "tenant"}}}, + {"nested claim path", "org = request.jwt.claims.app_metadata.org", + []Term{{Column: "org", Op: OpEq, Claim: "app_metadata.org"}}}, + {"string literal", "status = 'open'", + []Term{{Column: "status", Op: OpEq, Literal: "open"}}}, + {"number literal", "tier = 2", + []Term{{Column: "tier", Op: OpEq, Literal: json.Number("2")}}}, + {"bool literal", "archived = false", + []Term{{Column: "archived", Op: OpEq, Literal: false}}}, + {"inequality", "status != 'deleted'", + []Term{{Column: "status", Op: OpNeq, Literal: "deleted"}}}, + {"conjunction", "tenant_id = request.jwt.claims.tenant and status = 'open'", + []Term{ + {Column: "tenant_id", Op: OpEq, Claim: "tenant"}, + {Column: "status", Op: OpEq, Literal: "open"}, + }}, + {"identifier containing and", "band = 'rush'", + []Term{{Column: "band", Op: OpEq, Literal: "rush"}}}, + {"and inside a string literal", "title = 'salt and pepper'", + []Term{{Column: "title", Op: OpEq, Literal: "salt and pepper"}}}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p, err := parsePredicate(c.src) + if err != nil { + t.Fatalf("parsePredicate(%q): %v", c.src, err) + } + if len(p.Terms) != len(c.want) { + t.Fatalf("terms = %+v, want %+v", p.Terms, c.want) + } + for i := range c.want { + if p.Terms[i] != c.want[i] { + t.Errorf("term %d = %+v, want %+v", i, p.Terms[i], c.want[i]) + } + } + }) + } +} + +func TestParseRegistryFailsClosed(t *testing.T) { + cases := []struct { + name string + src string + want string // a fragment the error must carry + }{ + {"malformed json", `{`, "policy-registry"}, + {"unknown top-level key", `{"grant": []}`, "unknown field"}, + {"trailing data", `{} {}`, "trailing data"}, + {"grant missing role", `{"grants": [{"relation": "films", "actions": ["select"]}]}`, "role and relation"}, + {"grant missing actions", `{"grants": [{"role": "r", "relation": "films"}]}`, "actions is required"}, + {"unknown action", `{"grants": [{"role": "r", "relation": "films", "actions": ["grant"]}]}`, `unknown action "grant"`}, + {"policy missing role", `{"policies": [{"relation": "films", "using": "a = 1"}]}`, "role and relation"}, + {"policy with no predicate", `{"policies": [{"role": "r", "relation": "films"}]}`, "at least one of"}, + {"predicate without operator", `{"policies": [{"role": "r", "relation": "films", "using": "owner"}]}`, "expected = or !="}, + {"predicate missing rhs", `{"policies": [{"role": "r", "relation": "films", "using": "owner ="}]}`, "missing right-hand side"}, + {"unterminated string", `{"policies": [{"role": "r", "relation": "films", "using": "owner = 'x"}]}`, "unterminated string"}, + {"bare word rhs", `{"policies": [{"role": "r", "relation": "films", "using": "owner = sub"}]}`, "not a claim reference"}, + {"bad column name", `{"policies": [{"role": "r", "relation": "films", "using": "owner; drop = 'x'"}]}`, "not a column name"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + _, err := ParseRegistry(c.src) + if err == nil { + t.Fatalf("ParseRegistry(%q) parsed, want an error", c.src) + } + if !strings.Contains(err.Error(), c.want) { + t.Errorf("error %q does not mention %q", err, c.want) + } + }) + } +} diff --git a/backend/driver.go b/backend/driver.go index 42a7aa7..24b5900 100644 --- a/backend/driver.go +++ b/backend/driver.go @@ -13,6 +13,23 @@ type Driver interface { Open(dsn string) (Backend, error) } +// OpenOptions carries cross-cutting open-time settings a driver may honor. A +// field left at its nil/zero value means "use the driver default", so a caller +// can pass only the settings it cares about. +type OpenOptions struct { + // PreparedStatements, when non-nil, enables or disables server-side prepared + // statements (PostgREST's db-prepared-statements). A driver that cannot vary + // this ignores it. + PreparedStatements *bool +} + +// OptionsDriver is an optional extension a Driver implements to receive +// OpenOptions. A driver that does not implement it is opened through Open and the +// options are dropped, so OpenWith stays safe for every backend. +type OptionsDriver interface { + OpenWithOptions(dsn string, opts OpenOptions) (Backend, error) +} + var ( driversMu sync.RWMutex drivers = make(map[string]Driver) @@ -45,6 +62,24 @@ func Open(name, dsn string) (Backend, error) { return d.Open(dsn) } +// OpenWith opens a Backend using the named driver, the given DSN, and the +// supplied open-time options. A driver that implements OptionsDriver receives the +// options; one that does not is opened through Open, ignoring them. This keeps the +// caller engine-agnostic: it always passes the options, and each backend honors +// what it can. +func OpenWith(name, dsn string, opts OpenOptions) (Backend, error) { + driversMu.RLock() + d, ok := drivers[name] + driversMu.RUnlock() + if !ok { + return nil, fmt.Errorf("backend: unknown driver %q (forgotten import?)", name) + } + if od, ok := d.(OptionsDriver); ok { + return od.OpenWithOptions(dsn, opts) + } + return d.Open(dsn) +} + // Drivers returns a sorted list of registered driver names. func Drivers() []string { driversMu.RLock() diff --git a/backend/driver_test.go b/backend/driver_test.go new file mode 100644 index 0000000..dd675a1 --- /dev/null +++ b/backend/driver_test.go @@ -0,0 +1,53 @@ +package backend + +import "testing" + +// plainDriver implements only Driver: OpenWith must fall back to Open and drop +// the options without error. +type plainDriver struct{ opened string } + +func (d *plainDriver) Open(dsn string) (Backend, error) { + d.opened = dsn + return nil, nil +} + +// optsDriver also implements OptionsDriver, so OpenWith must route the options to +// it instead of plain Open. +type optsDriver struct { + gotDSN string + gotOpts OpenOptions +} + +func (d *optsDriver) Open(dsn string) (Backend, error) { return nil, nil } +func (d *optsDriver) OpenWithOptions(dsn string, opts OpenOptions) (Backend, error) { + d.gotDSN = dsn + d.gotOpts = opts + return nil, nil +} + +func TestOpenWithRoutesOptions(t *testing.T) { + d := &optsDriver{} + Register("test-opts-driver", d) + prepared := false + if _, err := OpenWith("test-opts-driver", "dsn://x", OpenOptions{PreparedStatements: &prepared}); err != nil { + t.Fatalf("OpenWith: %v", err) + } + if d.gotDSN != "dsn://x" { + t.Errorf("dsn = %q, want dsn://x", d.gotDSN) + } + if d.gotOpts.PreparedStatements == nil || *d.gotOpts.PreparedStatements != false { + t.Errorf("PreparedStatements = %v, want a pointer to false", d.gotOpts.PreparedStatements) + } +} + +func TestOpenWithFallsBackForPlainDriver(t *testing.T) { + d := &plainDriver{} + Register("test-plain-driver", d) + prepared := true + if _, err := OpenWith("test-plain-driver", "dsn://y", OpenOptions{PreparedStatements: &prepared}); err != nil { + t.Fatalf("OpenWith: %v", err) + } + if d.opened != "dsn://y" { + t.Errorf("plain driver opened %q, want dsn://y", d.opened) + } +} diff --git a/backend/listener.go b/backend/listener.go new file mode 100644 index 0000000..8ed0d3d --- /dev/null +++ b/backend/listener.go @@ -0,0 +1,31 @@ +package backend + +import "context" + +// Listener is an optional backend capability for PostgREST's db-channel: a +// dedicated connection that waits for a database notification asking the server +// to reload. A backend that cannot listen does not implement it, and the server +// reloads on signals only. PostgreSQL implements it over LISTEN/NOTIFY. +type Listener interface { + // Listen opens a dedicated connection, waits for notifications on the named + // channel, and invokes the handler for each one until ctx is canceled. It + // blocks, so the caller runs it in a goroutine; it reconnects on a dropped + // connection with capped backoff and calls OnReconnect after re-establishing, + // because notifications sent while it was down are lost and the cache may be + // stale. It returns ctx.Err() when ctx is canceled. + Listen(ctx context.Context, channel string, h ListenHandler) error +} + +// ListenHandler carries the callbacks a Listener invokes. Decoding a payload into +// a reload action (the empty / "reload schema" / "reload config" contract) is the +// caller's job, so the wire contract lives in one engine-agnostic place rather +// than being duplicated in each backend. +type ListenHandler struct { + // OnNotify is called for each notification with its raw payload. A nil func + // is a no-op. + OnNotify func(payload string) + // OnReconnect is called after the connection is re-established following a + // drop, signaling that notifications may have been missed. A nil func is a + // no-op. It is not called for the first connection. + OnReconnect func() +} diff --git a/backend/mongo/execute.go b/backend/mongo/execute.go index d59f544..29b4001 100644 --- a/backend/mongo/execute.go +++ b/backend/mongo/execute.go @@ -23,6 +23,18 @@ func (b *Backend) Execute(ctx context.Context, plan *ir.Plan, rc *reqctx.Context if plan.Query == nil { return nil, pgerr.ErrUnsupported("this operation", backendName) } + // An empty column set (POST with an empty array, PATCH with an empty object) + // is a no-op: nothing is written, the affected count is zero, and the + // representation is the empty array. MongoDB rejects an empty $set anyway, so + // the short-circuit also keeps the update path from issuing an invalid op. + if backend.IsNoOpMutation(plan.Query) { + return &bodyResult{ + controls: rc.Controls(), + rows: newDocRowStream(nil), + affected: 0, + hasAff: true, + }, nil + } switch plan.Query.Kind { case ir.Read: return b.executeRead(ctx, plan, rc) @@ -100,6 +112,13 @@ func (b *Backend) executeInsert(ctx context.Context, plan *ir.Plan, rc *reqctx.C return res, nil } + // Prefer: max-affected. MongoDB writes here are not transactional, so the + // guard refuses an over-broad insert before any document is written rather + // than rolling one back; the would-insert count is known up front. + if apiErr := backend.EnforceMaxAffected(q.Write, int64(len(docs)), true); apiErr != nil { + return nil, apiErr + } + if len(docs) == 1 { _, err := coll.InsertOne(ctx, docs[0]) if err != nil { @@ -143,6 +162,19 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C filter := filterDoc(q.Where, colTypes) setDoc := writePayloadToSetDoc(q.Write, plan.Rel) + // Prefer: max-affected. Without a transaction to roll back, count the + // would-update documents first and refuse before touching any when the match + // exceeds the bound. + if q.Write != nil && q.Write.MaxRows != nil { + n, err := coll.CountDocuments(ctx, filter) + if err != nil { + return nil, b.MapError(err) + } + if apiErr := backend.EnforceMaxAffected(q.Write, n, true); apiErr != nil { + return nil, apiErr + } + } + out, err := coll.UpdateMany(ctx, filter, bson.D{{Key: "$set", Value: setDoc}}) if err != nil { return nil, b.MapError(err) @@ -168,6 +200,19 @@ func (b *Backend) executeDelete(ctx context.Context, plan *ir.Plan, rc *reqctx.C filter := filterDoc(q.Where, colTypes) + // Prefer: max-affected. Count the would-delete documents and refuse before + // removing any when the match exceeds the bound, since the delete cannot be + // rolled back. + if q.Write != nil && q.Write.MaxRows != nil { + n, err := coll.CountDocuments(ctx, filter) + if err != nil { + return nil, b.MapError(err) + } + if apiErr := backend.EnforceMaxAffected(q.Write, n, true); apiErr != nil { + return nil, apiErr + } + } + if q.Write != nil && q.Write.Return == ir.ReturnRepresentation { // Capture rows before deleting. returnDocs, err := b.findDocs(ctx, coll, filter) diff --git a/backend/mongo/mongo.go b/backend/mongo/mongo.go index 356433d..cb3eceb 100644 --- a/backend/mongo/mongo.go +++ b/backend/mongo/mongo.go @@ -128,7 +128,12 @@ func (b *Backend) MapError(err error) *pgerr.APIError { return nil } if driver.IsDuplicateKeyError(err) { - return pgerr.ErrUniqueViolation(err.Error()) + // PostgreSQL's wording, not Mongo's native text: the driver gives no + // constraint name or key value in a form that reconstructs PG's message, + // so neither is invented and the native text is not leaked into details + // (an emulation limitation, documented in the spec). + return pgerr.ErrConstraintViolation(pgerr.CodeUniqueViolation, + "duplicate key value violates unique constraint", "", "") } if driver.IsTimeout(err) { return pgerr.ErrInternal("mongodb: timeout: " + err.Error()) diff --git a/backend/mysql/dialect.go b/backend/mysql/dialect.go index 62822f0..046ed2c 100644 --- a/backend/mysql/dialect.go +++ b/backend/mysql/dialect.go @@ -192,12 +192,43 @@ func (Dialect) SessionRead(string) string { return "" } // SessionWrite reports ok=false: there is no engine setting to write. func (Dialect) SessionWrite(string) (string, bool) { return "", false } -// ArrayOp returns false; MySQL has no native array types or containment operators. -func (Dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +// ArrayOp renders a JSON array containment/overlap expression using MySQL's +// JSON_CONTAINS and JSON_OVERLAPS functions (MySQL 8.0.17+). The column must be +// declared as JSON type; for any other column type ok=false is returned so the +// compiler raises PGRST127. colType is the canonical column type enriched by the +// planner; op is one of "@>" (contains), "<@" (contained-by), "&&" (overlaps). +func (Dialect) ArrayOp(col, op, val, colType string) (string, bool) { + if colType != "json" && colType != "jsonb" { + return "", false + } + switch op { + case "@>": // contains: col contains all elements of val + return "JSON_CONTAINS(" + col + ", " + val + ")", true + case "<@": // contained-by: val contains all elements of col + return "JSON_CONTAINS(" + val + ", " + col + ")", true + case "&&": // overlaps: at least one common element + return "JSON_OVERLAPS(" + col + ", " + val + ")", true + } + return "", false +} + +// RangeOp declines: MySQL has no range types, so sl/sr/nxr/nxl/adj are PGRST127. +func (Dialect) RangeOp(_, _, _ string) (string, bool) { return "", false } // ILike uses plain LIKE; MySQL's default utf8mb4_unicode_ci collation is CI. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } +// IsBool renders "col = 1" or "col = 0". MySQL 8's IS operator only accepts +// NULL/UNKNOWN/TRUE/FALSE, not integer literals, so "col IS 1" is a syntax +// error; equality works for TINYINT(1) boolean columns. +func (Dialect) IsBool(col string, v bool) (string, bool) { + return col + " = " + Dialect{}.BoolValue(v), true +} + +// IsUnknown falls back to "col IS NULL"; a TINYINT(1) boolean column's UNKNOWN +// state is its NULL, so the row set matches. +func (Dialect) IsUnknown(string) (string, bool) { return "", false } + // BoolValue renders a boolean as 1/0. MySQL's BOOL is an alias for TINYINT(1), // so there is no native boolean keyword. func (Dialect) BoolValue(v bool) string { @@ -206,3 +237,42 @@ func (Dialect) BoolValue(v bool) string { } return "0" } + +// InList reports ok=false: MySQL has no array-bound ANY, so the compiler emits +// the expanded col IN ($1, $2, ...) form. +func (Dialect) InList(_ string) (string, bool) { return "", false } + +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so JSON_CONTAINS/JSON_OVERLAPS in ArrayOp can process it. +func (Dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText // already JSON or empty; pass through + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + quoted[i] = p // already JSON-quoted + } else { + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} + +// ArrayArg stores a payload array as its JSON text: MySQL has no array +// columns, so a JSON column holds the array and reads it back as JSON. A +// PostgreSQL {a,b} literal here would corrupt the column. +func (Dialect) ArrayArg(elems []any, _ string) any { return sqlgen.JSONArrayArg(elems) } + +// JSONPath reports ok=false so the compiler raises PGRST127. MySQL has ->/->> +// operators, but lowering them faithfully needs a live server to verify against +// and is tracked as the per-driver remainder; until then JSON paths are an +// honest capability gap rather than an unverified spelling. +func (Dialect) JSONPath(string, []string, bool) (string, bool) { return "", false } diff --git a/backend/mysql/execute.go b/backend/mysql/execute.go index 0b3a7d8..662a542 100644 --- a/backend/mysql/execute.go +++ b/backend/mysql/execute.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "strings" + "time" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/backend/sqlgen" @@ -14,6 +15,30 @@ import ( "github.com/tamnd/dbrest/schema" ) +// normalizeArgs converts ISO 8601 datetime strings (e.g. "2024-01-01T00:00:00Z") +// to time.Time so the MySQL driver can bind them correctly. MySQL rejects the ISO +// T-separator format; passing time.Time avoids the string-to-DATETIME cast entirely. +func normalizeArgs(args []any) []any { + if len(args) == 0 { + return args + } + out := make([]any, len(args)) + for i, a := range args { + if s, ok := a.(string); ok { + if t, err := time.Parse(time.RFC3339, s); err == nil { + out[i] = t + continue + } + if t, err := time.Parse("2006-01-02T15:04:05", s); err == nil { + out[i] = t + continue + } + } + out[i] = a + } + return out +} + // Execute lowers a resolved plan to MySQL operations and returns a streamable // result. Reads stream from an open cursor; writes run in a transaction and // buffer their rows (since MySQL 8 has no RETURNING, rows are re-selected after @@ -47,7 +72,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if apiErr != nil { return nil, apiErr } - if err := b.db.QueryRowContext(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil { + if err := b.db.QueryRowContext(ctx, cst.SQL, normalizeArgs(cst.Args)...).Scan(&res.count); err != nil { return nil, b.MapError(err) } res.hasCount = true @@ -57,7 +82,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if apiErr != nil { return nil, apiErr } - rows, err := b.db.QueryContext(ctx, st.SQL, st.Args...) + rows, err := b.db.QueryContext(ctx, st.SQL, normalizeArgs(st.Args)...) if err != nil { return nil, b.MapError(err) } @@ -79,6 +104,19 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co q := plan.Query returning := returningCols(q, plan.Rel) + // An empty column set (POST with an empty array, PATCH with an empty object) + // is a no-op: nothing is compiled or run, the affected count is zero, and the + // representation is the empty array. The HTTP layer turns that into 201/[] for + // an insert and 204 or 200/[] for an update. + if backend.IsNoOpMutation(q) { + return &writeResult{ + controls: rc.Controls(), + cols: returning, + affected: 0, + hasAff: true, + }, nil + } + tx, err := b.db.BeginTx(ctx, nil) if err != nil { return nil, b.MapError(err) @@ -98,6 +136,11 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co return nil, b.MapError(err) } + // Prefer: max-affected rolls an over-broad write back instead of committing. + if apiErr := backend.EnforceMaxAffected(q.Write, res.affected, res.hasAff); apiErr != nil { + return nil, apiErr + } + if q.Write != nil && q.Write.Tx == ir.TxRollback { return res, nil } @@ -222,7 +265,10 @@ func (b *Backend) executeInsertEmulated( return nil } -// executeUpdateEmulated runs UPDATE then re-selects with the same filter. +// executeUpdateEmulated runs UPDATE then re-selects by pre-captured primary keys. +// The re-select must use PKs, not the original filter, because the UPDATE may +// change the very column being filtered (e.g. PATCH /todos?task=eq.old sets +// task=new — after the UPDATE, task=eq.old matches nothing). func (b *Backend) executeUpdateEmulated( ctx context.Context, tx *sql.Tx, q *ir.Query, returning []string, rel *schema.Relation, @@ -232,6 +278,16 @@ func (b *Backend) executeUpdateEmulated( if apiErr != nil { return apiErr } + + // Pre-capture PKs when we need to return representation. + var pkValues []any + if len(returning) > 0 && len(rel.PrimaryKey) == 1 { + pkValues, apiErr = b.selectPKs(ctx, tx, q, rel.PrimaryKey[0]) + if apiErr != nil { + return apiErr + } + } + out, err := tx.ExecContext(ctx, st.SQL, st.Args...) if err != nil { return err @@ -239,35 +295,84 @@ func (b *Backend) executeUpdateEmulated( n, _ := out.RowsAffected() res.affected, res.hasAff = n, true - if len(returning) == 0 || n == 0 { + if len(returning) == 0 || len(pkValues) == 0 { return nil } - // Re-select: compile the equivalent SELECT with the same filters. - readQ := *q - readQ.Kind = ir.Read - readST, apiErr := sqlgen.CompileRead(Dialect{}, &readQ) + // Re-select by PK (post-update values). + colNames, buf, err := b.selectByPKs(ctx, tx, rel, rel.PrimaryKey[0], pkValues, returning) + if err != nil { + return err + } + res.cols, res.rows = colNames, buf + return nil +} + +// selectPKs runs "SELECT pk FROM table WHERE " and returns +// the raw PK values. Used to anchor the post-write re-select. +func (b *Backend) selectPKs( + ctx context.Context, tx *sql.Tx, + q *ir.Query, pkCol string, +) ([]any, *pgerr.APIError) { + pkQ := *q + pkQ.Kind = ir.Read + pkQ.Select = []ir.SelectItem{ir.Column{Path: []string{pkCol}}} + pkQ.Embeds = nil + pkQ.Order = nil + pkQ.Singular = false + st, apiErr := sqlgen.CompileRead(Dialect{}, &pkQ) if apiErr != nil { - return apiErr + return nil, apiErr } - rows, err := tx.QueryContext(ctx, readST.SQL, readST.Args...) + rows, err := tx.QueryContext(ctx, st.SQL, normalizeArgs(st.Args)...) if err != nil { - return err + return nil, pgerr.New(500, "XX000", err.Error()) + } + defer rows.Close() + var vals []any + for rows.Next() { + var v any + if err := rows.Scan(&v); err != nil { + return nil, pgerr.New(500, "XX000", err.Error()) + } + vals = append(vals, v) + } + if err := rows.Err(); err != nil { + return nil, pgerr.New(500, "XX000", err.Error()) + } + return vals, nil +} + +// selectByPKs runs "SELECT cols FROM table WHERE pk IN (?,...)" using pre-captured +// PK values and returns the column names and buffered rows. +func (b *Backend) selectByPKs( + ctx context.Context, tx *sql.Tx, + rel *schema.Relation, pkCol string, pkValues []any, cols []string, +) ([]string, [][]any, error) { + d := Dialect{} + table := d.QuoteIdent(rel.Name) + pk := d.QuoteIdent(pkCol) + selCols := quotedCols(cols) + placeholders := make([]string, len(pkValues)) + for i := range pkValues { + placeholders[i] = "?" + } + sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s IN (%s)", + selCols, table, pk, strings.Join(placeholders, ",")) + rows, err := tx.QueryContext(ctx, sql, pkValues...) + if err != nil { + return nil, nil, err } colNames, err := rows.Columns() if err != nil { rows.Close() - return err + return nil, nil, err } boolCols := buildBoolCols(rel) jsonIdx, boolIdx, _ := buildColMaps(rows, boolCols) buf, err := drain(rows, colNames, jsonIdx, boolIdx) rows.Close() - if err != nil { - return err - } - res.cols, res.rows = colNames, buf - return nil + return colNames, buf, err } // executeDeleteEmulated selects the rows to return, then deletes them. @@ -284,7 +389,7 @@ func (b *Backend) executeDeleteEmulated( if apiErr != nil { return apiErr } - rows, err := tx.QueryContext(ctx, readST.SQL, readST.Args...) + rows, err := tx.QueryContext(ctx, readST.SQL, normalizeArgs(readST.Args)...) if err != nil { return err } @@ -318,15 +423,16 @@ func (b *Backend) executeDeleteEmulated( // executeCall runs a stored procedure or portable RPC function. func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - st, apiErr := sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func) + st, apiErr := sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func, sqlgen.ContextArgs(rc)) if apiErr != nil { return nil, apiErr } + st.Args = normalizeArgs(st.Args) if plan.ReadOnly { res := &result{controls: rc.Controls()} if plan.Call.Count != ir.CountNone { - cst, apiErr := sqlgen.CompileCallCount(Dialect{}, plan.Call, plan.Func) + cst, apiErr := sqlgen.CompileCallCount(Dialect{}, plan.Call, plan.Func, sqlgen.ContextArgs(rc)) if apiErr != nil { return nil, apiErr } @@ -383,17 +489,26 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con // compileWrite dispatches to the right compiler for the mutation kind. // When returning is empty the compiler omits the RETURNING / OUTPUT clause. +// Args are normalized for MySQL (ISO 8601 → time.Time) before returning. func compileWrite(q *ir.Query, returning []string) (*sqlgen.Statement, *pgerr.APIError) { + var ( + st *sqlgen.Statement + apiErr *pgerr.APIError + ) switch q.Kind { case ir.Insert, ir.Upsert: - return sqlgen.CompileInsert(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileInsert(Dialect{}, q, returning) case ir.Update: - return sqlgen.CompileUpdate(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileUpdate(Dialect{}, q, returning) case ir.Delete: - return sqlgen.CompileDelete(Dialect{}, q, returning) + st, apiErr = sqlgen.CompileDelete(Dialect{}, q, returning) default: return nil, pgerr.ErrUnsupported("this operation", "mysql") } + if st != nil { + st.Args = normalizeArgs(st.Args) + } + return st, apiErr } // returningCols decides which columns to read back after a write. @@ -401,6 +516,9 @@ func compileWrite(q *ir.Query, returning []string) (*sqlgen.Statement, *pgerr.AP // primary key only (for the Location header); for minimal updates/deletes it is nil. func returningCols(q *ir.Query, rel *schema.Relation) []string { if q.Write != nil && q.Write.Return == ir.ReturnRepresentation { + if cols := q.ProjectedColumns(); cols != nil { + return cols + } return rel.ColumnNames() } if q.Kind == ir.Insert || q.Kind == ir.Upsert { diff --git a/backend/mysql/fulltext.go b/backend/mysql/fulltext.go index 0f95978..31b87f4 100644 --- a/backend/mysql/fulltext.go +++ b/backend/mysql/fulltext.go @@ -24,7 +24,7 @@ import ( // goes). The translation is Best-effort: MySQL boolean mode has no general // AND/OR/grouping the way to_tsquery does, so disjunction and grouping are // approximated and documented in the conformance allowlist (spec 22). -func (Dialect) FullText(col string, _ *sqlgen.FullTextRef, variant ir.FTSVariant, _, value string) (string, string, bool) { +func (Dialect) FullText(col, _ string, _ *sqlgen.FullTextRef, variant ir.FTSVariant, _, value string) (string, string, bool) { frag := "MATCH(" + col + ") AGAINST(" + sqlgen.PatternMark + " IN BOOLEAN MODE)" return frag, booleanModeQuery(variant, value), true } diff --git a/backend/mysql/fulltext_test.go b/backend/mysql/fulltext_test.go index 975139c..cafc06b 100644 --- a/backend/mysql/fulltext_test.go +++ b/backend/mysql/fulltext_test.go @@ -10,12 +10,12 @@ import ( // wrapper is fixed and snapshotted in compile_test; these cases pin the grammar // translation, the part that carries the per-variant divergence. func fts(v ir.FTSVariant, value string) string { - _, q, _ := Dialect{}.FullText("`c`", nil, v, "", value) + _, q, _ := Dialect{}.FullText("`c`", "", nil, v, "", value) return q } func TestFullTextWrapper(t *testing.T) { - frag, _, ok := Dialect{}.FullText("`c`", nil, ir.FTSPlain, "", "x") + frag, _, ok := Dialect{}.FullText("`c`", "", nil, ir.FTSPlain, "", "x") if !ok || frag != "MATCH(`c`) AGAINST($PAT$ IN BOOLEAN MODE)" { t.Errorf("frag = %q, ok = %v", frag, ok) } diff --git a/backend/mysql/mysql.go b/backend/mysql/mysql.go index e5e3996..6816e57 100644 --- a/backend/mysql/mysql.go +++ b/backend/mysql/mysql.go @@ -52,6 +52,7 @@ func Open(dsn string) (*Backend, error) { return nil, fmt.Errorf("invalid MySQL DSN: %w", err) } cfg.ParseTime = true + cfg.ClientFoundRows = true // report matched rows, not changed rows (UPDATE re-select gate) delete(cfg.Params, "tinyInt1IsBool") // removed in v1.8; schema-layer handles coercion connector, err := mysqldrv.NewConnector(cfg) @@ -122,15 +123,23 @@ func (b *Backend) MapError(err error) *pgerr.APIError { // mapMySQLError builds the unified API error from a MySQL driver error. func mapMySQLError(me *mysqldrv.MySQLError) *pgerr.APIError { + // Class-23 violations carry PostgreSQL's wording, not the native MySQL text: + // the MySQL driver gives no constraint name or offending value in a form that + // reconstructs PG's message, so neither is invented and the native text is not + // leaked into details (an emulation limitation, documented in the spec). switch me.Number { case 1062: // ER_DUP_ENTRY - return pgerr.ErrUniqueViolation(me.Message) + return pgerr.ErrConstraintViolation(pgerr.CodeUniqueViolation, + "duplicate key value violates unique constraint", "", "") case 1048: // ER_BAD_NULL_ERROR - return pgerr.ErrNotNullViolation(me.Message) + return pgerr.ErrConstraintViolation(pgerr.CodeNotNullViolation, + "null value violates not-null constraint", "", "") case 1406, 1264: // ER_DATA_TOO_LONG, ER_WARN_DATA_OUT_OF_RANGE - return pgerr.ErrCheckViolation(me.Message) + return pgerr.ErrConstraintViolation(pgerr.CodeCheckViolation, + "new row violates check constraint", "", "") case 1451, 1452: // ER_ROW_IS_REFERENCED_2, ER_NO_REFERENCED_ROW_2 - return pgerr.ErrForeignKeyViolation(me.Message) + return pgerr.ErrConstraintViolation(pgerr.CodeForeignKeyViolation, + "insert or update on table violates foreign key constraint", "", "") case 1054, 1247: // ER_BAD_FIELD_ERROR, ER_ILLEGAL_REFERENCE return pgerr.New(400, "42703", me.Message) case 1146: // ER_NO_SUCH_TABLE diff --git a/backend/postgres/compile_test.go b/backend/postgres/compile_test.go index 5b33106..80f1800 100644 --- a/backend/postgres/compile_test.go +++ b/backend/postgres/compile_test.go @@ -110,6 +110,39 @@ func TestCompileRegexSnapshot(t *testing.T) { } } +// An in-list lowers to col = ANY($1) binding a single array literal, the form +// PostgREST uses so a list of any length reuses one prepared statement. Finding +// P13. +func TestCompileInListSnapshot(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"id"}, Op: ir.OpIn, Value: ir.Value{List: []string{"1", "2", "3"}}}) + st, err := sqlgen.CompileRead(d, &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where}) + if err != nil { + t.Fatalf("CompileRead: %v", err) + } + want := `SELECT * FROM "films" WHERE "id" = ANY($1)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } + // One bound argument carries the whole list as a PostgreSQL array literal. + if len(st.Args) != 1 || st.Args[0] != "{1,2,3}" { + t.Errorf("Args = %v, want [{1,2,3}]", st.Args) + } +} + +// A negated in-list keeps the single-parameter form under a NOT wrapper, which +// returns the same rows as PostgREST's <> ALL. +func TestCompileNotInListSnapshot(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"id"}, Op: ir.OpIn, Negate: true, Value: ir.Value{List: []string{"1", "2"}}}) + st, err := sqlgen.CompileRead(d, &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where}) + if err != nil { + t.Fatalf("CompileRead: %v", err) + } + want := `SELECT * FROM "films" WHERE NOT ("id" = ANY($1))` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + func TestCompileFTSSnapshot(t *testing.T) { where := ir.Cond(ir.Compare{Path: []string{"body"}, Op: ir.OpFTS, FTS: ir.FTSWeb, Value: ir.Value{Text: "cat dog"}}) st, err := sqlgen.CompileRead(d, &ir.Query{Relation: ir.Ref{Name: "docs"}, Where: &where}) diff --git a/backend/postgres/computed.go b/backend/postgres/computed.go new file mode 100644 index 0000000..4978261 --- /dev/null +++ b/backend/postgres/computed.go @@ -0,0 +1,123 @@ +package postgres + +import ( + "context" + + "github.com/tamnd/dbrest/schema" +) + +// loadComputedFields maps every exposed relation's OID to the computed fields it +// carries: functions that take the relation's row type and return a scalar, +// exposed as virtual columns a client can select, filter, and order by (PostgREST +// computed fields, spec 11). PostgreSQL records a function's first argument type +// in pg_proc.proargtypes; when that type is a relation's composite row type +// (pg_class.reltype), the function is a candidate. A scalar return (pg_type.typtype +// other than 'c') marks it a computed field; a composite or set-returning function +// over the same row type is a computed relationship instead and is read elsewhere. +// +// The function must live in the same schema as the relation, matching PostgREST, +// which only exposes a computed field defined alongside its table. A real column of +// the same name takes precedence: the model indexes columns first, so a function +// shadowing a stored column is simply never reached. +func (b *Backend) loadComputedFields(ctx context.Context, schemas []string) (map[uint32][]schema.ComputedField, error) { + const q = ` +SELECT cls.oid AS rel_oid, + n.nspname AS fn_schema, + p.proname AS fn_name, + format_type(p.prorettype, NULL) AS ret_type + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + JOIN pg_type at ON at.oid = p.proargtypes[0] + JOIN pg_class cls ON cls.reltype = at.oid AND cls.relkind IN ('r','v','m','f','p') + JOIN pg_namespace cn ON cn.oid = cls.relnamespace + JOIN pg_type rt ON rt.oid = p.prorettype + WHERE n.nspname = ANY($1) + AND cn.nspname = ANY($1) + AND n.nspname = cn.nspname + AND p.prokind = 'f' + AND p.pronargs = 1 + AND NOT p.proretset + AND p.provariadic = 0 + AND rt.typtype <> 'c' + ORDER BY cls.oid, p.proname` + + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + out := map[uint32][]schema.ComputedField{} + for rows.Next() { + var relOID uint32 + var cf schema.ComputedField + var retType string + if err := rows.Scan(&relOID, &cf.FuncSchema, &cf.Name, &retType); err != nil { + return nil, err + } + cf.Type = canonicalType(retType) + out[relOID] = append(out[relOID], cf) + } + return out, rows.Err() +} + +// loadComputedRels maps every exposed relation's OID to its computed +// relationships: functions taking the relation's row type and returning rows of +// another exposed relation, exposed as embeddable edges (PostgREST computed +// relationships, the escape hatch for recursive embeds, spec 11). The first +// argument is the parent relation's composite row type (pg_class.reltype); the +// return type is another relation's composite type, set-returning for a to-many +// edge and a single row for to-one. A function returning RETURNS TABLE(...) or a +// bare scalar is not a computed relationship: its return type is a pseudo or base +// type, not a relation's composite, so it never matches here (a scalar return is a +// computed field instead, read by loadComputedFields). +// +// The function must live in the same schema as the parent relation, matching +// PostgREST, which exposes a computed relationship defined alongside its table. +func (b *Backend) loadComputedRels(ctx context.Context, schemas []string) (map[uint32][]schema.ComputedRel, error) { + const q = ` +SELECT pc.oid AS parent_oid, + n.nspname AS fn_schema, + p.proname AS fn_name, + p.proretset, + tn.nspname AS target_schema, + tc.relname AS target_name + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + JOIN pg_type pat ON pat.oid = p.proargtypes[0] + JOIN pg_class pc ON pc.reltype = pat.oid AND pc.relkind IN ('r','v','m','f','p') + JOIN pg_namespace pn ON pn.oid = pc.relnamespace + JOIN pg_type rt ON rt.oid = p.prorettype AND rt.typtype = 'c' + JOIN pg_class tc ON tc.reltype = rt.oid AND tc.relkind IN ('r','v','m','f','p') + JOIN pg_namespace tn ON tn.oid = tc.relnamespace + WHERE n.nspname = ANY($1) + AND pn.nspname = ANY($1) + AND tn.nspname = ANY($1) + AND n.nspname = pn.nspname + AND p.prokind = 'f' + AND p.pronargs = 1 + AND p.provariadic = 0 + ORDER BY pc.oid, p.proname` + + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + out := map[uint32][]schema.ComputedRel{} + for rows.Next() { + var parentOID uint32 + var setof bool + var cr schema.ComputedRel + if err := rows.Scan(&parentOID, &cr.FuncSchema, &cr.Name, &setof, &cr.TargetSchema, &cr.TargetName); err != nil { + return nil, err + } + cr.Card = schema.CardToOne + if setof { + cr.Card = schema.CardToMany + } + out[parentOID] = append(out[parentOID], cr) + } + return out, rows.Err() +} diff --git a/backend/postgres/count.go b/backend/postgres/count.go new file mode 100644 index 0000000..33576ec --- /dev/null +++ b/backend/postgres/count.go @@ -0,0 +1,118 @@ +package postgres + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/jackc/pgx/v5" + + "github.com/tamnd/dbrest/backend/sqlgen" + "github.com/tamnd/dbrest/ir" +) + +// computeCount returns the total a read's Content-Range reports, by the strategy +// the request asked for (item 07.7): +// +// - exact: count(*) over the same relation and predicates the body ran. +// - planned: the planner's row estimate, read from EXPLAIN. Fast and +// approximate, it never touches the heap. +// - estimated: exact while the result is small, the planner estimate once it +// grows past db-max-rows. The capped exact count stops at the threshold, so a +// large table pays only for the estimate. +// +// PostgreSQL's EXPLAIN output is a stable, documented format; the estimate is the +// root plan node's Plan Rows, which for a plain SELECT is the predicted output +// row count. +func (b *Backend) computeCount(ctx context.Context, tx pgx.Tx, q *ir.Query) (int64, error) { + switch q.Count { + case ir.CountPlanned: + return b.plannedCount(ctx, tx, q) + case ir.CountEstimated: + return b.estimatedCount(ctx, tx, q) + default: // CountExact + return b.exactCount(ctx, tx, q) + } +} + +// exactCount runs count(*) over the relation and predicates. +func (b *Backend) exactCount(ctx context.Context, tx pgx.Tx, q *ir.Query) (int64, error) { + cst, apiErr := sqlgen.CompileCount(Dialect{}, q) + if apiErr != nil { + return 0, apiErr + } + var n int64 + if err := tx.QueryRow(ctx, cst.SQL, cst.Args...).Scan(&n); err != nil { + return 0, b.MapError(err) + } + return n, nil +} + +// plannedCount returns the planner's row estimate for the count's source query. +func (b *Backend) plannedCount(ctx context.Context, tx pgx.Tx, q *ir.Query) (int64, error) { + src, apiErr := sqlgen.CompileRowEstimateSource(Dialect{}, q) + if apiErr != nil { + return 0, apiErr + } + var raw []byte + if err := tx.QueryRow(ctx, "EXPLAIN (FORMAT JSON) "+src.SQL, src.Args...).Scan(&raw); err != nil { + return 0, b.MapError(err) + } + rows, err := parseExplainRows(raw) + if err != nil { + return 0, b.MapError(err) + } + return rows, nil +} + +// estimatedCount counts exactly until the result passes db-max-rows, then falls +// back to the planner estimate. With no threshold configured it is exact. +func (b *Backend) estimatedCount(ctx context.Context, tx pgx.Tx, q *ir.Query) (int64, error) { + if q.CountMax <= 0 { + return b.exactCount(ctx, tx, q) + } + src, apiErr := sqlgen.CompileRowEstimateSource(Dialect{}, q) + if apiErr != nil { + return 0, apiErr + } + // Count the source rows but stop one past the threshold: a result at or below + // it is the exact total, while reaching threshold+1 only proves there are more, + // at which point the planner estimate is cheaper and good enough. + capped := fmt.Sprintf("SELECT count(*) FROM (%s LIMIT %d) _pgrst_capped", + src.SQL, q.CountMax+1) + var n int64 + if err := tx.QueryRow(ctx, capped, src.Args...).Scan(&n); err != nil { + return 0, b.MapError(err) + } + if n <= q.CountMax { + return n, nil + } + return b.plannedCount(ctx, tx, q) +} + +// explainNode is the slice of EXPLAIN (FORMAT JSON) output the estimate needs: +// the top plan node carries the planner's output-row estimate in "Plan Rows". +type explainNode struct { + Plan struct { + PlanRows float64 `json:"Plan Rows"` + } `json:"Plan"` +} + +// parseExplainRows reads the root node's row estimate out of EXPLAIN (FORMAT +// JSON) output, which is a one-element array of plan trees. The estimate is a +// float in the plan (PostgreSQL prints fractional estimates), rounded to the +// nearest whole row. +func parseExplainRows(raw []byte) (int64, error) { + var plans []explainNode + if err := json.Unmarshal(raw, &plans); err != nil { + return 0, fmt.Errorf("parse EXPLAIN output: %w", err) + } + if len(plans) == 0 { + return 0, fmt.Errorf("EXPLAIN output held no plan") + } + rows := plans[0].Plan.PlanRows + if rows < 0 { + rows = 0 + } + return int64(rows + 0.5), nil +} diff --git a/backend/postgres/count_test.go b/backend/postgres/count_test.go new file mode 100644 index 0000000..79e0db0 --- /dev/null +++ b/backend/postgres/count_test.go @@ -0,0 +1,177 @@ +package postgres + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + + "github.com/tamnd/dbrest/ir" +) + +// These count-strategy tests reach the unexported computeCount/parseExplainRows, +// so they live in the internal package and carry their own DSN gate rather than +// borrowing the external integration helpers. +func countDSN(t *testing.T) string { + t.Helper() + s := os.Getenv("DBREST_PG_DSN") + if s == "" { + t.Skip("DBREST_PG_DSN not set; skipping postgres count integration tests") + } + return s +} + +func openCount(t *testing.T, dsn string) *Backend { + t.Helper() + be, err := Open(dsn) + if err != nil { + t.Fatalf("Open: %v", err) + } + be.SetSchemas([]string{"public"}) + return be +} + +func mustExec(t *testing.T, b *Backend, sql string) { + t.Helper() + if _, err := b.Pool().Exec(context.Background(), sql); err != nil { + t.Fatalf("exec: %v", err) + } +} + +func beginTx(t *testing.T, b *Backend) pgx.Tx { + t.Helper() + tx, err := b.Pool().Begin(context.Background()) + if err != nil { + t.Fatalf("begin: %v", err) + } + return tx +} + +// parseExplainRows reads the root node's row estimate out of the documented +// EXPLAIN (FORMAT JSON) shape, rounding the planner's fractional estimate. +func TestParseExplainRows(t *testing.T) { + cases := []struct { + name string + raw string + want int64 + }{ + { + name: "seq scan estimate", + raw: `[{"Plan":{"Node Type":"Seq Scan","Relation Name":"films","Plan Rows":1234,"Plan Width":8}}]`, + want: 1234, + }, + { + name: "fractional estimate rounds", + raw: `[{"Plan":{"Node Type":"Index Scan","Plan Rows":41.6}}]`, + want: 42, + }, + { + name: "rounds down below half", + raw: `[{"Plan":{"Node Type":"Index Scan","Plan Rows":41.2}}]`, + want: 41, + }, + { + name: "nested child does not shadow the root estimate", + raw: `[{"Plan":{"Node Type":"Aggregate","Plan Rows":1,` + + `"Plans":[{"Node Type":"Seq Scan","Plan Rows":9999}]}}]`, + want: 1, + }, + { + name: "zero rows", + raw: `[{"Plan":{"Node Type":"Result","Plan Rows":0}}]`, + want: 0, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got, err := parseExplainRows([]byte(c.raw)) + if err != nil { + t.Fatalf("parseExplainRows: %v", err) + } + if got != c.want { + t.Errorf("rows = %d, want %d", got, c.want) + } + }) + } +} + +func TestParseExplainRowsRejectsGarbage(t *testing.T) { + if _, err := parseExplainRows([]byte("not json")); err == nil { + t.Error("want error on non-JSON EXPLAIN output") + } + if _, err := parseExplainRows([]byte("[]")); err == nil { + t.Error("want error on an empty plan array") + } +} + +// The estimated count is exact when no db-max-rows threshold is configured: with +// nothing to cross over at, the planner estimate never enters in. +func TestEstimatedCountExactWithoutThreshold(t *testing.T) { + dsn := countDSN(t) // skips without DBREST_PG_DSN + b := openCount(t, dsn) + defer b.Close() + + mustExec(t, b, `DROP TABLE IF EXISTS estc; CREATE TABLE estc(id int); + INSERT INTO estc SELECT g FROM generate_series(1, 30) g;`) + + tx := beginTx(t, b) + defer tx.Rollback(context.Background()) + q := &ir.Query{Relation: ir.Ref{Schema: "public", Name: "estc"}, Count: ir.CountEstimated} + got, err := b.computeCount(context.Background(), tx, q) + if err != nil { + t.Fatalf("computeCount: %v", err) + } + if got != 30 { + t.Errorf("estimated count without threshold = %d, want exact 30", got) + } +} + +// Below the threshold an estimated count is exact; the capped probe returns the +// true total without ever consulting the planner. +func TestEstimatedCountExactBelowThreshold(t *testing.T) { + dsn := countDSN(t) + b := openCount(t, dsn) + defer b.Close() + + mustExec(t, b, `DROP TABLE IF EXISTS estc; CREATE TABLE estc(id int); + INSERT INTO estc SELECT g FROM generate_series(1, 30) g;`) + + tx := beginTx(t, b) + defer tx.Rollback(context.Background()) + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "estc"}, + Count: ir.CountEstimated, + CountMax: 100, + } + got, err := b.computeCount(context.Background(), tx, q) + if err != nil { + t.Fatalf("computeCount: %v", err) + } + if got != 30 { + t.Errorf("estimated count below threshold = %d, want exact 30", got) + } +} + +// A planned count returns the planner estimate, which after ANALYZE matches the +// real row count closely for a simple table. +func TestPlannedCountUsesPlannerEstimate(t *testing.T) { + dsn := countDSN(t) + b := openCount(t, dsn) + defer b.Close() + + mustExec(t, b, `DROP TABLE IF EXISTS estc; CREATE TABLE estc(id int); + INSERT INTO estc SELECT g FROM generate_series(1, 500) g; ANALYZE estc;`) + + tx := beginTx(t, b) + defer tx.Rollback(context.Background()) + q := &ir.Query{Relation: ir.Ref{Schema: "public", Name: "estc"}, Count: ir.CountPlanned} + got, err := b.computeCount(context.Background(), tx, q) + if err != nil { + t.Fatalf("computeCount: %v", err) + } + // The planner estimate of a freshly analyzed 500-row table is exact here. + if got != 500 { + t.Errorf("planned count = %d, want the planner estimate 500", got) + } +} diff --git a/backend/postgres/dialect.go b/backend/postgres/dialect.go index 77fac40..9bdc4d1 100644 --- a/backend/postgres/dialect.go +++ b/backend/postgres/dialect.go @@ -16,6 +16,7 @@ package postgres import ( + "fmt" "strconv" "strings" @@ -86,6 +87,13 @@ func (Dialect) Returning(cols []string) (string, bool) { // row, so a merge sets each column to its excluded value. An empty update set or // an ignore request becomes DO NOTHING. func (Dialect) Upsert(spec sqlgen.UpsertSpec) (string, error) { + // DO UPDATE without a conflict target is not valid PostgreSQL ("ON CONFLICT DO + // UPDATE requires inference specification or constraint name"). The compiler + // already degrades a no-target upsert to a plain INSERT, so this guards against + // a future caller emitting the invalid form. + if !spec.Ignore && len(spec.Update) > 0 && len(spec.Target) == 0 { + return "", fmt.Errorf("merge upsert needs a conflict target") + } var sb strings.Builder sb.WriteString("ON CONFLICT") if len(spec.Target) > 0 { @@ -110,12 +118,34 @@ func (Dialect) Upsert(spec sqlgen.UpsertSpec) (string, error) { // JSONObject assembles a JSON object with json_build_object, the function whose // argument order fixes the key order to the select order (spec 06, "JSON // assembly"). Keys are JSON string literals; values are already-compiled SQL. +// +// PostgreSQL caps a function call at 100 arguments (FUNC_MAX_ARGS), and each pair +// is two arguments, so an object of more than 50 keys (a wide embedded table) +// would raise 54023. Past that threshold the object is built in chunks of 50 +// pairs with jsonb_build_object and concatenated with jsonb's || , then cast back +// to json so the result type matches the unchunked form for json_agg and the +// json cast downstream. func (Dialect) JSONObject(pairs []sqlgen.Pair) string { - parts := make([]string, 0, len(pairs)*2) - for _, p := range pairs { - parts = append(parts, "'"+strings.ReplaceAll(p.Key, "'", "''")+"'", p.Value) + const maxPairs = 50 + buildChunk := func(chunk []sqlgen.Pair, fn string) string { + parts := make([]string, 0, len(chunk)*2) + for _, p := range chunk { + parts = append(parts, "'"+strings.ReplaceAll(p.Key, "'", "''")+"'", p.Value) + } + return fn + "(" + strings.Join(parts, ", ") + ")" + } + if len(pairs) <= maxPairs { + return buildChunk(pairs, "json_build_object") + } + var chunks []string + for i := 0; i < len(pairs); i += maxPairs { + end := i + maxPairs + if end > len(pairs) { + end = len(pairs) + } + chunks = append(chunks, buildChunk(pairs[i:end], "jsonb_build_object")) } - return "json_build_object(" + strings.Join(parts, ", ") + ")" + return "to_json(" + strings.Join(chunks, " || ") + ")" } // JSONAgg aggregates rows with json_agg. PostgreSQL takes an ORDER BY inside the @@ -130,15 +160,18 @@ func (Dialect) JSONAgg(elem, orderBy string) string { // Cast translates a canonical type to a PostgreSQL ::type cast, the form PG // itself uses. The expression is parenthesized so the cast binds to the whole -// expression, not just its tail. An unknown canonical type falls back to text, -// which is the safe rendering for an opaque value. +// expression, not just its tail. The type name is passed through to PostgreSQL +// after the parser has validated it against a safe grammar (ir.validCastType), +// so casts to money, interval, an enum, a domain, or an array type resolve the +// same way they do under PostgREST rather than degrading to text. func (Dialect) Cast(expr, canonicalType string) string { return "(" + expr + ")::" + pgType(canonicalType) } -// pgType maps a canonical type name to its PostgreSQL spelling. The canonical -// names are the PG type names already in most cases, so the map mostly -// normalizes aliases (int->int4, bool->boolean stays bool) to one spelling. +// pgType normalizes a handful of canonical aliases to one PostgreSQL spelling +// (int->int4 and friends) and passes every other type name through unchanged. +// The name has already been validated as a safe type spelling by the parser, so +// PostgreSQL resolves it directly the way PostgREST relies on. func pgType(canonical string) string { switch canonical { case "int", "integer", "int4": @@ -172,7 +205,7 @@ func pgType(canonical string) string { case "jsonb": return "jsonb" default: - return "text" + return canonical } } @@ -222,7 +255,12 @@ func sqlLiteral(s string) string { } // ArrayOp renders a PostgreSQL array containment/overlap expression. -func (Dialect) ArrayOp(col, op, val string) (string, bool) { +func (Dialect) ArrayOp(col, op, val, _ string) (string, bool) { + return col + " " + op + " " + val, true +} + +// RangeOp renders PostgreSQL's native range operators (<<, >>, &<, &>, -|-). +func (Dialect) RangeOp(col, op, val string) (string, bool) { return col + " " + op + " " + val, true } @@ -236,3 +274,54 @@ func (Dialect) BoolValue(v bool) string { } return "FALSE" } + +// IsBool falls back to the generic "IS TRUE"/"IS FALSE" form; PostgreSQL +// supports IS natively. +func (Dialect) IsBool(string, bool) (string, bool) { return "", false } + +// IsUnknown renders PostgreSQL's native three-valued "col IS UNKNOWN" test. +func (Dialect) IsUnknown(col string) (string, bool) { return col + " IS UNKNOWN", true } + +// InList renders an in-list as col = ANY($n), the form PostgREST uses so a list +// of any length binds as one array parameter and reuses a single prepared +// statement. The rows are identical to an expanded IN. +func (Dialect) InList(col string) (string, bool) { + return col + " = ANY(" + sqlgen.PatternMark + ")", true +} + +// ArrayLiteral returns the PostgreSQL {a,b} array literal unchanged; PostgreSQL +// accepts it natively. +func (Dialect) ArrayLiteral(pgText string) string { return pgText } + +// ArrayArg renders a payload array for the target column. A JSON array bound for +// a json/jsonb column is JSON, not a PostgreSQL array, so it is kept as JSON +// text; for an array column it becomes the {a,b} array-literal text so the +// server-side cast from text to text[]/int4[]/etc. succeeds with or without type +// OIDs. An unknown column type keeps the array-literal default. +func (Dialect) ArrayArg(elems []any, colType string) any { + if colType == "json" || colType == "jsonb" { + return sqlgen.JSONArrayArg(elems) + } + return sqlgen.PGArrayLiteral(elems) +} + +// JSONPath emits PostgreSQL's native -> / ->> operator chain: every hop is -> +// (json) except the final one, which is ->> when the access was text. A digit +// segment becomes an integer array index; any other segment is a quoted key. +func (Dialect) JSONPath(base string, hops []string, asText bool) (string, bool) { + var b strings.Builder + b.WriteString(base) + for i, h := range hops { + op := "->" + if asText && i == len(hops)-1 { + op = "->>" + } + b.WriteString(op) + if sqlgen.IsJSONArrayIndex(h) { + b.WriteString(h) + } else { + b.WriteString("'" + strings.ReplaceAll(h, "'", "''") + "'") + } + } + return b.String(), true +} diff --git a/backend/postgres/dialect_test.go b/backend/postgres/dialect_test.go index 4c5151d..c5ed5c4 100644 --- a/backend/postgres/dialect_test.go +++ b/backend/postgres/dialect_test.go @@ -1,6 +1,8 @@ package postgres import ( + "fmt" + "strings" "testing" "github.com/tamnd/dbrest/backend/sqlgen" @@ -107,11 +109,6 @@ func TestUpsert(t *testing.T) { sqlgen.UpsertSpec{Target: []string{`"id"`}, Ignore: true}, `ON CONFLICT ("id") DO NOTHING`, }, - { - "merge no target", - sqlgen.UpsertSpec{Update: []string{`"title"`}}, - `ON CONFLICT DO UPDATE SET "title" = excluded."title"`, - }, { "empty update degrades to nothing", sqlgen.UpsertSpec{Target: []string{`"id"`}}, @@ -127,6 +124,13 @@ func TestUpsert(t *testing.T) { t.Errorf("%s: = %q, want %q", c.name, got, c.want) } } + + // A merge with no conflict target is rejected: ON CONFLICT DO UPDATE without an + // inference specification is invalid PostgreSQL. The compiler degrades this to a + // plain INSERT before reaching here, matching PostgREST, so this is a guard. + if _, err := d.Upsert(sqlgen.UpsertSpec{Update: []string{`"title"`}}); err == nil { + t.Error("merge with empty target should return an error, got nil") + } } func TestJSONObject(t *testing.T) { @@ -170,7 +174,16 @@ func TestCast(t *testing.T) { "uuid": `("x")::uuid`, "json": `("x")::json`, "jsonb": `("x")::jsonb`, - "mystery": `("x")::text`, + // Types outside the alias table pass through verbatim rather than + // degrading to text, so they resolve the way PostgREST relies on. + "money": `("x")::money`, + "interval": `("x")::interval`, + "bytea": `("x")::bytea`, + "inet": `("x")::inet`, + "mood": `("x")::mood`, + "numeric(10,2)": `("x")::numeric(10,2)`, + "int[]": `("x")::int[]`, + "public.color": `("x")::public.color`, } for in, want := range cases { if got := d.Cast(`"x"`, in); got != want { @@ -205,16 +218,21 @@ func TestFullText(t *testing.T) { name string variant ir.FTSVariant config string + colType string want string }{ - {"plain no config", ir.FTSPlain, "", `to_tsvector("body") @@ to_tsquery(` + sqlgen.PatternMark + `)`}, - {"plaintext", ir.FTSPlainText, "", `to_tsvector("body") @@ plainto_tsquery(` + sqlgen.PatternMark + `)`}, - {"phrase", ir.FTSPhrase, "", `to_tsvector("body") @@ phraseto_tsquery(` + sqlgen.PatternMark + `)`}, - {"web", ir.FTSWeb, "", `to_tsvector("body") @@ websearch_to_tsquery(` + sqlgen.PatternMark + `)`}, - {"with config", ir.FTSPlain, "english", `to_tsvector('english', "body") @@ to_tsquery('english', ` + sqlgen.PatternMark + `)`}, + {"plain no config", ir.FTSPlain, "", "text", `to_tsvector("body") @@ to_tsquery(` + sqlgen.PatternMark + `)`}, + {"plaintext", ir.FTSPlainText, "", "text", `to_tsvector("body") @@ plainto_tsquery(` + sqlgen.PatternMark + `)`}, + {"phrase", ir.FTSPhrase, "", "text", `to_tsvector("body") @@ phraseto_tsquery(` + sqlgen.PatternMark + `)`}, + {"web", ir.FTSWeb, "", "text", `to_tsvector("body") @@ websearch_to_tsquery(` + sqlgen.PatternMark + `)`}, + {"with config", ir.FTSPlain, "english", "text", `to_tsvector('english', "body") @@ to_tsquery('english', ` + sqlgen.PatternMark + `)`}, + // A tsvector column is matched directly: PostgreSQL has no + // to_tsvector(tsvector) overload, so the wrap is skipped (PostgREST's rule). + {"tsvector column", ir.FTSPlain, "", "tsvector", `"body" @@ to_tsquery(` + sqlgen.PatternMark + `)`}, + {"tsvector with config", ir.FTSPlain, "english", "tsvector", `"body" @@ to_tsquery('english', ` + sqlgen.PatternMark + `)`}, } for _, c := range cases { - frag, bind, ok := d.FullText(`"body"`, nil, c.variant, c.config, "cat") + frag, bind, ok := d.FullText(`"body"`, c.colType, nil, c.variant, c.config, "cat") if !ok { t.Fatalf("%s: ok=false", c.name) } @@ -242,3 +260,54 @@ func TestBoolValue(t *testing.T) { t.Error("BoolValue should render the PostgreSQL keywords") } } + +// A JSON object of more than 50 keys exceeds json_build_object's 100-argument +// limit, so the dialect chunks it into jsonb_build_object calls concatenated with +// || and casts the result back to json. A small object stays a single +// json_build_object call. +func TestJSONObjectChunking(t *testing.T) { + small := []sqlgen.Pair{{Key: "a", Value: `t."a"`}, {Key: "b", Value: `t."b"`}} + got := d.JSONObject(small) + if got != `json_build_object('a', t."a", 'b', t."b")` { + t.Errorf("small object = %q", got) + } + + pairs := make([]sqlgen.Pair, 120) + for i := range pairs { + name := fmt.Sprintf("c%d", i) + pairs[i] = sqlgen.Pair{Key: name, Value: "t." + name} + } + got = d.JSONObject(pairs) + if !strings.HasPrefix(got, "to_json(jsonb_build_object(") { + t.Fatalf("large object did not chunk: %q", got[:60]) + } + // 120 pairs at 50 per chunk is three chunks, joined by two || operators. + if n := strings.Count(got, "jsonb_build_object("); n != 3 { + t.Errorf("chunk count = %d, want 3", n) + } + if n := strings.Count(got, " || "); n != 2 { + t.Errorf("concat count = %d, want 2", n) + } +} + +// A JSON array in a write payload is bound by target column type: a json/jsonb +// column takes JSON text (a JSON array there is JSON, not a PG array), an array +// column takes the {a,b} array literal, and an unknown type keeps the literal. +func TestArrayArgByColumnType(t *testing.T) { + elems := []any{"a", "b"} + cases := []struct { + colType string + want string + }{ + {"jsonb", `["a","b"]`}, + {"json", `["a","b"]`}, + {"text[]", `{a,b}`}, + {"integer[]", `{a,b}`}, + {"", `{a,b}`}, + } + for _, c := range cases { + if got := d.ArrayArg(elems, c.colType); got != c.want { + t.Errorf("ArrayArg(_, %q) = %q, want %q", c.colType, got, c.want) + } + } +} diff --git a/backend/postgres/execute.go b/backend/postgres/execute.go index 02531b5..d54341b 100644 --- a/backend/postgres/execute.go +++ b/backend/postgres/execute.go @@ -5,6 +5,7 @@ import ( "encoding/json" "strconv" "strings" + "time" "github.com/jackc/pgx/v5" @@ -13,6 +14,7 @@ import ( "github.com/tamnd/dbrest/ir" "github.com/tamnd/dbrest/pgerr" "github.com/tamnd/dbrest/reqctx" + "github.com/tamnd/dbrest/rpc" "github.com/tamnd/dbrest/schema" ) @@ -49,7 +51,19 @@ func (b *Backend) Execute(ctx context.Context, plan *ir.Plan, rc *reqctx.Context // completes its batch before the main query is issued, so Parse runs as the // request role, which has the required privileges. func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) + txOpts := b.txOptions(rc, pgx.ReadOnly) + // A counted read runs the count and the page as two statements. At READ + // COMMITTED each takes its own snapshot, so a concurrent write between them can + // make the Content-Range total disagree with the rows returned. PostgREST + // reads both from one statement, hence one snapshot; pinning the transaction to + // REPEATABLE READ gives the two statements that same single snapshot. A + // read-only REPEATABLE READ transaction never raises a serialization error, so + // this only fixes consistency without adding a failure mode. A role that pins a + // stronger level (its default_transaction_isolation) keeps it. + if plan.Query.Count != ir.CountNone && !isoAtLeastRepeatableRead(txOpts.IsoLevel) { + txOpts.IsoLevel = pgx.RepeatableRead + } + tx, err := b.pool.BeginTx(ctx, txOpts) if err != nil { return nil, b.MapError(err) } @@ -60,18 +74,25 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con return nil, b.MapError(err) } - res := &streamResult{ctx: ctx, tx: tx, controls: rc.Controls()} + res := &streamResult{ctx: ctx, tx: tx, controls: rc.Controls(), loc: b.loc} + + // db-pre-request runs inside applySession and may have set response.status or + // response.headers (PostgREST lets the pre-request hook steer any response, + // including a plain GET). Those headers must be read before the body streams, so + // read them now: the GUCs are already set, and a table SELECT does not set them + // itself, so reading here captures the same value PostgREST would. + if err := readResponseControls(ctx, tx, res.controls); err != nil { + rollback() + return nil, b.MapError(err) + } if plan.Query.Count != ir.CountNone { - cst, apiErr := sqlgen.CompileCount(Dialect{}, plan.Query) - if apiErr != nil { - rollback() - return nil, apiErr - } - if err := tx.QueryRow(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil { + total, err := b.computeCount(ctx, tx, plan.Query) + if err != nil { rollback() - return nil, b.MapError(err) + return nil, err } + res.count = total res.hasCount = true } @@ -94,7 +115,7 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con // returned rows. The transaction commits unless the client requested tx=rollback, // in which case the computed representation is returned but nothing is persisted. func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadWrite}) + tx, err := b.pool.BeginTx(ctx, b.txOptions(rc, pgx.ReadWrite)) if err != nil { return nil, b.MapError(err) } @@ -107,6 +128,19 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co q := plan.Query returning := returningCols(q, plan.Rel) + // An empty column set (POST with an empty array, PATCH with an empty object) + // is a no-op: nothing is compiled or run, the affected count is zero, and the + // representation is the empty array. The HTTP layer turns that into 201/[] for + // an insert and 204 or 200/[] for an update. + if backend.IsNoOpMutation(q) { + return &bufResult{ + controls: rc.Controls(), + cols: returning, + affected: 0, + hasAff: true, + }, nil + } + // For upserts, append xmax to the RETURNING list so we can distinguish // an INSERT from an ON CONFLICT UPDATE and set the 201/200 status correctly. isUpsert := q.Kind == ir.Upsert @@ -132,37 +166,42 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co return nil, b.MapError(err) } cols := fieldNames(rows) - buf, err := drainRows(rows) + buf, err := drainRows(rows, b.loc) if err != nil { return nil, b.MapError(err) } // Strip the xmax column from the result and use it to decide insert/update status. if isUpsert && xmaxIdx >= 0 && xmaxIdx < len(cols) { - allInsert := true + inserted := 0 cleaned := make([][]any, len(buf)) for i, row := range buf { - // Check if xmax indicates an update (non-zero value means the row - // existed before and was updated via ON CONFLICT DO UPDATE). + // A zero (or empty) xmax means the row had no prior version: it was + // newly inserted. A non-zero xmax means ON CONFLICT DO UPDATE replaced + // an existing row. + rowInserted := true if xmaxIdx < len(row) { switch xv := row[xmaxIdx].(type) { case []byte: if string(xv) != "0" && string(xv) != "" { - allInsert = false + rowInserted = false } case string: if xv != "0" && xv != "" { - allInsert = false + rowInserted = false } case int64: if xv != 0 { - allInsert = false + rowInserted = false } case uint32: if xv != 0 { - allInsert = false + rowInserted = false } } } + if rowInserted { + inserted++ + } // Remove the xmax column from the row. r := make([]any, 0, len(row)-1) for j, v := range row { @@ -175,10 +214,13 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co buf = cleaned cols = append(cols[:xmaxIdx], cols[xmaxIdx+1:]...) res.controls.UpsertStatusKnown = true - res.controls.UpsertInsert = allInsert + res.controls.InsertedRows = inserted } - res.cols, res.rows = cols, buf + // The affected count is the full mutated set, taken before the + // representation is shaped: order/limit/offset bound only the returned + // body, not the mutation (v13 dropped limited update/delete). res.affected, res.hasAff = int64(len(buf)), true + res.cols, res.rows = cols, backend.ShapeWriteRepresentation(cols, buf, q) } else { tag, err := tx.Exec(ctx, st.SQL, st.Args...) if err != nil { @@ -191,6 +233,18 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co return nil, b.MapError(err) } + // Prefer: max-affected rolls an over-broad write back instead of committing. + if apiErr := backend.EnforceMaxAffected(q.Write, res.affected, res.hasAff); apiErr != nil { + return nil, apiErr + } + + // A singular write (vnd.pgrst.object+json) that touched zero or many rows + // fails closed before commit, so the deferred rollback discards it rather + // than the renderer rejecting an already-durable mutation. + if apiErr := backend.EnforceSingularWrite(q.Singular, res.affected, res.hasAff); apiErr != nil { + return nil, apiErr + } + if q.Write != nil && q.Write.Tx == ir.TxRollback { return res, nil } @@ -200,29 +254,79 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co return res, nil } +// portableCall reports whether a call lowers through the portable registry rather +// than the native catalog. A portable function carries a PortableQuery; a native +// descriptor (resolved from pg_proc for its return shape) leaves Query nil and is +// lowered by splicing literals into a SELECT * FROM schema.fn(...). A nil Func is +// also native (the function was not introspected). +func portableCall(plan *ir.Plan) bool { + return plan.Func != nil && plan.Func.Query != nil +} + // executeCall lowers and runs an RPC call. A read-only function (stable or // immutable) runs in a read-only transaction like executeRead; a volatile // function runs in a read-write transaction that commits (or rolls back under // Prefer: tx=rollback) so its side effects persist. func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { + // On the native path the function was not resolved through the portable + // registry, so plan.Func is nil. Resolve its descriptor from the introspected + // catalog now: it carries the real return shape the renderer needs (so a SETOF + // scalar is not truncated to one value and a single composite is not wrapped in + // an array) and leaves Query nil so the lowering below still uses the literal + // splice. portableCall is the dispatch predicate from here on: a portable + // function has a Query, a native descriptor does not. + if plan.Func == nil { + plan.Func = b.nativeFunc(plan.Call, b.callSchema(rc)) + } + + // On the native path the access mode follows volatility, not only the method: + // PostgREST runs a STABLE or IMMUTABLE function read-only even on POST, and + // only a VOLATILE function read-write. The registry path already set plan.ReadOnly + // from volatility, so only the native path needs the check. The mode is decided + // before lowering because it selects the volatile count mechanism below. + readOnly := plan.ReadOnly + if !portableCall(plan) { + readOnly = b.nativeCallReadOnly(plan, rc) + } + + // A volatile function must run exactly once, so its count cannot use the read + // path's separate count statement; instead count(*) OVER () rides the row query + // when the caller asked for a count, and the total is read off any returned row + // and the column dropped. This applies only to the native, read-write path: the + // portable count compiler is the read path's, and the read path counts with its + // own separate statement. + counted := !readOnly && !portableCall(plan) && plan.Call.Count != ir.CountNone + var ( st *sqlgen.Statement apiErr *pgerr.APIError ) - if plan.Func != nil { - st, apiErr = sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func) + if portableCall(plan) { + st, apiErr = sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func, sqlgen.ContextArgs(rc)) } else { - st, apiErr = b.compileNativeCall(plan.Call) + st, apiErr = b.compileNativeCall(plan.Call, b.callSchema(rc), plan.Func) + if apiErr == nil { + // A table-valued function result supports the same select, filters, + // ordering, and window a table read does; the registry path wraps for + // these inside CompileCall, so the native path wraps here too. When a + // count is requested the counted wrap also carries count(*) OVER (). + if counted { + st, apiErr = sqlgen.CompileNativeCallCountedWrap(Dialect{}, plan.Call, st) + } else { + st, apiErr = sqlgen.CompileNativeCallWrap(Dialect{}, plan.Call, st) + } + } } if apiErr != nil { return nil, apiErr } - if plan.ReadOnly { + if readOnly { return b.executeCallRead(ctx, plan, rc, st) } - tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadWrite}) + txOpts, hoisted := b.callTxOptions(plan, rc, pgx.ReadWrite) + tx, err := b.pool.BeginTx(ctx, txOpts) if err != nil { return nil, b.MapError(err) } @@ -231,6 +335,11 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con if err := applySession(ctx, tx, b, rc); err != nil { return nil, b.MapError(err) } + // Hoist the function's db-hoisted-tx-settings (statement_timeout and friends) + // after the session is set and before the call, so they bound the call itself. + if err := applyHoisted(ctx, tx, hoisted); err != nil { + return nil, b.MapError(err) + } rows, err := tx.Query(ctx, st.SQL, st.Args...) if err != nil { @@ -238,15 +347,34 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con } isVoid := isVoidResult(rows) cols := fieldNames(rows) - buf, err := drainRows(rows) + buf, err := drainRows(rows, b.loc) if err != nil { return nil, b.MapError(err) } res := &bufResult{cols: cols, rows: buf, controls: rc.Controls()} + if counted { + // The count(*) OVER () column repeats the full filtered total on every row; + // read it off the first row (an empty result is a total of zero) and drop the + // column so it never reaches the body. This is the single-execution count: the + // function ran once, in this same query. + res.cols, res.rows, res.count = extractCountWindow(cols, buf) + res.hasCount = true + } if err := readResponseControls(ctx, tx, res.controls); err != nil { return nil, b.MapError(err) } + // A portable registry function may steer the response with reserved columns + // instead of the GUCs (the engine-agnostic mechanism); lift them out here too + // so a portable function behaves the same on postgres as on an emulated + // backend. A native function sets the GUCs and carries no such columns, so + // this is a no-op for it. An invalid status or header set is PGRST112/111 + // before commit, so the deferred rollback discards the mutation. + var ctrlErr *pgerr.APIError + res.cols, res.rows, ctrlErr = backend.LiftResponseControls(res.cols, res.rows, res.controls) + if ctrlErr != nil { + return nil, ctrlErr + } // Void-returning functions produce no meaningful body; signal 204 to the // HTTP layer unless the function already set a status override via GUC. if isVoid && res.controls.Status == 0 { @@ -264,28 +392,46 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con // executeCallRead handles a stable/immutable function in a read-only transaction. // An optional count runs as a separate statement before the function call itself. +// +// Unlike a table read, the rows are buffered rather than streamed: a function +// invoked through GET can still set response.status or response.headers (the +// documented Cache-Control and 418 patterns), and those GUCs must be read back +// before the response is sent. Buffering lets readResponseControls run after the +// rows are drained and before the transaction commits, the same shape the write +// and volatile-call paths use; RPC results are small, so this costs little. func (b *Backend) executeCallRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Context, st *sqlgen.Statement) (backend.Result, error) { - tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) + txOpts, hoisted := b.callTxOptions(plan, rc, pgx.ReadOnly) + tx, err := b.pool.BeginTx(ctx, txOpts) if err != nil { return nil, b.MapError(err) } - rollback := func() { _ = tx.Rollback(ctx) } + defer func() { _ = tx.Rollback(ctx) }() if err := applySession(ctx, tx, b, rc); err != nil { - rollback() + return nil, b.MapError(err) + } + // Hoist db-hoisted-tx-settings before the count and the call so a function's + // statement_timeout bounds both, matching PostgREST. + if err := applyHoisted(ctx, tx, hoisted); err != nil { return nil, b.MapError(err) } - res := &streamResult{ctx: ctx, tx: tx, controls: rc.Controls()} + res := &bufResult{controls: rc.Controls()} if plan.Call.Count != ir.CountNone { - cst, apiErr := sqlgen.CompileCallCount(Dialect{}, plan.Call, plan.Func) + var ( + cst *sqlgen.Statement + apiErr *pgerr.APIError + ) + if portableCall(plan) { + cst, apiErr = sqlgen.CompileCallCount(Dialect{}, plan.Call, plan.Func, sqlgen.ContextArgs(rc)) + } else { + cst, apiErr = b.compileNativeCallCount(plan.Call, b.callSchema(rc), plan.Func) + } if apiErr != nil { - rollback() return nil, apiErr } if err := tx.QueryRow(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil { - rollback() return nil, b.MapError(err) } res.hasCount = true @@ -293,14 +439,54 @@ func (b *Backend) executeCallRead(ctx context.Context, plan *ir.Plan, rc *reqctx rows, err := tx.Query(ctx, st.SQL, st.Args...) if err != nil { - rollback() return nil, b.MapError(err) } - res.rows = rows - res.cols = fieldNames(rows) + isVoid := isVoidResult(rows) + cols := fieldNames(rows) + buf, err := drainRows(rows, b.loc) + if err != nil { + return nil, b.MapError(err) + } + res.cols, res.rows = cols, buf + + // Read response.status / response.headers a stable function may have set, then + // lift any portable-registry reserved control columns, matching the volatile + // and write paths so a GET to a function steers its response the same way. + if err := readResponseControls(ctx, tx, res.controls); err != nil { + return nil, b.MapError(err) + } + var ctrlErr *pgerr.APIError + res.cols, res.rows, ctrlErr = backend.LiftResponseControls(res.cols, res.rows, res.controls) + if ctrlErr != nil { + return nil, ctrlErr + } + if isVoid && res.controls.Status == 0 { + res.controls.Status = 204 + } + + if err := tx.Commit(ctx); err != nil { + return nil, b.MapError(err) + } return res, nil } +// callSchema is the schema a native RPC resolves in: the request's negotiated +// profile (Accept-Profile on GET/HEAD, Content-Profile on POST) when set, else +// the first configured search-path schema, else public. The HTTP layer already +// rejected a profile outside the exposed list with PGRST106, so rc.Schema is a +// vetted member of the configured set by the time it reaches here. This lets a +// multi-schema deployment dispatch /rpc to the function in the active schema +// instead of always calling the first one. +func (b *Backend) callSchema(rc *reqctx.Context) string { + if rc != nil && rc.Schema != "" { + return rc.Schema + } + if len(b.searchPath) > 0 { + return b.searchPath[0] + } + return "public" +} + // compileNativeCall generates the PostgreSQL function-call SQL for the native // RPC path (NativeRPC=true), where there is no declared function registry. It // renders SELECT * FROM schema.fn(arg := , ...) with values embedded @@ -308,12 +494,7 @@ func (b *Backend) executeCallRead(ctx context.Context, plan *ir.Plan, rc *reqctx // signature and the call does not depend on pgx OID mapping. String values are // single-quote escaped; numeric JSON values are written as numeric literals; // booleans become TRUE/FALSE; null or absent values become NULL. -func (b *Backend) compileNativeCall(c *ir.Call) (*sqlgen.Statement, *pgerr.APIError) { - schema := "public" - if len(b.searchPath) > 0 { - schema = b.searchPath[0] - } - +func (b *Backend) compileNativeCall(c *ir.Call, schema string, fn *rpc.Function) (*sqlgen.Statement, *pgerr.APIError) { d := Dialect{} var sb strings.Builder sb.WriteString("SELECT * FROM ") @@ -322,20 +503,63 @@ func (b *Backend) compileNativeCall(c *ir.Call) (*sqlgen.Statement, *pgerr.APIEr sb.WriteString(d.QuoteIdent(c.Function.Name)) sb.WriteString("(") - i := 0 - for name, val := range c.Args { - if i > 0 { - sb.WriteString(", ") + if fn != nil && len(fn.Params) > 0 { + // With a resolved descriptor the arguments are spliced in declared parameter + // order, which keeps the generated SQL text stable across identical requests + // (Go map iteration is randomized) so the pgx statement cache hits. A + // single-raw-body parameter is unnamed in the catalog, so it is passed + // positionally; every other argument uses the name := value form. Omitted + // optional arguments are left out, taking the function's default. + first := true + for _, p := range fn.Params { + val, ok := c.Args[p.Name] + if !ok { + continue + } + if !first { + sb.WriteString(", ") + } + if !p.RawBody { + sb.WriteString(d.QuoteIdent(p.Name)) + sb.WriteString(" := ") + } + appendNativeArg(&sb, val) + first = false + } + } else { + // No descriptor (a function not introspected): fall back to the named form in + // map order. The result is correct, only non-deterministic in column text. + i := 0 + for name, val := range c.Args { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(d.QuoteIdent(name)) + sb.WriteString(" := ") + appendNativeArg(&sb, val) + i++ } - sb.WriteString(d.QuoteIdent(name)) - sb.WriteString(" := ") - appendNativeArg(&sb, val) - i++ } sb.WriteString(")") return &sqlgen.Statement{SQL: sb.String()}, nil } +// compileNativeCallCount wraps the native function call in a count, the exact-count +// statement for a native RPC. There is no registry function to drive +// sqlgen.CompileCallCount (plan.Func is nil), so the count is built here over the +// same SELECT * FROM schema.fn(...) the row query runs; a scalar-returning function +// yields its single row and counts as one, a setof yields its rows. +func (b *Backend) compileNativeCallCount(c *ir.Call, schema string, fn *rpc.Function) (*sqlgen.Statement, *pgerr.APIError) { + inner, apiErr := b.compileNativeCall(c, schema, fn) + if apiErr != nil { + return nil, apiErr + } + // Count over the same post-filter the row query applies, so a count=exact + // total matches the rows returned (the select, order, and window do not + // change the count). + return sqlgen.CompileNativeCallCountWrap(Dialect{}, c, inner) +} + // appendNativeArg writes one function argument as a safe SQL literal. Numbers // are written unquoted so PostgreSQL resolves their type from context; strings // use single-quote escaping; booleans are TRUE/FALSE; anything else (including @@ -359,11 +583,16 @@ func appendNativeArg(sb *strings.Builder, val ir.Value) { sb.WriteString("FALSE") } default: - // JSON object / array: pass as json literal. + // JSON object / array: splice the encoded text as an UNTYPED literal. + // PostgreSQL's function resolution applies implicit casts only, and the + // json->jsonb cast is assignment-context, so a '...'::json literal fails + // to match an fn(jsonb) signature (42883 -> 404). An unknown-type literal + // instead coerces to json, jsonb, or text alike, which is also why the + // string/number/bool branches already work against any parameter type. enc, _ := json.Marshal(v) sb.WriteString("'") sb.WriteString(strings.ReplaceAll(string(enc), "'", "''")) - sb.WriteString("'::json") + sb.WriteString("'") } return } @@ -399,6 +628,9 @@ func returningCols(q *ir.Query, rel *schema.Relation) []string { return nil } if q.Write != nil && q.Write.Return == ir.ReturnRepresentation { + if cols := q.ProjectedColumns(); cols != nil { + return cols + } return rel.ColumnNames() } if q.Kind == ir.Insert || q.Kind == ir.Upsert { @@ -425,13 +657,105 @@ func fieldNames(rows pgx.Rows) []string { return names } -// ExplainRead runs EXPLAIN (FORMAT JSON) on the read query and returns the raw -// JSON plan from PostgreSQL. When analyze is true EXPLAIN ANALYZE is used -// instead, which also executes the query and includes timing. The request runs -// in a read-only transaction with the full session setup (role + GUCs) so the -// planner sees the same context as a real request. -func (b *Backend) ExplainRead(ctx context.Context, p *ir.Plan, rc *reqctx.Context, analyze bool) ([]byte, error) { - tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) +// extractCountWindow pulls the count(*) OVER () total out of a buffered result that +// carries the _pgrst_count window column, returning the columns and rows with that +// column removed and the total. The window repeats the full filtered total on every +// row, so the first row carries it; an empty result is a total of zero. The rows are +// rewritten in place: each is reduced to a fresh slice that excludes the count cell. +func extractCountWindow(cols []string, buf [][]any) ([]string, [][]any, int64) { + idx := -1 + for i, c := range cols { + if c == sqlgen.CountColName { + idx = i + break + } + } + if idx < 0 { + return cols, buf, 0 + } + var total int64 + if len(buf) > 0 && idx < len(buf[0]) { + switch n := buf[0][idx].(type) { + case int64: + total = n + case int32: + total = int64(n) + case int: + total = int64(n) + } + } + cols = append(cols[:idx:idx], cols[idx+1:]...) + for i, row := range buf { + if idx < len(row) { + buf[i] = append(row[:idx:idx], row[idx+1:]...) + } + } + return cols, buf, total +} + +// explainPrefix builds the "EXPLAIN (...) " clause for a plan request from the +// parsed options: the output format plus whichever of analyze/verbose/settings/ +// buffers/wal were asked for, in PostgreSQL's option syntax. +func explainPrefix(opts backend.PlanOptions) string { + args := []string{"FORMAT TEXT"} + if opts.Format == backend.PlanJSON { + args[0] = "FORMAT JSON" + } + for _, o := range []struct { + on bool + name string + }{ + {opts.Analyze, "ANALYZE"}, {opts.Verbose, "VERBOSE"}, + {opts.Settings, "SETTINGS"}, {opts.Buffers, "BUFFERS"}, {opts.Wal, "WAL"}, + } { + if o.on { + args = append(args, o.name) + } + } + return "EXPLAIN (" + strings.Join(args, ", ") + ") " +} + +// runExplain executes the prefixed statement and returns the plan bytes. The +// JSON format yields a single document; the text format yields one row per plan +// line, which are joined with newlines into the text body PostgREST returns. +func (b *Backend) runExplain(ctx context.Context, tx pgx.Tx, opts backend.PlanOptions, sql string, args []any) ([]byte, error) { + rows, err := tx.Query(ctx, explainPrefix(opts)+sql, args...) + if err != nil { + return nil, b.MapError(err) + } + defer rows.Close() + if opts.Format == backend.PlanJSON { + var plan []byte + for rows.Next() { + if err := rows.Scan(&plan); err != nil { + return nil, b.MapError(err) + } + } + if err := rows.Err(); err != nil { + return nil, b.MapError(err) + } + return plan, nil + } + var lines []string + for rows.Next() { + var line string + if err := rows.Scan(&line); err != nil { + return nil, b.MapError(err) + } + lines = append(lines, line) + } + if err := rows.Err(); err != nil { + return nil, b.MapError(err) + } + return []byte(strings.Join(lines, "\n")), nil +} + +// ExplainRead runs EXPLAIN on the read query and returns the plan in the +// requested format. The request runs in a read-only transaction with the full +// session setup (role + GUCs) so the planner sees the same context as a real +// request. +func (b *Backend) ExplainRead(ctx context.Context, p *ir.Plan, rc *reqctx.Context, opts backend.PlanOptions) ([]byte, error) { + tx, err := b.pool.BeginTx(ctx, b.txOptions(rc, pgx.ReadOnly)) if err != nil { return nil, b.MapError(err) } @@ -445,34 +769,71 @@ func (b *Backend) ExplainRead(ctx context.Context, p *ir.Plan, rc *reqctx.Contex if apiErr != nil { return nil, apiErr } + return b.runExplain(ctx, tx, opts, st.SQL, st.Args) +} - var prefix string - if analyze { - prefix = "EXPLAIN (ANALYZE, FORMAT JSON) " - } else { - prefix = "EXPLAIN (FORMAT JSON) " +// ExplainWrite runs EXPLAIN on the mutation. It uses a read-write transaction +// that always rolls back, so EXPLAIN ANALYZE (which executes the statement) +// leaves nothing behind, matching PostgREST's plan-only contract. +func (b *Backend) ExplainWrite(ctx context.Context, p *ir.Plan, rc *reqctx.Context, opts backend.PlanOptions) ([]byte, error) { + tx, err := b.pool.BeginTx(ctx, b.txOptions(rc, "")) + if err != nil { + return nil, b.MapError(err) } - rows, err := tx.Query(ctx, prefix+st.SQL, st.Args...) + defer func() { _ = tx.Rollback(ctx) }() + + if err := applySession(ctx, tx, b, rc); err != nil { + return nil, b.MapError(err) + } + + st, apiErr := compileWrite(p.Query, returningCols(p.Query, p.Rel)) + if apiErr != nil { + return nil, apiErr + } + return b.runExplain(ctx, tx, opts, st.SQL, st.Args) +} + +// ExplainCall runs EXPLAIN on the RPC function call. The read-write transaction +// rolls back, so an EXPLAIN ANALYZE of a volatile function discards its effects. +func (b *Backend) ExplainCall(ctx context.Context, p *ir.Plan, rc *reqctx.Context, opts backend.PlanOptions) ([]byte, error) { + tx, err := b.pool.BeginTx(ctx, b.txOptions(rc, "")) if err != nil { return nil, b.MapError(err) } - defer rows.Close() - var plan []byte - for rows.Next() { - if err := rows.Scan(&plan); err != nil { - return nil, b.MapError(err) + defer func() { _ = tx.Rollback(ctx) }() + + if err := applySession(ctx, tx, b, rc); err != nil { + return nil, b.MapError(err) + } + + // EXPLAIN compiles the call the same way Execute does, so a native function + // (no registry Query) is planned through the literal-splice path rather than + // the registry compiler, matching what the call would actually run. + if p.Func == nil { + p.Func = b.nativeFunc(p.Call, b.callSchema(rc)) + } + var ( + st *sqlgen.Statement + apiErr *pgerr.APIError + ) + if portableCall(p) { + st, apiErr = sqlgen.CompileCall(Dialect{}, p.Call, p.Func, sqlgen.ContextArgs(rc)) + } else { + st, apiErr = b.compileNativeCall(p.Call, b.callSchema(rc), p.Func) + if apiErr == nil { + st, apiErr = sqlgen.CompileNativeCallWrap(Dialect{}, p.Call, st) } } - if err := rows.Err(); err != nil { - return nil, b.MapError(err) + if apiErr != nil { + return nil, apiErr } - return plan, nil + return b.runExplain(ctx, tx, opts, st.SQL, st.Args) } // drainRows reads every row of a pgx cursor into memory, normalizing values so // json/jsonb, bytea, and date columns render correctly. The rows are closed by // drainRows; the caller must not close them again. -func drainRows(rows pgx.Rows) ([][]any, error) { +func drainRows(rows pgx.Rows, loc *time.Location) ([][]any, error) { defer rows.Close() fields := rows.FieldDescriptions() var out [][]any @@ -481,7 +842,7 @@ func drainRows(rows pgx.Rows) ([][]any, error) { if err != nil { return nil, err } - out = append(out, normalizeValues(vals, fields)) + out = append(out, normalizeValues(vals, fields, loc)) } return out, rows.Err() } diff --git a/backend/postgres/fulltext.go b/backend/postgres/fulltext.go index 59958ba..ef4edf6 100644 --- a/backend/postgres/fulltext.go +++ b/backend/postgres/fulltext.go @@ -24,7 +24,7 @@ import ( // names, not raw client input, and the Dialect interface carries a bound operand // only for the query value. An empty config omits the argument, letting the // server's default_text_search_config apply, which is the PostgREST default. -func (Dialect) FullText(col string, _ *sqlgen.FullTextRef, variant ir.FTSVariant, config, value string) (string, string, bool) { +func (Dialect) FullText(col, colType string, _ *sqlgen.FullTextRef, variant ir.FTSVariant, config, value string) (string, string, bool) { ctor := map[ir.FTSVariant]string{ ir.FTSPlain: "to_tsquery", ir.FTSPlainText: "plainto_tsquery", @@ -36,7 +36,15 @@ func (Dialect) FullText(col string, _ *sqlgen.FullTextRef, variant ir.FTSVariant if config != "" { cfg = sqlLiteral(config) + ", " } - frag := "to_tsvector(" + cfg + col + ") @@ " + ctor + "(" + cfg + sqlgen.PatternMark + ")" + // A column already typed tsvector is matched directly: PostgreSQL has no + // to_tsvector(tsvector) overload, so wrapping it would raise 42883. This is + // PostgREST's "Do not apply to_tsvector to tsvector types" rule. Text and + // json/jsonb columns are wrapped so the server builds the vector on the fly. + lhs := "to_tsvector(" + cfg + col + ")" + if colType == "tsvector" { + lhs = col + } + frag := lhs + " @@ " + ctor + "(" + cfg + sqlgen.PatternMark + ")" // The value carries the variant's grammar (boolean operators, quoted phrases, // a web-style string), which PostgreSQL parses itself, so it is bound verbatim // rather than pre-translated the way the FTS5 dialect must translate it. diff --git a/backend/postgres/funcs.go b/backend/postgres/funcs.go new file mode 100644 index 0000000..5f64791 --- /dev/null +++ b/backend/postgres/funcs.go @@ -0,0 +1,137 @@ +package postgres + +import ( + "context" + + "github.com/tamnd/dbrest/rpc" +) + +// loadFunctionRegistry reads the full callable signature of every function in the +// exposed schemas from pg_proc and materializes one rpc.Registry per schema. This +// is the function half of the schema cache PostgREST keeps: with it the native RPC +// path resolves overloads (raising PGRST202 for no match and PGRST203 for an +// ambiguous one), partitions GET arguments from result filters by parameter name, +// and runs a POST to a STABLE or IMMUTABLE function read-only, all through the same +// planner code the portable registry uses. The descriptors carry Query nil, which +// keeps the executor lowering them through the native splice path. +// +// One row per input argument would be simpler to scan but loses the per-function +// grouping; instead each function row carries its input names and type names as +// arrays, reconstructed in SQL so the OUT and TABLE columns (which are not call +// arguments) are filtered out by argument mode. proargtypes already lists only the +// input arguments in order, so the type array needs no filtering; proargnames lists +// every argument, so it is filtered to the input modes ('i' in, 'b' inout, 'v' +// variadic) when proargmodes is present. +func (b *Backend) loadFunctionRegistry(ctx context.Context, schemas []string) (map[string]rpc.Registry, error) { + const q = ` +SELECT n.nspname, p.proname, + p.provolatile::text, + p.proretset, p.prorettype, rt.typtype::text, rt.typname, + p.pronargs, p.pronargdefaults, (p.provariadic <> 0), + (SELECT array_agg(tt.typname ORDER BY u.ord) + FROM unnest(p.proargtypes) WITH ORDINALITY AS u(typoid, ord) + JOIN pg_type tt ON tt.oid = u.typoid) AS in_types, + CASE WHEN p.proargmodes IS NULL THEN p.proargnames + ELSE (SELECT array_agg(nm ORDER BY ord) + FROM unnest(p.proargnames, p.proargmodes) WITH ORDINALITY AS m(nm, mode, ord) + WHERE mode IN ('i', 'b', 'v')) END AS in_names, + (p.prosecdef) AS secdef, + COALESCE(d.description, '') AS comment + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + JOIN pg_type rt ON rt.oid = p.prorettype + LEFT JOIN pg_description d ON d.objoid = p.oid AND d.classoid = 'pg_proc'::regclass AND d.objsubid = 0 + WHERE n.nspname = ANY($1) + AND p.prokind = 'f' + ORDER BY n.nspname, p.proname, p.oid` + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + bySchema := map[string][]*rpc.Function{} + for rows.Next() { + var ( + nsp, name, vol, retTyptype, retTypname string + retset, variadic, secdef bool + rettype uint32 + nargs, ndefaults int + inTypes []string + inNames []*string + comment string + ) + if err := rows.Scan(&nsp, &name, &vol, &retset, &rettype, &retTyptype, &retTypname, + &nargs, &ndefaults, &variadic, &inTypes, &inNames, &secdef, &comment); err != nil { + return nil, err + } + fn := &rpc.Function{ + Name: name, + Returns: returnShapeFor(retset, rettype, retTyptype, retTypname), + Volatility: volatilityFromChar(vol), + Comment: comment, + } + if secdef { + fn.Security = rpc.Definer + } + fn.Params = buildParams(nargs, ndefaults, variadic, inTypes, inNames) + bySchema[nsp] = append(bySchema[nsp], fn) + } + if err := rows.Err(); err != nil { + return nil, err + } + + out := make(map[string]rpc.Registry, len(bySchema)) + for schema, fns := range bySchema { + out[schema] = rpc.NewStaticRegistry(fns) + } + return out, nil +} + +// buildParams reconstructs a function's input parameters in signature order from +// the pg_proc facts. proargtypes (inTypes) holds exactly the input argument types +// in order, so its length is the input arity; inNames is the input names in the +// same order, possibly nil (no names at all) or holding a nil entry for an unnamed +// argument. The trailing ndefaults inputs are optional, and a variadic function's +// last input collects its trailing arguments. A function whose single input is +// unnamed and of a raw-body type takes the whole request body as that argument. +func buildParams(nargs, ndefaults int, variadic bool, inTypes []string, inNames []*string) []rpc.Param { + if len(inTypes) == 0 { + return nil + } + params := make([]rpc.Param, 0, len(inTypes)) + for i, typ := range inTypes { + name := "" + if i < len(inNames) && inNames[i] != nil { + name = *inNames[i] + } + params = append(params, rpc.Param{ + Name: name, + Type: typ, + Optional: i >= len(inTypes)-ndefaults, + Variadic: variadic && i == len(inTypes)-1, + }) + } + // The single-unnamed-parameter form: one input, no name, and a body-shaped type. + // PostgREST binds the whole request body to it rather than reading the body as a + // JSON object of named arguments. + if len(params) == 1 && params[0].Name == "" && isRawBodyType(inTypes[0]) { + params[0].RawBody = true + // The lowering references the argument by name, so give the unnamed parameter + // a stable synthetic name; the wire contract is positional, so the spelling is + // internal only. + params[0].Name = "__raw_body" + } + return params +} + +// isRawBodyType reports whether a type can stand as a single-unnamed-parameter raw +// body. PostgREST accepts json, jsonb, text, xml, and bytea in this position: the +// request body is bound whole, decoded by Content-Type. +func isRawBodyType(typname string) bool { + switch typname { + case "json", "jsonb", "text", "xml", "bytea": + return true + } + return false +} diff --git a/backend/postgres/funcs_reg_test.go b/backend/postgres/funcs_reg_test.go new file mode 100644 index 0000000..e6756c3 --- /dev/null +++ b/backend/postgres/funcs_reg_test.go @@ -0,0 +1,111 @@ +package postgres + +import ( + "reflect" + "testing" + + "github.com/tamnd/dbrest/rpc" +) + +func strp(s string) *string { return &s } + +// TestBuildParams covers the input-parameter reconstruction from pg_proc facts that +// loadFunctionRegistry feeds it (finding 03-P03): proargtypes is input-only and in +// order, the trailing pronargdefaults inputs are optional, a variadic function's +// last input collects its tail, and a lone unnamed body-typed input is a raw body. +func TestBuildParams(t *testing.T) { + cases := []struct { + name string + nargs int + ndefaults int + variadic bool + inTypes []string + inNames []*string + want []rpc.Param + }{ + { + name: "no arguments", + inTypes: nil, + want: nil, + }, + { + name: "two required", + nargs: 2, + inTypes: []string{"int4", "int4"}, + inNames: []*string{strp("a"), strp("b")}, + want: []rpc.Param{ + {Name: "a", Type: "int4"}, + {Name: "b", Type: "int4"}, + }, + }, + { + name: "trailing default is optional", + nargs: 2, + ndefaults: 1, + inTypes: []string{"text", "text"}, + inNames: []*string{strp("name"), strp("greeting")}, + want: []rpc.Param{ + {Name: "name", Type: "text"}, + {Name: "greeting", Type: "text", Optional: true}, + }, + }, + { + name: "variadic last input", + nargs: 1, + variadic: true, + inTypes: []string{"_int4"}, + inNames: []*string{strp("vals")}, + want: []rpc.Param{ + {Name: "vals", Type: "_int4", Variadic: true}, + }, + }, + { + name: "single unnamed json is a raw body", + nargs: 1, + inTypes: []string{"json"}, + inNames: nil, + want: []rpc.Param{ + {Name: "__raw_body", Type: "json", RawBody: true}, + }, + }, + { + name: "single unnamed non-body type is an ordinary unnamed arg", + nargs: 1, + inTypes: []string{"int4"}, + inNames: nil, + want: []rpc.Param{ + {Name: "", Type: "int4"}, + }, + }, + { + name: "named json is not a raw body", + nargs: 1, + inTypes: []string{"jsonb"}, + inNames: []*string{strp("payload")}, + want: []rpc.Param{ + {Name: "payload", Type: "jsonb"}, + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := buildParams(c.nargs, c.ndefaults, c.variadic, c.inTypes, c.inNames) + if !reflect.DeepEqual(got, c.want) { + t.Errorf("buildParams = %+v\nwant %+v", got, c.want) + } + }) + } +} + +func TestIsRawBodyType(t *testing.T) { + for _, ok := range []string{"json", "jsonb", "text", "xml", "bytea"} { + if !isRawBodyType(ok) { + t.Errorf("isRawBodyType(%q) = false, want true", ok) + } + } + for _, no := range []string{"int4", "numeric", "uuid", "timestamptz", ""} { + if isRawBodyType(no) { + t.Errorf("isRawBodyType(%q) = true, want false", no) + } + } +} diff --git a/backend/postgres/funcs_test.go b/backend/postgres/funcs_test.go new file mode 100644 index 0000000..5ebfcdb --- /dev/null +++ b/backend/postgres/funcs_test.go @@ -0,0 +1,94 @@ +package postgres + +import ( + "testing" + + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/rpc" +) + +// call builds a minimal ir.Call naming a function, for the native-resolution tests. +func call(name string) *ir.Call { return &ir.Call{Function: ir.Ref{Name: name}} } + +// TestReturnShapeFor covers finding 03-P06: the native RPC return shape is taken +// from pg_proc facts (proretset and the return type's class), not guessed from +// column names. A composite or record return is object-shaped; everything else is +// scalar-shaped; proretset then decides array vs single. +func TestReturnShapeFor(t *testing.T) { + cases := []struct { + name string + retset bool + rettype uint32 + typtype string + typname string + want rpc.ReturnKind + }{ + {"scalar integer", false, 23, "b", "int4", rpc.ReturnScalar}, + {"setof integer", true, 23, "b", "int4", rpc.ReturnSetOf}, + {"single composite", false, 16385, "c", "point_2d", rpc.ReturnObject}, + {"setof composite", true, 16385, "c", "point_2d", rpc.ReturnTable}, + {"returns table", true, oidRecord, "p", "record", rpc.ReturnTable}, + {"returns record single", false, oidRecord, "p", "record", rpc.ReturnObject}, + {"returns void", false, oidVoid, "p", "void", rpc.ReturnVoid}, + {"setof void stays void", true, oidVoid, "p", "void", rpc.ReturnVoid}, + {"scalar enum", false, 16400, "e", "mood", rpc.ReturnScalar}, + {"scalar json", false, 114, "b", "json", rpc.ReturnScalar}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := returnShapeFor(c.retset, c.rettype, c.typtype, c.typname) + if got.Kind != c.want { + t.Errorf("returnShapeFor(%v,%d,%q) Kind = %v, want %v", c.retset, c.rettype, c.typtype, got.Kind, c.want) + } + }) + } +} + +// A scalar or setof-scalar shape carries the return type name so the renderer can +// embed a json/jsonb value verbatim; an object/table/void shape needs no Type. +func TestReturnShapeForType(t *testing.T) { + if got := returnShapeFor(false, 114, "b", "json"); got.Type != "json" { + t.Errorf("scalar Type = %q, want json", got.Type) + } + if got := returnShapeFor(true, 23, "b", "int4"); got.Type != "int4" { + t.Errorf("setof Type = %q, want int4", got.Type) + } + if got := returnShapeFor(false, 16385, "c", "point_2d"); got.Type != "" { + t.Errorf("object Type = %q, want empty", got.Type) + } +} + +// nativeFunc returns nil when the catalog has no entry for the call, so the +// renderer keeps its column-name fallback rather than asserting a wrong shape. +func TestNativeFuncUnknown(t *testing.T) { + b := &Backend{funcRet: map[string]rpc.ReturnShape{}} + if got := b.nativeFunc(call("missing"), "public"); got != nil { + t.Errorf("nativeFunc(missing) = %v, want nil", got) + } + b2 := &Backend{} // funcRet nil (never introspected) + if got := b2.nativeFunc(call("anything"), "public"); got != nil { + t.Errorf("nativeFunc with nil funcRet = %v, want nil", got) + } +} + +// nativeFunc builds a native descriptor (Query nil, so portableCall is false) +// carrying the introspected return shape and volatility. +func TestNativeFuncResolved(t *testing.T) { + b := &Backend{ + funcRet: map[string]rpc.ReturnShape{"public.ret_point": {Kind: rpc.ReturnObject}}, + funcVol: map[string]rpc.Volatility{"public.ret_point": rpc.Stable}, + } + fn := b.nativeFunc(call("ret_point"), "public") + if fn == nil { + t.Fatal("nativeFunc(ret_point) = nil, want descriptor") + } + if fn.Returns.Kind != rpc.ReturnObject { + t.Errorf("Returns.Kind = %v, want ReturnObject", fn.Returns.Kind) + } + if fn.Volatility != rpc.Stable { + t.Errorf("Volatility = %v, want Stable", fn.Volatility) + } + if fn.Query != nil { + t.Error("native descriptor must leave Query nil so it lowers through the splice path") + } +} diff --git a/backend/postgres/hoist.go b/backend/postgres/hoist.go new file mode 100644 index 0000000..c3841e8 --- /dev/null +++ b/backend/postgres/hoist.go @@ -0,0 +1,120 @@ +package postgres + +import ( + "context" + "slices" + "sort" + "strings" + + "github.com/jackc/pgx/v5" + + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/reqctx" +) + +// loadFunctionProconfig reads pg_proc.proconfig (a function's SET clause) for +// every function in the exposed schemas into a map keyed by "schema.name", so an +// RPC call can hoist the settings db-hoisted-tx-settings selects to the +// transaction the way PostgREST does. proconfig is a text[] of "name=value" +// entries; the value half can itself contain '=', so only the first '=' splits. +// +// A name with several overloads collapses to one key: the entries are appended +// and hoistFor takes the last value per setting, so overloads that disagree on a +// hoisted setting resolve to the last one introspected. This is the documented +// limit of static introspection (the actual overload is resolved by argument +// types at call time, which the static map does not model). +func (b *Backend) loadFunctionProconfig(ctx context.Context, schemas []string) (map[string][]roleSetting, error) { + const q = ` +SELECT n.nspname, p.proname, p.proconfig + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = ANY($1) AND p.proconfig IS NOT NULL` + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + out := map[string][]roleSetting{} + for rows.Next() { + var nsp, name string + var cfg []string + if err := rows.Scan(&nsp, &name, &cfg); err != nil { + return nil, err + } + key := nsp + "." + name + for _, kv := range cfg { + i := strings.IndexByte(kv, '=') + if i <= 0 { + continue + } + out[key] = append(out[key], roleSetting{name: kv[:i], value: kv[i+1:]}) + } + } + return out, rows.Err() +} + +// hoistFor returns the hoisted transaction settings for an RPC call: the function +// SET options whose names are in db-hoisted-tx-settings, with the last value per +// name winning. default_transaction_isolation is split out as an isolation level +// because it cannot be set with set_config once the transaction has run a +// statement; the caller applies it at BeginTx. The remaining settings are +// returned sorted by name so the replay order is deterministic. +func (b *Backend) hoistFor(plan *ir.Plan, rc *reqctx.Context) ([]roleSetting, pgx.TxIsoLevel) { + if b.funcProconfig == nil || len(b.hoistedTxSettings) == 0 || plan.Call == nil { + return nil, "" + } + key := b.callSchema(rc) + "." + plan.Call.Function.Name + raw := b.funcProconfig[key] + if len(raw) == 0 { + return nil, "" + } + + picked := map[string]string{} + for _, s := range raw { + if slices.Contains(b.hoistedTxSettings, s.name) { + picked[s.name] = s.value // last wins + } + } + + var iso pgx.TxIsoLevel + var sets []roleSetting + for name, val := range picked { + if name == "default_transaction_isolation" { + if lvl, ok := isoLevelFromName(val); ok { + iso = lvl + } + continue + } + sets = append(sets, roleSetting{name: name, value: val}) + } + sort.Slice(sets, func(i, j int) bool { return sets[i].name < sets[j].name }) + return sets, iso +} + +// applyHoisted replays the hoisted settings as transaction-scoped settings after +// the session is set up and before the call statement, so they override the role +// and connection settings for the whole statement (including the count query of a +// set-returning call), matching PostgREST. default_transaction_isolation is not +// here; it is applied at BeginTx by the caller. +func applyHoisted(ctx context.Context, tx pgx.Tx, sets []roleSetting) error { + for _, s := range sets { + if _, err := tx.Exec(ctx, "SELECT set_config($1,$2,true)", s.name, s.value); err != nil { + return err + } + } + return nil +} + +// callTxOptions builds the transaction options for an RPC call: the role's +// options (access mode plus any role default_transaction_isolation) with the +// hoisted default_transaction_isolation overriding the role's, since a function's +// SET clause takes precedence over the role and connection settings. +func (b *Backend) callTxOptions(plan *ir.Plan, rc *reqctx.Context, mode pgx.TxAccessMode) (pgx.TxOptions, []roleSetting) { + opts := b.txOptions(rc, mode) + sets, iso := b.hoistFor(plan, rc) + if iso != "" { + opts.IsoLevel = iso + } + return opts, sets +} diff --git a/backend/postgres/integration_test.go b/backend/postgres/integration_test.go index 5d5c916..37d4728 100644 --- a/backend/postgres/integration_test.go +++ b/backend/postgres/integration_test.go @@ -2,12 +2,21 @@ package postgres_test import ( "context" + "encoding/json" + "fmt" "os" + "reflect" + "strings" "testing" "github.com/tamnd/dbrest/backend/postgres" "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/openapi" + "github.com/tamnd/dbrest/pgerr" + planpkg "github.com/tamnd/dbrest/plan" "github.com/tamnd/dbrest/reqctx" + "github.com/tamnd/dbrest/rpc" + "github.com/tamnd/dbrest/schema" ) // dsn returns the DSN for the integration tests. The tests are skipped entirely @@ -117,7 +126,7 @@ func TestIntegrationReadWrite(t *testing.T) { Kind: ir.Insert, Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_rw"}, Write: &ir.WriteSpec{ - Rows: []map[string]ir.Value{{"val": {Text: "hello"}}}, + Rows: []map[string]ir.Value{{"val": {JSON: "hello"}}}, Columns: []string{"val"}, Return: ir.ReturnMinimal, }, @@ -167,6 +176,2537 @@ func TestIntegrationReadWrite(t *testing.T) { } } +// TestIntegrationNativeCallPostFilter proves the native RPC path (plan.Func nil) +// applies select, filter, order, limit, and an exact count to a set-returning +// function's rows, the same post-filter a table read enjoys. Before the fix the +// native path ran SELECT * FROM fn(...) and silently dropped all of these. +// Finding 05-M08 / P01. +func TestIntegrationNativeCallPostFilter(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE OR REPLACE FUNCTION _dbrest_test_films() + RETURNS TABLE(id int, title text, year int) + LANGUAGE sql STABLE AS $$ + SELECT * FROM (VALUES + (1, 'Metropolis', 1927), + (2, 'Blade Runner', 1982), + (3, 'Arrival', 2016) + ) AS t(id, title, year) + $$`); err != nil { + t.Fatalf("seed function: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP FUNCTION IF EXISTS _dbrest_test_films()") + }) + + // year >= 1982, ordered year desc, limit 1, projecting title only. Of the two + // matching rows (Blade Runner 1982, Arrival 2016) the top of a year-desc order + // is Arrival, and limit 1 keeps just that one. + call := &ir.Call{ + Function: ir.Ref{Schema: "public", Name: "_dbrest_test_films"}, + Args: map[string]ir.Value{}, + ReadOnly: true, + Select: []ir.SelectItem{ir.Column{Path: []string{"title"}}}, + Where: condPtr(ir.Compare{Path: []string{"year"}, Op: ir.OpGte, Value: ir.Value{Text: "1982"}}), + Order: []ir.OrderTerm{{Path: []string{"year"}, Desc: true}}, + Limit: intPtr(1), + Count: ir.CountExact, + } + plan := &ir.Plan{ReadOnly: true, Call: call} + + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/rpc/_dbrest_test_films"}) + if err != nil { + t.Fatalf("Execute(native call): %v", err) + } + + // The count is exact over the filtered set: two rows match year >= 1982. + if c, ok := res.Count(); !ok || c != 2 { + t.Errorf("Count = (%d, %v), want (2, true) over the filtered rows", c, ok) + } + + rs := res.Rows() + defer rs.Close() + var titles []string + var cols int + for rs.Next() { + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + cols = len(vals) + titles = append(titles, vals[0].(string)) + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if len(titles) != 1 { + t.Fatalf("limit 1 returned %d rows, want 1: %v", len(titles), titles) + } + if cols != 1 { + t.Errorf("select=title returned %d columns, want 1", cols) + } + if titles[0] != "Arrival" { + t.Errorf("order=year.desc top row = %q, want Arrival", titles[0]) + } +} + +// TestIntegrationNativeReturnShapes covers finding 03-P06: the native RPC path +// resolves a function's return shape from pg_proc (proretset plus the return +// type's class) and carries it on plan.Func, so the renderer shapes the body by +// the real return kind instead of guessing from column names. Each seeded function +// exercises one shape; Execute must populate plan.Func with the matching kind. +func TestIntegrationNativeReturnShapes(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE OR REPLACE FUNCTION _dbrest_ret_setof_integers() RETURNS SETOF integer + LANGUAGE sql STABLE AS $$ SELECT * FROM (VALUES (1),(2),(3)) v(n) $$; + CREATE OR REPLACE FUNCTION _dbrest_ret_point_2d(OUT x int, OUT y int) + LANGUAGE sql STABLE AS $$ SELECT 10, 5 $$; + CREATE OR REPLACE FUNCTION _dbrest_ret_films() RETURNS TABLE(id int, title text) + LANGUAGE sql STABLE AS $$ SELECT * FROM (VALUES (1, 'Dune')) v(id, title) $$; + CREATE OR REPLACE FUNCTION _dbrest_ret_scalar() RETURNS integer + LANGUAGE sql IMMUTABLE AS $$ SELECT 42 $$; + CREATE OR REPLACE FUNCTION _dbrest_ret_void() RETURNS void + LANGUAGE plpgsql VOLATILE AS $$ BEGIN END $$`); err != nil { + t.Fatalf("seed functions: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, ` + DROP FUNCTION IF EXISTS _dbrest_ret_setof_integers(); + DROP FUNCTION IF EXISTS _dbrest_ret_point_2d(); + DROP FUNCTION IF EXISTS _dbrest_ret_films(); + DROP FUNCTION IF EXISTS _dbrest_ret_scalar(); + DROP FUNCTION IF EXISTS _dbrest_ret_void()`) + }) + + // Introspection fills the return-shape map the native path reads. + if _, err := be.Introspect(ctx); err != nil { + t.Fatalf("Introspect: %v", err) + } + + cases := []struct { + fn string + readOnly bool + want rpc.ReturnKind + }{ + {"_dbrest_ret_setof_integers", true, rpc.ReturnSetOf}, + {"_dbrest_ret_point_2d", true, rpc.ReturnObject}, + {"_dbrest_ret_films", true, rpc.ReturnTable}, + {"_dbrest_ret_scalar", true, rpc.ReturnScalar}, + {"_dbrest_ret_void", false, rpc.ReturnVoid}, + } + for _, c := range cases { + t.Run(c.fn, func(t *testing.T) { + plan := &ir.Plan{ReadOnly: c.readOnly, Call: &ir.Call{ + Function: ir.Ref{Name: c.fn}, + Args: map[string]ir.Value{}, + ReadOnly: c.readOnly, + }} + method := "GET" + if !c.readOnly { + method = "POST" + } + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: method, Path: "/rpc/" + c.fn}) + if err != nil { + t.Fatalf("Execute(%s): %v", c.fn, err) + } + if rs := res.Rows(); rs != nil { + rs.Close() + } + if plan.Func == nil { + t.Fatalf("Execute(%s) did not populate plan.Func from the native catalog", c.fn) + } + if plan.Func.Returns.Kind != c.want { + t.Errorf("%s return kind = %v, want %v", c.fn, plan.Func.Returns.Kind, c.want) + } + }) + } +} + +// TestIntegrationFunctionRegistry covers finding 03-P03: introspection builds the +// function half of the schema cache from pg_proc. Each seeded function exercises one +// signature shape; the per-schema native registry must reconstruct its input +// parameters (names, optionality, variadic, raw body), its return shape, and its +// volatility, and group overloads under one name. +func TestIntegrationFunctionRegistry(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_dbrest_reg"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE SCHEMA IF NOT EXISTS _dbrest_reg; + CREATE OR REPLACE FUNCTION _dbrest_reg.add2(a int, b int) RETURNS int + LANGUAGE sql IMMUTABLE AS $$ SELECT a + b $$; + CREATE OR REPLACE FUNCTION _dbrest_reg.greet(name text, greeting text DEFAULT 'hi') RETURNS text + LANGUAGE sql STABLE AS $$ SELECT greeting || ' ' || name $$; + CREATE OR REPLACE FUNCTION _dbrest_reg.sumall(VARIADIC vals int[]) RETURNS int + LANGUAGE sql IMMUTABLE AS $$ SELECT coalesce((SELECT sum(v) FROM unnest(vals) v), 0)::int $$; + CREATE OR REPLACE FUNCTION _dbrest_reg.takejson(json) RETURNS int + LANGUAGE sql STABLE AS $$ SELECT 1 $$; + CREATE OR REPLACE FUNCTION _dbrest_reg.over1(a int) RETURNS int + LANGUAGE sql STABLE AS $$ SELECT a $$; + CREATE OR REPLACE FUNCTION _dbrest_reg.over1(a int, b int) RETURNS int + LANGUAGE sql STABLE AS $$ SELECT a + b $$; + CREATE OR REPLACE FUNCTION _dbrest_reg.films(year int) RETURNS TABLE(id int, title text) + LANGUAGE sql STABLE AS $$ SELECT 1, 'x' $$`); err != nil { + t.Fatalf("seed functions: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP SCHEMA IF EXISTS _dbrest_reg CASCADE") + }) + if _, err := be.Introspect(ctx); err != nil { + t.Fatalf("Introspect: %v", err) + } + + reg := be.SchemaFunctions("_dbrest_reg") + + // add2: two required int parameters, immutable, scalar return. + fn, ok := reg.Lookup("add2", rpc.ArgSet{"a": true, "b": true}) + if !ok { + t.Fatal("add2 not resolved") + } + if len(fn.Params) != 2 || fn.Params[0].Name != "a" || fn.Params[1].Name != "b" { + t.Errorf("add2 params = %+v", fn.Params) + } + if fn.Volatility != rpc.Immutable { + t.Errorf("add2 volatility = %v, want Immutable", fn.Volatility) + } + if got := fn.Required(); len(got) != 2 { + t.Errorf("add2 required = %v, want both", got) + } + + // greet: the trailing defaulted parameter is optional, so a call with only name + // resolves. + if fn, ok := reg.Lookup("greet", rpc.ArgSet{"name": true}); !ok { + t.Error("greet(name) not resolved despite greeting having a default") + } else if p, _ := fn.Param("greeting"); !p.Optional { + t.Error("greet.greeting should be optional") + } + + // sumall: variadic, so it resolves with no arguments and the parameter is not + // required. + if fn, ok := reg.Lookup("sumall", rpc.ArgSet{}); !ok { + t.Error("sumall() not resolved despite variadic") + } else if !fn.Params[0].Variadic { + t.Error("sumall.vals should be variadic") + } + + // takejson: a lone unnamed json input is a raw body (the request body binds to + // it, so it is found by listing, not by an empty argument set). + var takejson *rpc.Function + for _, f := range reg.List() { + if f.Name == "takejson" { + takejson = f + } + } + if takejson == nil { + t.Error("takejson missing from registry") + } else if p, raw := takejson.SingleRawBody(); !raw || p.Type != "json" { + t.Errorf("takejson single-raw-body = %v, param %+v", raw, p) + } + + // over1: two overloads chosen by argument arity, PGRST203 territory when neither + // is more specific. One arg picks the unary, two args the binary. + if fn, ok := reg.Lookup("over1", rpc.ArgSet{"a": true}); !ok || len(fn.Params) != 1 { + t.Errorf("over1(a) overload = %+v, ok=%v", fn, ok) + } + if fn, ok := reg.Lookup("over1", rpc.ArgSet{"a": true, "b": true}); !ok || len(fn.Params) != 2 { + t.Errorf("over1(a,b) overload = %+v, ok=%v", fn, ok) + } + if _, ok := reg.Lookup("over1", rpc.ArgSet{"a": true, "z": true}); ok { + t.Error("over1(a,z) should not resolve: z names no parameter") + } + + // films: only the input year is a parameter; the TABLE columns are the return + // shape, not arguments. + if fn, ok := reg.Lookup("films", rpc.ArgSet{"year": true}); !ok { + t.Error("films(year) not resolved") + } else { + if len(fn.Params) != 1 || fn.Params[0].Name != "year" { + t.Errorf("films params = %+v, want [year]", fn.Params) + } + if fn.Returns.Kind != rpc.ReturnTable { + t.Errorf("films return kind = %v, want ReturnTable", fn.Returns.Kind) + } + } +} + +// TestIntegrationNativeResolution covers finding 03-P03 end to end: a native RPC +// resolves through the shared planner against the introspected registry, the same +// way the portable path does. It proves overload resolution and its error codes +// (PGRST202 for no match, PGRST203 for ambiguity), GET argument-versus-filter +// partitioning, the volatility-driven access mode (a POST to a STABLE function runs +// read-only), and that the resolved plan lowers and runs against the live engine. +func TestIntegrationNativeResolution(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_dbrest_res"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE SCHEMA IF NOT EXISTS _dbrest_res; + CREATE OR REPLACE FUNCTION _dbrest_res.add2(a int, b int) RETURNS int + LANGUAGE sql IMMUTABLE AS $$ SELECT a + b $$; + CREATE OR REPLACE FUNCTION _dbrest_res.films(year int) RETURNS TABLE(id int, title text, yr int) + LANGUAGE sql STABLE AS $$ SELECT * FROM (VALUES (1,'Dune',2021),(2,'Arrival',2016)) v(id,title,yr) WHERE yr >= year $$; + -- two overloads with the same parameter name but different types: a call + -- naming {a} satisfies both equally, which is PostgREST's PGRST203. + CREATE OR REPLACE FUNCTION _dbrest_res.amb(a int) RETURNS int + LANGUAGE sql STABLE AS $$ SELECT a $$; + CREATE OR REPLACE FUNCTION _dbrest_res.amb(a text) RETURNS text + LANGUAGE sql STABLE AS $$ SELECT a $$; + -- single unnamed json parameter: the whole POST body binds to it. + CREATE OR REPLACE FUNCTION _dbrest_res.takejson(json) RETURNS int + LANGUAGE sql STABLE AS $$ SELECT ($1->>'n')::int $$`); err != nil { + t.Fatalf("seed functions: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP SCHEMA IF EXISTS _dbrest_res CASCADE") + }) + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + reg := be.SchemaFunctions("_dbrest_res") + schemas := []string{"_dbrest_res"} + + // PGRST202: an argument set no overload accepts (add2 has no parameter z). + t.Run("no overload is PGRST202", func(t *testing.T) { + call, apiErr := ir.ParseCall("add2", "", nil, false, "application/json", []byte(`{"a":1,"z":2}`), "", "") + if apiErr != nil { + t.Fatalf("ParseCall: %v", apiErr) + } + _, perr := planpkg.Call(reg, model, call, false, schemas) + if perr == nil || perr.Code != pgerr.CodeNoFunction { + t.Fatalf("plan.Call error = %v, want %s", perr, pgerr.CodeNoFunction) + } + }) + + // PGRST203: an argument set two overloads accept equally well. + t.Run("ambiguous overload is PGRST203", func(t *testing.T) { + call, apiErr := ir.ParseCall("amb", "", nil, false, "application/json", []byte(`{"a":1}`), "", "") + if apiErr != nil { + t.Fatalf("ParseCall: %v", apiErr) + } + _, perr := planpkg.Call(reg, model, call, false, schemas) + if perr == nil || perr.Code != pgerr.CodeAmbiguousFunc { + t.Fatalf("plan.Call error = %v, want %s", perr, pgerr.CodeAmbiguousFunc) + } + }) + + // GET argument-versus-filter partitioning: year is a parameter, title is a + // post-filter on the table result. After planning, the call carries year as an + // argument and title as a WHERE, and the lowered query applies both. + t.Run("GET partitions args from filters", func(t *testing.T) { + call, apiErr := ir.ParseCall("films", "year=2015&title=eq.Arrival", nil, true, "", nil, "", "") + if apiErr != nil { + t.Fatalf("ParseCall: %v", apiErr) + } + plan, perr := planpkg.Call(reg, model, call, true, schemas) + if perr != nil { + t.Fatalf("plan.Call: %v", perr) + } + if _, ok := call.Args["year"]; !ok { + t.Errorf("year should remain an argument, args = %v", call.Args) + } + if call.Where == nil { + t.Error("title=eq.Arrival should have been reclassified as a post-filter") + } + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/rpc/films"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + var titles []string + for rs.Next() { + vals, _ := rs.Values() + for i, c := range rs.Columns() { + if c == "title" { + titles = append(titles, vals[i].(string)) + } + } + } + rs.Close() + // year>=2015 leaves Dune and Arrival; title=eq.Arrival narrows to Arrival. + if len(titles) != 1 || titles[0] != "Arrival" { + t.Errorf("filtered titles = %v, want [Arrival]", titles) + } + }) + + // A POST to a STABLE function runs read-only: plan.ReadOnly is set from the + // introspected volatility, not from the HTTP method. + t.Run("POST to stable runs read-only", func(t *testing.T) { + call, apiErr := ir.ParseCall("add2", "", nil, false, "application/json", []byte(`{"a":2,"b":3}`), "", "") + if apiErr != nil { + t.Fatalf("ParseCall: %v", apiErr) + } + plan, perr := planpkg.Call(reg, model, call, false, schemas) + if perr != nil { + t.Fatalf("plan.Call: %v", perr) + } + if !plan.ReadOnly { + t.Error("POST to an IMMUTABLE function should plan read-only") + } + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "POST", Path: "/rpc/add2"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + var sum int32 + for rs.Next() { + vals, _ := rs.Values() + sum = vals[0].(int32) + } + rs.Close() + if sum != 5 { + t.Errorf("add2(2,3) = %d, want 5", sum) + } + }) + + // A function with a single unnamed body-typed parameter takes the whole POST + // body as that argument. The registry marks it raw-body, ParseCall binds the + // body to it positionally, and compileNativeCall splices it as the lone literal. + t.Run("single unnamed param binds the raw body", func(t *testing.T) { + var fn *rpc.Function + for _, f := range reg.List() { + if f.Name == "takejson" { + fn = f + } + } + if fn == nil { + t.Fatal("takejson missing from registry") + } + p, raw := fn.SingleRawBody() + if !raw { + t.Fatalf("takejson is not single-raw-body, params %+v", fn.Params) + } + call, apiErr := ir.ParseCall("takejson", "", nil, false, "application/json", []byte(`{"n":7}`), p.Name, p.Type) + if apiErr != nil { + t.Fatalf("ParseCall: %v", apiErr) + } + plan, perr := planpkg.Call(reg, model, call, false, schemas) + if perr != nil { + t.Fatalf("plan.Call: %v", perr) + } + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "POST", Path: "/rpc/takejson"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + var got int32 + for rs.Next() { + vals, _ := rs.Values() + got = vals[0].(int32) + } + rs.Close() + if got != 7 { + t.Errorf("takejson({n:7}) = %d, want 7", got) + } + }) +} + +// TestIntegrationNativeOpenAPI covers the last piece of finding 03-P03: an +// introspected native function appears in the generated OpenAPI document. The +// root handler feeds the active schema's native registry into openapi.Generate, so +// a /rpc/ path is emitted for every function the catalog reported, the same as +// the portable path does for a registered function. +func TestIntegrationNativeOpenAPI(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_dbrest_oa"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE SCHEMA IF NOT EXISTS _dbrest_oa; + CREATE OR REPLACE FUNCTION _dbrest_oa.add2(a int, b int) RETURNS int + LANGUAGE sql IMMUTABLE AS $$ SELECT a + b $$`); err != nil { + t.Fatalf("seed function: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP SCHEMA IF EXISTS _dbrest_oa CASCADE") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + reg := be.SchemaFunctions("_dbrest_oa") + body, err := openapi.Generate(model, reg, be.Capabilities(), openapi.Options{ + Host: "localhost", + ActiveSchema: "_dbrest_oa", + }) + if err != nil { + t.Fatalf("openapi.Generate: %v", err) + } + + var doc struct { + Paths map[string]json.RawMessage `json:"paths"` + } + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("unmarshal document: %v", err) + } + if _, ok := doc.Paths["/rpc/add2"]; !ok { + t.Errorf("native function add2 missing from OpenAPI paths; got %v", keysOf(doc.Paths)) + } +} + +func keysOf(m map[string]json.RawMessage) []string { + ks := make([]string, 0, len(m)) + for k := range m { + ks = append(ks, k) + } + return ks +} + +// TestIntegrationComputedFields covers finding 03-P11 (computed fields): a +// function taking a relation's row type and returning a scalar is introspected as +// a virtual column the planner accepts in select, filter, and order, and the +// compiler renders as a function call on the row. Before the fix introspection +// read nothing from pg_proc for relations, so a computed-field select 400'd as an +// unknown column. +func TestIntegrationComputedFields(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_p11cf"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + DROP SCHEMA IF EXISTS _p11cf CASCADE; + CREATE SCHEMA _p11cf; + CREATE TABLE _p11cf.authors (id int PRIMARY KEY, first_name text, last_name text); + INSERT INTO _p11cf.authors VALUES (1,'Ada','Lovelace'), (2,'Alan','Turing'); + -- computed field: a scalar function of the row, same schema as the table. + CREATE FUNCTION _p11cf.full_name(a _p11cf.authors) RETURNS text + LANGUAGE sql STABLE AS $$ SELECT a.first_name || ' ' || a.last_name $$; + -- a set-returning function over the same row type is a computed relationship, + -- not a field; it must not show up among the computed fields. + CREATE TABLE _p11cf.books (id int PRIMARY KEY, author_id int, title text); + CREATE FUNCTION _p11cf.books(a _p11cf.authors) RETURNS SETOF _p11cf.books + LANGUAGE sql STABLE AS $$ SELECT * FROM _p11cf.books WHERE author_id = a.id $$`); err != nil { + t.Fatalf("seed schema: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP SCHEMA IF EXISTS _p11cf CASCADE") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("authors", []string{"_p11cf"}) + if !ok { + t.Fatal("authors relation not found") + } + + // The scalar function is a computed field; the set-returning one is not. + if len(rel.Computed) != 1 { + t.Fatalf("computed fields = %d, want 1: %+v", len(rel.Computed), rel.Computed) + } + cf := rel.Computed[0] + if cf.Name != "full_name" || cf.FuncSchema != "_p11cf" || cf.Type != "text" { + t.Errorf("computed field = %+v, want {full_name _p11cf text}", cf) + } + + // select=id,full_name with a filter and order on the computed field, the way a + // client uses one. The planner accepts all three references and the compiler + // renders the function call. + q := &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Schema: "_p11cf", Name: "authors"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"id"}}, + ir.Column{Path: []string{"full_name"}}, + }, + Order: []ir.OrderTerm{{Path: []string{"full_name"}}}, + } + var where ir.Cond = ir.Compare{ + Path: []string{"full_name"}, Op: ir.OpEq, Value: ir.Value{Text: "Ada Lovelace"}, + } + q.Where = &where + + plan, perr := planpkg.Read(model, q, []string{"_p11cf"}, planpkg.Options{}) + if perr != nil { + t.Fatalf("plan.Read: %v", perr) + } + if q.Computed["full_name"] != "_p11cf" { + t.Errorf("planner did not bind computed field, q.Computed = %v", q.Computed) + } + + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/authors"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + defer rs.Close() + var got []string + for rs.Next() { + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + if len(vals) != 2 { + t.Fatalf("columns = %d, want 2 (id, full_name)", len(vals)) + } + got = append(got, fmt.Sprintf("%v", vals[1])) + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if len(got) != 1 || got[0] != "Ada Lovelace" { + t.Errorf("computed-field read = %v, want [Ada Lovelace]", got) + } +} + +// TestIntegrationComputedRelationships covers finding 03-P11 (computed +// relationships, the recursive-embed escape hatch): a function taking a +// relation's row type and returning SETOF another (here the same) relation is +// introspected as an embeddable edge, so a client can embed it by name. The +// headline case is a self-referential tree, which a stored foreign key cannot +// embed without an explicit hint; the computed relationship names the edge. +// Before the fix introspection read no functions as relationships, so the embed +// 400'd as an unknown relationship. +func TestIntegrationComputedRelationships(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_p11cr"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + DROP SCHEMA IF EXISTS _p11cr CASCADE; + CREATE SCHEMA _p11cr; + CREATE TABLE _p11cr.employees (id int PRIMARY KEY, name text, manager_id int); + INSERT INTO _p11cr.employees VALUES + (1,'Grace',NULL), (2,'Ada',1), (3,'Alan',1), (4,'Edsger',2); + -- a set-returning function over the row type is a to-many computed + -- relationship: the direct reports of the given employee. + CREATE FUNCTION _p11cr.reports(e _p11cr.employees) RETURNS SETOF _p11cr.employees + LANGUAGE sql STABLE AS $$ + SELECT * FROM _p11cr.employees WHERE manager_id = e.id $$`); err != nil { + t.Fatalf("seed schema: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP SCHEMA IF EXISTS _p11cr CASCADE") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("employees", []string{"_p11cr"}) + if !ok { + t.Fatal("employees relation not found") + } + if len(rel.ComputedRels) != 1 { + t.Fatalf("computed relationships = %d, want 1: %+v", len(rel.ComputedRels), rel.ComputedRels) + } + cr := rel.ComputedRels[0] + if cr.Name != "reports" || cr.TargetName != "employees" || cr.Card != schema.CardToMany { + t.Errorf("computed rel = %+v, want {reports ... employees to-many}", cr) + } + + // GET /employees?select=name,reports(name)&id=eq.1 embeds Grace's direct + // reports through the computed edge, the escape hatch a self-referential FK + // cannot offer on its own. + q, perr := ir.ParseRead("employees", "select=name,reports(name)&id=eq.1", nil) + if perr != nil { + t.Fatalf("parse: %v", perr) + } + rp, perr := planpkg.Read(model, q, []string{"_p11cr"}, planpkg.Options{}) + if perr != nil { + t.Fatalf("plan.Read: %v", perr) + } + rp.Rel = rel + + res, err := be.Execute(ctx, rp, &reqctx.Context{Method: "GET", Path: "/employees"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + defer rs.Close() + var got []string + for rs.Next() { + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + // columns: name, reports (a JSON array of {name}) + got = append(got, fmt.Sprintf("%v|%v", vals[0], vals[1])) + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if len(got) != 1 { + t.Fatalf("rows = %d, want 1: %v", len(got), got) + } + // Grace's reports are Ada and Alan; the embed must carry both names. + if !strings.Contains(got[0], "Grace") || !strings.Contains(got[0], "Ada") || !strings.Contains(got[0], "Alan") { + t.Errorf("computed-rel embed = %q, want Grace with reports Ada and Alan", got[0]) + } +} + +// TestIntegrationDataRepresentations covers finding 03-P11 (data +// representations, spec 11): a domain over a base type plus pg_cast casts to and +// from json/text reshapes a column on the wire. The to-json cast formats the +// stored value for a response, the from-json cast parses a write body, and the +// from-text cast parses a query-string filter literal. PostgreSQL ignores these +// casts in the `::` operator, so the introspector records the cast function per +// direction and the compiler calls it by name. The headline is a full round trip: +// POST a representation value, read it back formatted, and filter by the formatted +// form. The fixture is a "color" domain over integer presented as "#rrggbb". +func TestIntegrationDataRepresentations(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_p11dr"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + DROP SCHEMA IF EXISTS _p11dr CASCADE; + CREATE SCHEMA _p11dr; + -- a color is an integer presented on the wire as the string "#rrggbb". + CREATE DOMAIN _p11dr.color AS integer; + -- to-json: format the stored integer as "#rrggbb". + CREATE FUNCTION _p11dr.json(c _p11dr.color) RETURNS json + LANGUAGE sql IMMUTABLE AS $$ + SELECT to_json('#' || lpad(to_hex(c::int), 6, '0')) $$; + -- from-text: parse "#rrggbb" out of a filter literal. + CREATE FUNCTION _p11dr.color(t text) RETURNS _p11dr.color + LANGUAGE sql IMMUTABLE AS $$ + SELECT (('x' || lpad(substring(t from 2), 8, '0'))::bit(32)::int)::_p11dr.color $$; + -- from-json: parse a json string ("#rrggbb") out of a write body. + CREATE FUNCTION _p11dr.color(j json) RETURNS _p11dr.color + LANGUAGE sql IMMUTABLE AS $$ SELECT _p11dr.color(j #>> '{}') $$; + CREATE CAST (_p11dr.color AS json) WITH FUNCTION _p11dr.json(_p11dr.color) AS ASSIGNMENT; + CREATE CAST (text AS _p11dr.color) WITH FUNCTION _p11dr.color(text) AS ASSIGNMENT; + CREATE CAST (json AS _p11dr.color) WITH FUNCTION _p11dr.color(json) AS ASSIGNMENT; + CREATE TABLE _p11dr.shirts (id int PRIMARY KEY, c _p11dr.color)`); err != nil { + t.Fatalf("seed schema: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP SCHEMA IF EXISTS _p11dr CASCADE") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("shirts", []string{"_p11dr"}) + if !ok { + t.Fatal("shirts relation not found") + } + + // The c column carries a representation with all three cast directions, each + // backed by a function in _p11dr. + var crep *schema.Representation + for i := range rel.Columns { + if rel.Columns[i].Name == "c" { + crep = rel.Columns[i].Rep + } + } + if crep == nil { + t.Fatal("column c carries no data representation") + } + if crep.ToJSON.IsZero() || crep.FromText.IsZero() || crep.FromJSON.IsZero() { + t.Fatalf("representation missing a direction: %+v", crep) + } + if crep.ToJSON.Schema != "_p11dr" || crep.ToJSON.Name != "json" || + crep.FromText.Name != "color" || crep.FromJSON.Name != "color" { + t.Errorf("representation functions = %+v, want json/color/color in _p11dr", crep) + } + + // POST a representation value: the from-json cast parses "#0000ff" out of the + // body, and return=representation reads it back formatted through to-json. + wq := &ir.Query{ + Kind: ir.Insert, + Relation: ir.Ref{Schema: "_p11dr", Name: "shirts"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"id"}}, + ir.Column{Path: []string{"c"}}, + }, + Write: &ir.WriteSpec{ + Columns: []string{"id", "c"}, + Rows: []map[string]ir.Value{{"id": {JSON: json.Number("1")}, "c": {JSON: "#0000ff"}}}, + Return: ir.ReturnRepresentation, + }, + } + wp, perr := planpkg.Write(model, wq, []string{"_p11dr"}) + if perr != nil { + t.Fatalf("plan.Write: %v", perr) + } + wp.Rel = rel + wres, err := be.Execute(ctx, wp, &reqctx.Context{Method: "POST", Path: "/shirts"}) + if err != nil { + t.Fatalf("Execute(insert): %v", err) + } + wrs := wres.Rows() + var posted string + for wrs.Next() { + vals, err := wrs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + // columns: id, c (formatted through to-json) + posted = fmt.Sprintf("%v", vals[1]) + } + wrs.Close() + if err := wrs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if !strings.Contains(posted, "#0000ff") { + t.Errorf("return=representation c = %q, want it formatted as #0000ff", posted) + } + + // GET /shirts?select=id,c&c=eq.#0000ff filters by the formatted value: the + // from-text cast parses the literal and the to-json cast formats the output. + rq, perr := ir.ParseRead("shirts", "select=id,c&c=eq.%230000ff", nil) + if perr != nil { + t.Fatalf("parse: %v", perr) + } + rp, perr := planpkg.Read(model, rq, []string{"_p11dr"}, planpkg.Options{}) + if perr != nil { + t.Fatalf("plan.Read: %v", perr) + } + rp.Rel = rel + rres, err := be.Execute(ctx, rp, &reqctx.Context{Method: "GET", Path: "/shirts"}) + if err != nil { + t.Fatalf("Execute(read): %v", err) + } + rs := rres.Rows() + defer rs.Close() + var got []string + for rs.Next() { + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + got = append(got, fmt.Sprintf("%v", vals[1])) + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if len(got) != 1 { + t.Fatalf("filter by representation returned %d rows, want 1: %v", len(got), got) + } + if !strings.Contains(got[0], "#0000ff") { + t.Errorf("read c = %q, want it formatted as #0000ff", got[0]) + } + + // GET /shirts?select=id,c&c=in.(#0000ff,#00ff00) filters by an IN list of + // formatted values: each element parses through the from-text cast over the + // unpacked array (= ANY), matching PostgREST. Only the seeded #0000ff row + // exists, so the list still resolves to one match. + inq, perr := ir.ParseRead("shirts", "select=id,c&c=in.(%230000ff,%2300ff00)", nil) + if perr != nil { + t.Fatalf("parse in: %v", perr) + } + inp, perr := planpkg.Read(model, inq, []string{"_p11dr"}, planpkg.Options{}) + if perr != nil { + t.Fatalf("plan.Read in: %v", perr) + } + inp.Rel = rel + inres, err := be.Execute(ctx, inp, &reqctx.Context{Method: "GET", Path: "/shirts"}) + if err != nil { + t.Fatalf("Execute(in read): %v", err) + } + irs := inres.Rows() + defer irs.Close() + var ingot []string + for irs.Next() { + vals, err := irs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + ingot = append(ingot, fmt.Sprintf("%v", vals[1])) + } + if err := irs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if len(ingot) != 1 || !strings.Contains(ingot[0], "#0000ff") { + t.Errorf("filter by representation IN returned %v, want one #0000ff row", ingot) + } +} + +// TestIntegrationMergedRegistry covers finding 03-P13: a declared portable +// registry on postgres is reachable and shares one document with the native +// catalog. The merged registry (portable plus native, the exact composition the +// server builds per request) resolves both a portable function with no native +// equivalent and a native function; the portable one executes through the SQL +// compiler, and both appear in the OpenAPI document. +func TestIntegrationMergedRegistry(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_dbrest_mrg"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE SCHEMA IF NOT EXISTS _dbrest_mrg; + CREATE OR REPLACE FUNCTION _dbrest_mrg.native_add(a int, b int) RETURNS int + LANGUAGE sql IMMUTABLE AS $$ SELECT a + b $$`); err != nil { + t.Fatalf("seed function: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP SCHEMA IF EXISTS _dbrest_mrg CASCADE") + }) + + // A portable function with no native equivalent, declared the way an operator + // supplies one via Register. + be.Register(rpc.NewStaticRegistry([]*rpc.Function{{ + Name: "portable_mul", + Params: []rpc.Param{{Name: "a", Type: "int"}, {Name: "b", Type: "int"}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "int"}, + Volatility: rpc.Immutable, + Query: &rpc.PortableQuery{SQL: "SELECT :a::int * :b::int"}, + }})) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + merged := rpc.Merge(be.Functions(), be.SchemaFunctions("_dbrest_mrg")) + schemas := []string{"_dbrest_mrg"} + + // The portable function resolves and runs through the SQL compiler. + t.Run("portable function is reachable", func(t *testing.T) { + call, apiErr := ir.ParseCall("portable_mul", "", nil, false, "application/json", []byte(`{"a":6,"b":7}`), "", "") + if apiErr != nil { + t.Fatalf("ParseCall: %v", apiErr) + } + plan, perr := planpkg.Call(merged, model, call, false, schemas) + if perr != nil { + t.Fatalf("plan.Call: %v", perr) + } + if plan.Func == nil || plan.Func.Query == nil { + t.Fatal("portable_mul should resolve to a portable function with a Query") + } + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "POST", Path: "/rpc/portable_mul"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + var got int + for rs.Next() { + vals, _ := rs.Values() + switch v := vals[0].(type) { + case int32: + got = int(v) + case int64: + got = int(v) + } + } + rs.Close() + if got != 42 { + t.Errorf("portable_mul(6,7) = %d, want 42", got) + } + }) + + // The native function still resolves through the same merged registry. + t.Run("native function is reachable", func(t *testing.T) { + if _, ok := merged.Lookup("native_add", rpc.ArgSet{"a": true, "b": true}); !ok { + t.Error("native_add should resolve in the merged registry") + } + }) + + // Both functions appear in the OpenAPI document. + t.Run("both appear in OpenAPI", func(t *testing.T) { + body, err := openapi.Generate(model, merged, be.Capabilities(), openapi.Options{ + Host: "localhost", + ActiveSchema: "_dbrest_mrg", + }) + if err != nil { + t.Fatalf("openapi.Generate: %v", err) + } + var doc struct { + Paths map[string]json.RawMessage `json:"paths"` + } + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("unmarshal document: %v", err) + } + for _, want := range []string{"/rpc/portable_mul", "/rpc/native_add"} { + if _, ok := doc.Paths[want]; !ok { + t.Errorf("%s missing from OpenAPI paths; got %v", want, keysOf(doc.Paths)) + } + } + }) +} + +// TestIntegrationNativeVolatileCount covers finding 03-P02: a POST to a VOLATILE +// set-returning function with Prefer: count=exact returns the exact total over the +// filtered set, and the function runs exactly once. The read path counts with a +// separate statement, but a volatile function has side effects, so the count must +// ride count(*) OVER () on the single row query rather than re-invoking the +// function. An audit table records each invocation, proving single execution. +func TestIntegrationNativeVolatileCount(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_vc_audit (n int); + TRUNCATE _dbrest_vc_audit; + CREATE OR REPLACE FUNCTION _dbrest_vc_enroll() RETURNS TABLE(n int) + LANGUAGE plpgsql VOLATILE AS $$ + BEGIN + INSERT INTO _dbrest_vc_audit VALUES (1); + RETURN QUERY SELECT * FROM (VALUES (1),(2),(3),(4)) v(n); + END $$`); err != nil { + t.Fatalf("seed function: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, ` + DROP FUNCTION IF EXISTS _dbrest_vc_enroll(); + DROP TABLE IF EXISTS _dbrest_vc_audit`) + }) + if _, err := be.Introspect(ctx); err != nil { + t.Fatalf("Introspect: %v", err) + } + + // POST with a filter (n >= 2) and a limit of 1: of the four returned values + // three match, so the exact count is 3, but only one row reaches the body. + plan := &ir.Plan{Call: &ir.Call{ + Function: ir.Ref{Name: "_dbrest_vc_enroll"}, + Args: map[string]ir.Value{}, + Where: condPtr(ir.Compare{Path: []string{"n"}, Op: ir.OpGte, Value: ir.Value{Text: "2"}}), + Limit: intPtr(1), + Count: ir.CountExact, + }} + + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "POST", Path: "/rpc/_dbrest_vc_enroll"}) + if err != nil { + t.Fatalf("Execute(volatile call): %v", err) + } + + if c, ok := res.Count(); !ok || c != 3 { + t.Errorf("Count = (%d, %v), want (3, true) over the filtered rows", c, ok) + } + + rs := res.Rows() + var rows, cols int + for rs.Next() { + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + cols = len(vals) + rows++ + } + rs.Close() + if rows != 1 { + t.Errorf("limit 1 returned %d rows, want 1", rows) + } + // The _pgrst_count window column must not leak into the body. + if cols != 1 { + t.Errorf("body row has %d columns, want 1 (count column stripped)", cols) + } + + // The function ran exactly once: a separate count statement would have inserted + // a second audit row. + var runs int + if err := be.Pool().QueryRow(ctx, "SELECT count(*) FROM _dbrest_vc_audit").Scan(&runs); err != nil { + t.Fatalf("audit count: %v", err) + } + if runs != 1 { + t.Errorf("function ran %d times, want exactly 1", runs) + } +} + +// TestIntegrationNativeCallSchemaDispatch proves a native RPC resolves in the +// request's negotiated schema (Accept-Profile / Content-Profile, carried as +// reqctx.Context.Schema), not always the first configured schema. Two schemas +// expose a same-named function with distinct results; switching rc.Schema picks +// the matching one. Finding 03-P04. +func TestIntegrationNativeCallSchemaDispatch(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_dbrest_s1", "_dbrest_s2"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE SCHEMA IF NOT EXISTS _dbrest_s1; + CREATE SCHEMA IF NOT EXISTS _dbrest_s2; + CREATE OR REPLACE FUNCTION _dbrest_s1.whoami() RETURNS text + LANGUAGE sql STABLE AS $$ SELECT 'schema-one' $$; + CREATE OR REPLACE FUNCTION _dbrest_s2.whoami() RETURNS text + LANGUAGE sql STABLE AS $$ SELECT 'schema-two' $$`); err != nil { + t.Fatalf("seed schemas: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP SCHEMA IF EXISTS _dbrest_s1 CASCADE; DROP SCHEMA IF EXISTS _dbrest_s2 CASCADE") + }) + + call := func(schema string) string { + t.Helper() + plan := &ir.Plan{ReadOnly: true, Call: &ir.Call{ + Function: ir.Ref{Name: "whoami"}, + Args: map[string]ir.Value{}, + ReadOnly: true, + }} + rc := &reqctx.Context{Method: "GET", Path: "/rpc/whoami", Schema: schema} + res, err := be.Execute(ctx, plan, rc) + if err != nil { + t.Fatalf("Execute(%s): %v", schema, err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatalf("Execute(%s): no rows", schema) + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values(%s): %v", schema, err) + } + return vals[0].(string) + } + + if got := call("_dbrest_s1"); got != "schema-one" { + t.Errorf("Accept-Profile _dbrest_s1 dispatched to %q, want schema-one", got) + } + if got := call("_dbrest_s2"); got != "schema-two" { + t.Errorf("Accept-Profile _dbrest_s2 dispatched to %q, want schema-two", got) + } +} + +// TestIntegrationNativeCallJSONArg proves a JSON object argument binds to a +// json, a jsonb, and a text parameter alike. The argument is spliced as an +// untyped literal so PostgreSQL's function resolution coerces it to whichever +// type the parameter declares; a '...'::json literal would fail to match a +// jsonb parameter (42883 -> 404). Finding 03-P05. +func TestIntegrationNativeCallJSONArg(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE OR REPLACE FUNCTION _dbrest_test_jb(payload jsonb) RETURNS text + LANGUAGE sql IMMUTABLE AS $$ SELECT payload->>'name' $$; + CREATE OR REPLACE FUNCTION _dbrest_test_js(payload json) RETURNS text + LANGUAGE sql IMMUTABLE AS $$ SELECT payload->>'name' $$; + CREATE OR REPLACE FUNCTION _dbrest_test_tx(payload text) RETURNS text + LANGUAGE sql IMMUTABLE AS $$ SELECT payload $$`); err != nil { + t.Fatalf("seed functions: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, ` + DROP FUNCTION IF EXISTS _dbrest_test_jb(jsonb); + DROP FUNCTION IF EXISTS _dbrest_test_js(json); + DROP FUNCTION IF EXISTS _dbrest_test_tx(text)`) + }) + + call := func(fn string) string { + t.Helper() + plan := &ir.Plan{ReadOnly: true, Call: &ir.Call{ + Function: ir.Ref{Name: fn}, + Args: map[string]ir.Value{"payload": {JSON: map[string]any{"name": "Ada"}}}, + ReadOnly: true, + }} + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/rpc/" + fn}) + if err != nil { + t.Fatalf("Execute(%s): %v", fn, err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatalf("Execute(%s): no rows", fn) + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values(%s): %v", fn, err) + } + if vals[0] == nil { + return "" + } + return vals[0].(string) + } + + // json/jsonb parameters extract the name; the text parameter receives the + // serialized object. The point is that none of the three 404s. + if got := call("_dbrest_test_jb"); got != "Ada" { + t.Errorf("jsonb arg returned %q, want Ada", got) + } + if got := call("_dbrest_test_js"); got != "Ada" { + t.Errorf("json arg returned %q, want Ada", got) + } + if got := call("_dbrest_test_tx"); got != `{"name":"Ada"}` { + t.Errorf("text arg returned %q, want the serialized object", got) + } +} + +// TestIntegrationTemporalRendering proves date, time, timetz, interval, +// timestamp, and timestamptz columns render through the backend as the same JSON +// strings PostgreSQL itself emits (to_json), instead of Go struct or Z-suffixed +// RFC3339 spellings. The expected values are read back from the server with +// to_json so the assertion tracks the live server's TimeZone. Finding 03-P07. +func TestIntegrationTemporalRendering(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_temporal ( + id int PRIMARY KEY, + d date, + t time, + ttz timetz, + iv interval, + ts timestamp, + tstz timestamptz + ); + TRUNCATE _dbrest_test_temporal; + INSERT INTO _dbrest_test_temporal VALUES ( + 1, + '2017-01-02', + '13:00:00.5', + '13:00:00+02', + '1 day 02:03:04.5', + '2017-01-01 14:30:00.123456', + '2017-07-01 14:30:00+05' + )`); err != nil { + t.Fatalf("seed temporal table: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP TABLE IF EXISTS _dbrest_test_temporal") + }) + + cols := []string{"d", "t", "ttz", "iv", "ts", "tstz"} + // The oracle: PostgreSQL's own JSON spelling for each column, stripped of the + // surrounding quotes a JSON string carries. + want := make([]string, len(cols)) + for i, c := range cols { + var j string + if err := be.Pool().QueryRow(ctx, + "SELECT to_json("+c+")::text FROM _dbrest_test_temporal WHERE id = 1").Scan(&j); err != nil { + t.Fatalf("oracle to_json(%s): %v", c, err) + } + want[i] = strings.Trim(j, `"`) + } + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_temporal", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_temporal not found") + } + sel := make([]ir.SelectItem, len(cols)) + for i, c := range cols { + sel[i] = ir.Column{Path: []string{c}} + } + plan := &ir.Plan{Rel: rel, Query: &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_temporal"}, + Select: sel, + }} + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/_dbrest_test_temporal"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatal("no rows") + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + for i, c := range cols { + got, ok := vals[i].(string) + if !ok { + t.Errorf("column %s rendered as %T (%v), want a string", c, vals[i], vals[i]) + continue + } + if got != want[i] { + t.Errorf("column %s = %q, want %q (PostgreSQL to_json)", c, got, want[i]) + } + } +} + +// TestIntegrationFullTextTSVector proves an fts filter on a real tsvector column +// returns rows instead of failing. PostgreSQL has no to_tsvector(tsvector) +// overload, so wrapping the column raised 42883 (surfaced as 404). With the +// column type threaded through, the dialect matches the column directly +// (col @@ to_tsquery(...)), the way PostgREST does. Finding 01-P01. +func TestIntegrationFullTextTSVector(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_fts ( + id serial PRIMARY KEY, + doc tsvector NOT NULL + ); + TRUNCATE _dbrest_test_fts; + INSERT INTO _dbrest_test_fts (doc) VALUES + (to_tsvector('english', 'the quick brown fox')), + (to_tsvector('english', 'a lazy dog sleeps'))`); err != nil { + t.Fatalf("seed tsvector table: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP TABLE IF EXISTS _dbrest_test_fts") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_fts", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_fts not found") + } + + // fts on the tsvector column: ?doc=fts.fox should match only the first row. + // ColumnType is "tsvector" as the planner resolves it from the schema. + plan := &ir.Plan{Rel: rel, Query: &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_fts"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"id"}}}, + Where: condPtr(ir.Compare{ + Path: []string{"doc"}, + Op: ir.OpFTS, + FTS: ir.FTSPlain, + Value: ir.Value{Text: "fox"}, + ColumnType: "tsvector", + }), + }} + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/_dbrest_test_fts"}) + if err != nil { + t.Fatalf("Execute(fts on tsvector): %v", err) + } + rs := res.Rows() + defer rs.Close() + rows := 0 + for rs.Next() { + if _, err := rs.Values(); err != nil { + t.Fatalf("Values: %v", err) + } + rows++ + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if rows != 1 { + t.Errorf("fts.fox matched %d rows, want 1", rows) + } +} + +// TestIntegrationArrayPayloadByColumnType proves a JSON array payload value +// lands as JSON in a jsonb column and as a PostgreSQL array in a text[] column. +// Before the fix every array became a {a,b} literal, so inserting an array into +// a jsonb column failed with 22P02. The planner resolves the target column type +// and the dialect routes the value accordingly. Finding 01-P06. +func TestIntegrationArrayPayloadByColumnType(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_arr ( + id serial PRIMARY KEY, + tags jsonb NOT NULL, + labs text[] NOT NULL + ); + TRUNCATE _dbrest_test_arr`); err != nil { + t.Fatalf("seed table: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP TABLE IF EXISTS _dbrest_test_arr") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_arr", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_arr not found") + } + + // The planner fills WriteSpec.ColumnTypes from the relation; build the same + // shape here so the compiler routes each array by its target column type. + plan := &ir.Plan{Rel: rel, Query: &ir.Query{ + Kind: ir.Insert, + Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_arr"}, + Write: &ir.WriteSpec{ + Rows: []map[string]ir.Value{{ + "tags": {JSON: []any{"x", "y"}}, + "labs": {JSON: []any{"a", "b"}}, + }}, + Columns: []string{"tags", "labs"}, + ColumnTypes: map[string]string{"tags": "jsonb", "labs": "text[]"}, + Return: ir.ReturnMinimal, + }, + }} + if _, err := be.Execute(ctx, plan, &reqctx.Context{Method: "POST", Path: "/_dbrest_test_arr"}); err != nil { + t.Fatalf("Execute(insert arrays): %v", err) + } + + // Read the stored values straight from the pool to confirm the jsonb holds a + // JSON array and the text[] holds two elements. + var tags string + var labs []string + if err := be.Pool().QueryRow(ctx, + "SELECT tags::text, labs FROM _dbrest_test_arr LIMIT 1").Scan(&tags, &labs); err != nil { + t.Fatalf("read back: %v", err) + } + if tags != `["x", "y"]` { + t.Errorf("jsonb tags = %q, want a JSON array", tags) + } + if len(labs) != 2 || labs[0] != "a" || labs[1] != "b" { + t.Errorf("text[] labs = %v, want [a b]", labs) + } +} + +// TestIntegrationViewForeignKeyInference covers finding 03-P09: a base table's +// foreign keys are projected onto a view over it, so embedding works through a view +// the same way it does through the base table. It exercises a renamed-column view, a +// view-over-view chain (the projection runs to a fixpoint), a materialized view, and +// an end-to-end embed of the referenced table through the view. +func TestIntegrationViewForeignKeyInference(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _p9_authors (id int PRIMARY KEY, name text); + CREATE TABLE IF NOT EXISTS _p9_books ( + id int PRIMARY KEY, title text, author_id int REFERENCES _p9_authors(id)); + TRUNCATE _p9_books, _p9_authors; + INSERT INTO _p9_authors (id, name) VALUES (1, 'Le Guin'), (2, 'Butler'); + INSERT INTO _p9_books (id, title, author_id) VALUES (10, 'A Wizard of Earthsea', 1), (11, 'Kindred', 2); + CREATE OR REPLACE VIEW _p9_books_v AS SELECT id, title, author_id FROM _p9_books; + CREATE OR REPLACE VIEW _p9_books_renamed AS SELECT id AS book_id, author_id AS writer FROM _p9_books; + CREATE OR REPLACE VIEW _p9_books_chain AS SELECT book_id, writer FROM _p9_books_renamed; + DROP MATERIALIZED VIEW IF EXISTS _p9_books_m; + CREATE MATERIALIZED VIEW _p9_books_m AS SELECT id, author_id FROM _p9_books`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, ` + DROP MATERIALIZED VIEW IF EXISTS _p9_books_m; + DROP VIEW IF EXISTS _p9_books_chain, _p9_books_renamed, _p9_books_v; + DROP TABLE IF EXISTS _p9_books, _p9_authors`) + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + + // projectedFK returns the view's foreign key whose source column matches, with + // the columns it was projected under. + hasFKTo := func(relName, refRel string, wantCols ...string) bool { + rel, ok := model.Lookup(relName, []string{"public"}) + if !ok { + return false + } + for _, fk := range rel.ForeignKeys { + if fk.RefRelation == refRel && reflect.DeepEqual(fk.Columns, wantCols) { + return true + } + } + return false + } + + t.Run("plain view inherits the base FK", func(t *testing.T) { + if !hasFKTo("_p9_books_v", "_p9_authors", "author_id") { + rel, _ := model.Lookup("_p9_books_v", []string{"public"}) + t.Errorf("books_v missing projected FK to _p9_authors; FKs = %+v", rel.ForeignKeys) + } + }) + + t.Run("renamed column carries the FK under the view name", func(t *testing.T) { + if !hasFKTo("_p9_books_renamed", "_p9_authors", "writer") { + rel, _ := model.Lookup("_p9_books_renamed", []string{"public"}) + t.Errorf("books_renamed missing projected FK on writer; FKs = %+v", rel.ForeignKeys) + } + }) + + t.Run("view-over-view chain resolves to a fixpoint", func(t *testing.T) { + if !hasFKTo("_p9_books_chain", "_p9_authors", "writer") { + rel, _ := model.Lookup("_p9_books_chain", []string{"public"}) + t.Errorf("books_chain missing projected FK on writer; FKs = %+v", rel.ForeignKeys) + } + }) + + t.Run("materialized view inherits the base FK", func(t *testing.T) { + if !hasFKTo("_p9_books_m", "_p9_authors", "author_id") { + rel, _ := model.Lookup("_p9_books_m", []string{"public"}) + t.Errorf("books_m missing projected FK to _p9_authors; FKs = %+v", rel.ForeignKeys) + } + }) + + t.Run("embedding the referenced table through the view works", func(t *testing.T) { + rel, ok := model.Lookup("_p9_books_v", []string{"public"}) + if !ok { + t.Fatal("_p9_books_v not found") + } + q, perr := ir.ParseRead("_p9_books_v", "select=title,_p9_authors(name)&order=id", nil) + if perr != nil { + t.Fatalf("parse: %v", perr) + } + rp, perr := planpkg.Read(model, q, []string{"public"}, planpkg.Options{}) + if perr != nil { + t.Fatalf("plan: %v", perr) + } + rp.Rel = rel + res, err := be.Execute(ctx, rp, &reqctx.Context{Method: "GET", Path: "/_p9_books_v"}) + if err != nil { + t.Fatalf("Execute(embed through view): %v", err) + } + rs := res.Rows() + defer rs.Close() + rows := 0 + for rs.Next() { + if _, err := rs.Values(); err != nil { + t.Fatalf("Values: %v", err) + } + rows++ + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if rows != 2 { + t.Errorf("embed through view returned %d rows, want 2", rows) + } + }) +} + +// TestIntegrationWideEmbed proves an embed of a table with more than 50 columns +// assembles instead of failing. json_build_object caps at 100 arguments (two per +// key), so a 60-column embed raised 54023; the dialect now chunks the object with +// jsonb_build_object and || past 50 keys. Finding 01-P07. +func TestIntegrationWideEmbed(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + // A parent with a child whose 60 columns force the chunked path. + var childCols strings.Builder + for i := 0; i < 60; i++ { + fmt.Fprintf(&childCols, ", c%d int DEFAULT %d", i, i) + } + ddl := ` + CREATE TABLE IF NOT EXISTS _dbrest_test_parent (id int PRIMARY KEY); + CREATE TABLE IF NOT EXISTS _dbrest_test_child ( + id int PRIMARY KEY, + parent_id int REFERENCES _dbrest_test_parent(id)` + childCols.String() + ` + ); + TRUNCATE _dbrest_test_child, _dbrest_test_parent; + INSERT INTO _dbrest_test_parent (id) VALUES (1); + INSERT INTO _dbrest_test_child (id, parent_id) VALUES (10, 1);` + if _, err := be.Pool().Exec(ctx, ddl); err != nil { + t.Fatalf("seed wide tables: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP TABLE IF EXISTS _dbrest_test_child; DROP TABLE IF EXISTS _dbrest_test_parent") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_parent", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_parent not found") + } + + // GET /_dbrest_test_parent?select=id,_dbrest_test_child(*) embeds every child + // column, which is the chunked-object case. + q, perr := ir.ParseRead("_dbrest_test_parent", "select=id,_dbrest_test_child(*)", nil) + if perr != nil { + t.Fatalf("parse: %v", perr) + } + rp, perr := planpkg.Read(model, q, []string{"public"}, planpkg.Options{}) + if perr != nil { + t.Fatalf("plan: %v", perr) + } + rp.Rel = rel + + res, err := be.Execute(ctx, rp, &reqctx.Context{Method: "GET", Path: "/_dbrest_test_parent"}) + if err != nil { + t.Fatalf("Execute(wide embed): %v", err) + } + rs := res.Rows() + defer rs.Close() + rows := 0 + for rs.Next() { + if _, err := rs.Values(); err != nil { + t.Fatalf("Values: %v", err) + } + rows++ + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if rows != 1 { + t.Errorf("wide embed returned %d parent rows, want 1", rows) + } +} + +// TestIntegrationCountedReadConsistent exercises the counted-read path, which +// runs the count and the page as two statements. The fix pins that transaction +// to REPEATABLE READ so both statements read one snapshot, the way PostgREST's +// single statement does. The test seeds a known set, reads it with a page +// smaller than the total, and proves the exact count reports the whole set while +// the page honours the limit. Finding P11. +func TestIntegrationCountedReadConsistent(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_counted (id serial PRIMARY KEY); + TRUNCATE _dbrest_test_counted; + INSERT INTO _dbrest_test_counted SELECT generate_series(1, 7)`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP TABLE IF EXISTS _dbrest_test_counted") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_counted", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_counted not found") + } + + plan := &ir.Plan{ + Rel: rel, + Query: &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_counted"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"id"}}}, + Limit: intPtr(3), + Count: ir.CountExact, + }, + } + + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/_dbrest_test_counted"}) + if err != nil { + t.Fatalf("Execute(counted read): %v", err) + } + if c, ok := res.Count(); !ok || c != 7 { + t.Errorf("Count = (%d, %v), want (7, true) over the whole set", c, ok) + } + rs := res.Rows() + defer rs.Close() + page := 0 + for rs.Next() { + if _, err := rs.Values(); err != nil { + t.Fatalf("Values: %v", err) + } + page++ + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + if page != 3 { + t.Errorf("page returned %d rows, want 3 (the limit)", page) + } +} + +// TestIntegrationUpsertNoConflictTarget proves a merge upsert against a table +// with no primary key degrades to a plain INSERT instead of emitting an invalid +// ON CONFLICT DO UPDATE. This matches PostgREST 14, where a merge-duplicates POST +// to a key-less table inserts the rows and returns 201 (verified against a live +// PostgREST). Two identical rows therefore both land. Finding P12. +func TestIntegrationUpsertNoConflictTarget(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_nopk (a int, b text); + TRUNCATE _dbrest_test_nopk`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP TABLE IF EXISTS _dbrest_test_nopk") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_nopk", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_nopk not found") + } + + plan := &ir.Plan{ + Rel: rel, + Query: &ir.Query{ + Kind: ir.Upsert, + Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_nopk"}, + Write: &ir.WriteSpec{ + Rows: []map[string]ir.Value{{"a": {JSON: "1"}, "b": {JSON: "x"}}}, + Columns: []string{"a", "b"}, + Return: ir.ReturnMinimal, + Conflict: &ir.Conflict{Resolution: ir.ConflictMerge}, + }, + }, + } + rc := &reqctx.Context{Method: "POST", Path: "/_dbrest_test_nopk"} + + for i := 0; i < 2; i++ { + if _, err := be.Execute(ctx, plan, rc); err != nil { + t.Fatalf("Execute(merge upsert, no PK) #%d: %v", i, err) + } + } + var n int + if err := be.Pool().QueryRow(ctx, "SELECT count(*) FROM _dbrest_test_nopk WHERE a=1").Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 2 { + t.Errorf("rows after two merge upserts = %d, want 2 (plain insert, no merge)", n) + } +} + +// TestIntegrationInListAny proves the col = ANY($1) lowering selects exactly the +// rows an expanded IN would, against a live server. The list binds as one array +// literal parameter. Finding P13. +func TestIntegrationInListAny(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_inlist (id int PRIMARY KEY); + TRUNCATE _dbrest_test_inlist; + INSERT INTO _dbrest_test_inlist SELECT generate_series(1, 5)`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP TABLE IF EXISTS _dbrest_test_inlist") + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_inlist", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_inlist not found") + } + + plan := &ir.Plan{ + Rel: rel, + Query: &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_inlist"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"id"}}}, + Where: condPtr(ir.Compare{Path: []string{"id"}, Op: ir.OpIn, ColumnType: "integer", Value: ir.Value{List: []string{"2", "4", "9"}}}), + Order: []ir.OrderTerm{{Path: []string{"id"}}}, + }, + } + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/_dbrest_test_inlist"}) + if err != nil { + t.Fatalf("Execute(in-list): %v", err) + } + rs := res.Rows() + defer rs.Close() + var got []int32 + for rs.Next() { + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + got = append(got, vals[0].(int32)) + } + if err := rs.Err(); err != nil { + t.Fatalf("row error: %v", err) + } + // 2 and 4 exist; 9 does not. = ANY selects exactly the present members. + if len(got) != 2 || got[0] != 2 || got[1] != 4 { + t.Errorf("in-list rows = %v, want [2 4]", got) + } +} + +// TestIntegrationSearchPathShape proves the per-request search_path is the active +// schema followed by db-extra-search-path (default "public"), not the whole +// exposed schema set, and that the GUC string is the verbatim quoted value +// PostgREST writes. It reads current_setting('search_path') through a native RPC +// and switches the active schema via Accept-Profile (reqctx.Context.Schema). +// Finding 02-P01. Verified against PostgREST 14.12, which sets the path with +// set_config('search_path', '"", "public"', true) and does not dedup, so +// an active schema of public yields "public", "public". +func TestIntegrationSearchPathShape(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE SCHEMA IF NOT EXISTS _dbrest_sp1; + CREATE SCHEMA IF NOT EXISTS _dbrest_sp2; + CREATE OR REPLACE FUNCTION public.show_path() RETURNS text + LANGUAGE sql STABLE AS $$ SELECT current_setting('search_path') $$; + CREATE OR REPLACE FUNCTION _dbrest_sp1.show_path() RETURNS text + LANGUAGE sql STABLE AS $$ SELECT current_setting('search_path') $$; + CREATE OR REPLACE FUNCTION _dbrest_sp2.show_path() RETURNS text + LANGUAGE sql STABLE AS $$ SELECT current_setting('search_path') $$`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP FUNCTION IF EXISTS public.show_path(); + DROP SCHEMA IF EXISTS _dbrest_sp1 CASCADE; DROP SCHEMA IF EXISTS _dbrest_sp2 CASCADE`) + }) + + path := func(schema string) string { + t.Helper() + plan := &ir.Plan{ReadOnly: true, Call: &ir.Call{ + Function: ir.Ref{Name: "show_path"}, + Args: map[string]ir.Value{}, + ReadOnly: true, + }} + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/rpc/show_path", Schema: schema}) + if err != nil { + t.Fatalf("Execute(%q): %v", schema, err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatalf("Execute(%q): no rows", schema) + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values(%q): %v", schema, err) + } + return vals[0].(string) + } + + // Default active schema is public (the single configured schema), extra is the + // default "public"; PostgREST does not dedup, so the path is "public", "public". + be.SetSchemas([]string{"public"}) + be.SetExtraSearchPath([]string{"public"}) + if got := path(""); got != `"public", "public"` { + t.Errorf(`default search_path = %q, want "public", "public"`, got) + } + + // Two exposed schemas: the active one (Accept-Profile) leads the path, not the + // first configured schema, and the whole set never appears. + be.SetSchemas([]string{"_dbrest_sp1", "_dbrest_sp2"}) + if got := path("_dbrest_sp1"); got != `"_dbrest_sp1", "public"` { + t.Errorf(`sp1 search_path = %q, want "_dbrest_sp1", "public"`, got) + } + if got := path("_dbrest_sp2"); got != `"_dbrest_sp2", "public"` { + t.Errorf(`sp2 search_path = %q, want "_dbrest_sp2", "public"`, got) + } +} + +// TestIntegrationSearchPathReachesExtra proves db-extra-search-path puts its +// schemas on the path: a function running in a non-public active schema resolves +// an unqualified helper defined in public because public is appended to the path. +// Finding 02-P01. +func TestIntegrationSearchPathReachesExtra(t *testing.T) { + be := openBE(t) + be.SetSchemas([]string{"_dbrest_spx"}) + be.SetExtraSearchPath([]string{"public"}) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE SCHEMA IF NOT EXISTS _dbrest_spx; + CREATE OR REPLACE FUNCTION public._dbrest_helper() RETURNS text + LANGUAGE sql IMMUTABLE AS $$ SELECT 'from-public' $$; + CREATE OR REPLACE FUNCTION _dbrest_spx.uses_helper() RETURNS text + LANGUAGE sql STABLE AS $$ SELECT _dbrest_helper() $$`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP FUNCTION IF EXISTS public._dbrest_helper(); + DROP SCHEMA IF EXISTS _dbrest_spx CASCADE`) + }) + + plan := &ir.Plan{ReadOnly: true, Call: &ir.Call{ + Function: ir.Ref{Name: "uses_helper"}, + Args: map[string]ir.Value{}, + ReadOnly: true, + }} + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/rpc/uses_helper", Schema: "_dbrest_spx"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatal("no rows") + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + if got := vals[0].(string); got != "from-public" { + t.Errorf("unqualified helper resolved to %q, want from-public", got) + } +} + +// TestIntegrationNativeCallVolatilityAccessMode proves the native RPC access mode +// follows the function's volatility, not only the HTTP method: a POST to a STABLE +// or IMMUTABLE function runs in a read-only transaction, while a VOLATILE function +// runs read-write, matching PostgREST's access-mode table. Each function reports +// current_setting('transaction_read_only') so the transaction mode is observed +// directly, and a volatile insert proves the read-write path still commits. +// Finding 02-P06. Verified against PostgREST 14.12. +func TestIntegrationNativeCallVolatilityAccessMode(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_vol (n int); + TRUNCATE _dbrest_test_vol; + CREATE OR REPLACE FUNCTION public._dbrest_txmode_v() RETURNS text + LANGUAGE sql VOLATILE AS $$ SELECT current_setting('transaction_read_only') $$; + CREATE OR REPLACE FUNCTION public._dbrest_txmode_s() RETURNS text + LANGUAGE sql STABLE AS $$ SELECT current_setting('transaction_read_only') $$; + CREATE OR REPLACE FUNCTION public._dbrest_vol_insert(x int) RETURNS int + LANGUAGE sql VOLATILE AS $$ INSERT INTO _dbrest_test_vol VALUES (x) RETURNING n $$`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP FUNCTION IF EXISTS public._dbrest_txmode_v(); + DROP FUNCTION IF EXISTS public._dbrest_txmode_s(); + DROP FUNCTION IF EXISTS public._dbrest_vol_insert(int); + DROP TABLE IF EXISTS _dbrest_test_vol`) + }) + + // Refresh the catalog so the new functions' volatility is loaded. + if _, err := be.Introspect(ctx); err != nil { + t.Fatalf("Introspect: %v", err) + } + + txmode := func(fn, method string) string { + t.Helper() + plan := &ir.Plan{ + ReadOnly: method == "GET", + Call: &ir.Call{Function: ir.Ref{Name: fn}, Args: map[string]ir.Value{}}, + } + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: method, Path: "/rpc/" + fn}) + if err != nil { + t.Fatalf("Execute(%s %s): %v", method, fn, err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatalf("Execute(%s %s): no rows", method, fn) + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values(%s %s): %v", method, fn, err) + } + return vals[0].(string) + } + + // POST to a VOLATILE function runs read-write; POST to a STABLE function runs + // read-only (the fix); GET to either is read-only. + if got := txmode("_dbrest_txmode_v", "POST"); got != "off" { + t.Errorf("volatile POST transaction_read_only = %q, want off", got) + } + if got := txmode("_dbrest_txmode_s", "POST"); got != "on" { + t.Errorf("stable POST transaction_read_only = %q, want on (read-only)", got) + } + if got := txmode("_dbrest_txmode_s", "GET"); got != "on" { + t.Errorf("stable GET transaction_read_only = %q, want on", got) + } + + // The read-write path still commits: a volatile insert via POST persists. + volPlan := &ir.Plan{Call: &ir.Call{ + Function: ir.Ref{Name: "_dbrest_vol_insert"}, + Args: map[string]ir.Value{"x": {Text: "7"}}, + }} + if _, err := be.Execute(ctx, volPlan, &reqctx.Context{Method: "POST", Path: "/rpc/_dbrest_vol_insert"}); err != nil { + t.Fatalf("volatile insert POST: %v", err) + } + var n int + if err := be.Pool().QueryRow(ctx, "SELECT count(*) FROM _dbrest_test_vol").Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 1 { + t.Errorf("after volatile insert POST rows = %d, want 1", n) + } +} + +// TestIntegrationImpersonatedRoleSettings proves the backend replays an +// impersonated role's ALTER ROLE ... SET settings as transaction-scoped settings, +// like PostgREST: a role pinned to statement_timeout '50ms' carries that timeout +// on every request, a slow call as that role is cancelled (SQLSTATE 57014 -> 500), +// and the setting is transaction-scoped so it does not leak to a request that runs +// without the role. Finding 02-P02. Verified against PostgREST 14.12. +func TestIntegrationImpersonatedRoleSettings(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + // A role granted to the connected authenticator, pinned to a short timeout, so + // loadRoleSettings (which reads roles the authenticator is a member of) picks it + // up. Functions are PUBLIC-executable by default, so the role can call them. + if _, err := be.Pool().Exec(ctx, ` + DROP ROLE IF EXISTS _dbrest_slow; + CREATE ROLE _dbrest_slow; + GRANT _dbrest_slow TO CURRENT_USER; + ALTER ROLE _dbrest_slow SET statement_timeout = '50ms'; + CREATE OR REPLACE FUNCTION public._dbrest_show_timeout() RETURNS text + LANGUAGE sql STABLE AS $$ SELECT current_setting('statement_timeout') $$; + CREATE OR REPLACE FUNCTION public._dbrest_sleep() RETURNS text + LANGUAGE sql VOLATILE AS $$ SELECT pg_sleep(3)::text $$`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP FUNCTION IF EXISTS public._dbrest_show_timeout(); + DROP FUNCTION IF EXISTS public._dbrest_sleep(); + DROP ROLE IF EXISTS _dbrest_slow`) + }) + + // Refresh the catalog so the role's settings are loaded. + if _, err := be.Introspect(ctx); err != nil { + t.Fatalf("Introspect: %v", err) + } + + showTimeout := func(role string) string { + t.Helper() + plan := &ir.Plan{ReadOnly: true, Call: &ir.Call{ + Function: ir.Ref{Name: "_dbrest_show_timeout"}, + Args: map[string]ir.Value{}, + }} + res, err := be.Execute(ctx, plan, &reqctx.Context{Role: role, Method: "GET", Path: "/rpc/_dbrest_show_timeout"}) + if err != nil { + t.Fatalf("show_timeout(%q): %v", role, err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatalf("show_timeout(%q): no rows", role) + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values(%q): %v", role, err) + } + return vals[0].(string) + } + + // The role carries its pinned timeout. + if got := showTimeout("_dbrest_slow"); got != "50ms" { + t.Errorf("statement_timeout as _dbrest_slow = %q, want 50ms", got) + } + // A request without the role does not inherit it (transaction-scoped, no leak). + if got := showTimeout(""); got == "50ms" { + t.Errorf("statement_timeout without the role = %q, want the server default (not 50ms)", got) + } + + // A slow call as the role is cancelled by the pinned timeout. + sleepPlan := &ir.Plan{Call: &ir.Call{ + Function: ir.Ref{Name: "_dbrest_sleep"}, + Args: map[string]ir.Value{}, + }} + _, err := be.Execute(ctx, sleepPlan, &reqctx.Context{Role: "_dbrest_slow", Method: "POST", Path: "/rpc/_dbrest_sleep"}) + if err == nil { + t.Fatal("slow call as _dbrest_slow: want a timeout error, got nil") + } + apiErr, ok := err.(*pgerr.APIError) + if !ok { + t.Fatalf("timeout error type = %T, want *pgerr.APIError", err) + } + if apiErr.Code != "57014" { + t.Errorf("timeout code = %q, want 57014", apiErr.Code) + } + if apiErr.HTTPStatus != 500 { + t.Errorf("timeout status = %d, want 500", apiErr.HTTPStatus) + } +} + +// TestIntegrationReadCallResponseControls proves a STABLE function reached over GET +// can still steer its response: response.status and response.headers it sets are +// read back and folded into the response controls. Before the fix the read-call +// path streamed straight from the cursor and never called readResponseControls, so +// the GUCs a function set on a GET were silently dropped. Finding 02-P05. +func TestIntegrationReadCallResponseControls(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + // A STABLE function (so the call runs read-only, the path under test) that sets a + // status override and a Cache-Control response header the PostgREST way: a JSON + // array of single-key name->value objects in response.headers. + if _, err := be.Pool().Exec(ctx, ` + CREATE OR REPLACE FUNCTION public._dbrest_resp_ctl() RETURNS text + LANGUAGE plpgsql STABLE AS $$ + BEGIN + PERFORM set_config('response.status', '205', true); + PERFORM set_config('response.headers', '[{"Cache-Control": "max-age=60"}]', true); + RETURN 'ok'; + END $$`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP FUNCTION IF EXISTS public._dbrest_resp_ctl()`) + }) + + if _, err := be.Introspect(ctx); err != nil { + t.Fatalf("Introspect: %v", err) + } + + plan := &ir.Plan{ReadOnly: true, Call: &ir.Call{ + Function: ir.Ref{Name: "_dbrest_resp_ctl"}, + Args: map[string]ir.Value{}, + }} + rc := &reqctx.Context{Method: "GET", Path: "/rpc/_dbrest_resp_ctl"} + res, err := be.Execute(ctx, plan, rc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + + ctrl := res.ResponseControls() + if ctrl.Status != 205 { + t.Errorf("response status override = %d, want 205", ctrl.Status) + } + if got := ctrl.Headers["Cache-Control"]; got != "max-age=60" { + t.Errorf("Cache-Control header = %q, want max-age=60", got) + } + + // The body still carries the function's return value. + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatal("Execute: no rows") + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + if vals[0].(string) != "ok" { + t.Errorf("body = %q, want ok", vals[0].(string)) + } +} + +// TestIntegrationReadTablePreRequestControls proves a db-pre-request function can +// steer the response of a plain GET table read: a header it sets via +// response.headers is read back before the body streams. Before the fix the +// table-read path streamed from the cursor and never read the response GUCs, so a +// pre-request that set a header on a GET was silently dropped. Finding 02-P05. +func TestIntegrationReadTablePreRequestControls(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_pr (id serial PRIMARY KEY, val text); + TRUNCATE _dbrest_test_pr; + INSERT INTO _dbrest_test_pr (val) VALUES ('a'); + CREATE OR REPLACE FUNCTION public._dbrest_pre() RETURNS void + LANGUAGE plpgsql AS $$ + BEGIN + PERFORM set_config('response.headers', '[{"X-Pre": "ran"}]', true); + END $$`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP FUNCTION IF EXISTS public._dbrest_pre(); + DROP TABLE IF EXISTS _dbrest_test_pr`) + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_pr", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_pr not found") + } + + rc := &reqctx.Context{Method: "GET", Path: "/_dbrest_test_pr", PreRequest: "_dbrest_pre"} + readPlan := &ir.Plan{ + Rel: rel, + Query: &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_pr"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"val"}}}, + }, + } + res, err := be.Execute(ctx, readPlan, rc) + if err != nil { + t.Fatalf("Execute(read): %v", err) + } + if got := res.ResponseControls().Headers["X-Pre"]; got != "ran" { + t.Errorf("X-Pre header = %q, want ran (pre-request header dropped on table read)", got) + } + // The body still streams the row. + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatal("read returned no rows") + } +} + +// TestIntegrationHoistedTxSettings proves db-hoisted-tx-settings: a function's SET +// clause for a hoisted setting is applied to the transaction, not only the +// function body. default_transaction_isolation is the cleanest probe because it +// can never take effect without hoisting (the transaction has already started by +// the time the function runs), so a function that returns the current isolation +// level reads the database default unless its SET clause was hoisted to BeginTx. +// Finding 02-P03. +func TestIntegrationHoistedTxSettings(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE OR REPLACE FUNCTION public._dbrest_hoist_iso() RETURNS text + LANGUAGE sql STABLE SET default_transaction_isolation = 'serializable' + AS $$ SELECT current_setting('transaction_isolation') $$`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP FUNCTION IF EXISTS public._dbrest_hoist_iso()`) + }) + + if _, err := be.Introspect(ctx); err != nil { + t.Fatalf("Introspect: %v", err) + } + + callIso := func() string { + t.Helper() + plan := &ir.Plan{ReadOnly: true, Call: &ir.Call{ + Function: ir.Ref{Name: "_dbrest_hoist_iso"}, + Args: map[string]ir.Value{}, + }} + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/rpc/_dbrest_hoist_iso"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatal("no rows") + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + return vals[0].(string) + } + + // With no hoisted settings configured, the function's SET clause stays inside + // the body and the transaction runs at the database default. + if got := callIso(); got == "serializable" { + t.Errorf("isolation without hoisting = %q, want the default (not serializable)", got) + } + + // With the v14 default hoist list, default_transaction_isolation is applied at + // BeginTx, so the transaction itself runs serializable. + be.SetHoistedTxSettings([]string{"statement_timeout", "plan_filter.statement_cost_limit", "default_transaction_isolation"}) + if got := callIso(); got != "serializable" { + t.Errorf("isolation with hoisting = %q, want serializable", got) + } +} + +// TestIntegrationRelationKinds proves the schema cache mirrors PostgREST's +// relation set: a materialized view is exposed (as the view kind), a foreign table +// is exposed (as the table kind), and a partitioned table exposes only the parent, +// never its leaf partitions. Before the fix the relkind filter was IN ('r','v','p') +// with no relispartition guard, so matviews and foreign tables were invisible and +// every partition leaked in as its own endpoint. Finding 03-P08 / 03-P14. +func TestIntegrationRelationKinds(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + // A matview over a base table, a partitioned parent with two leaf partitions, + // and a foreign table over a file_fdw server. file_fdw ships with the standard + // contrib package and needs no network, so it is the lightest foreign table to + // stand up; if the extension is unavailable the foreign-table leg is skipped. + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_mvbase (id int PRIMARY KEY, n int); + TRUNCATE _dbrest_test_mvbase; + INSERT INTO _dbrest_test_mvbase VALUES (1, 10), (2, 20); + DROP MATERIALIZED VIEW IF EXISTS _dbrest_test_mv; + CREATE MATERIALIZED VIEW _dbrest_test_mv AS SELECT id, n FROM _dbrest_test_mvbase; + CREATE TABLE IF NOT EXISTS _dbrest_test_part (id int, region text) PARTITION BY LIST (region); + CREATE TABLE IF NOT EXISTS _dbrest_test_part_us PARTITION OF _dbrest_test_part FOR VALUES IN ('us'); + CREATE TABLE IF NOT EXISTS _dbrest_test_part_eu PARTITION OF _dbrest_test_part FOR VALUES IN ('eu')`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, ` + DROP MATERIALIZED VIEW IF EXISTS _dbrest_test_mv; + DROP TABLE IF EXISTS _dbrest_test_mvbase; + DROP TABLE IF EXISTS _dbrest_test_part`) + }) + + // Best-effort foreign table over file_fdw; the test still asserts the matview + // and partition behaviour when the extension is not installed. + haveForeign := false + if _, err := be.Pool().Exec(ctx, ` + CREATE EXTENSION IF NOT EXISTS file_fdw; + DROP SERVER IF EXISTS _dbrest_test_files CASCADE; + CREATE SERVER _dbrest_test_files FOREIGN DATA WRAPPER file_fdw; + CREATE FOREIGN TABLE _dbrest_test_ft (line text) + SERVER _dbrest_test_files OPTIONS (filename '/etc/hostname')`); err == nil { + haveForeign = true + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP SERVER IF EXISTS _dbrest_test_files CASCADE`) + }) + } else { + t.Logf("file_fdw unavailable, skipping foreign-table leg: %v", err) + } + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + + // The materialized view is exposed and carries the view kind. + mv, ok := model.Lookup("_dbrest_test_mv", []string{"public"}) + if !ok { + t.Fatal("materialized view _dbrest_test_mv not exposed") + } + if mv.Kind != schema.KindView { + t.Errorf("matview kind = %v, want KindView", mv.Kind) + } + + // The partitioned parent is exposed; the leaf partitions are not. + if _, ok := model.Lookup("_dbrest_test_part", []string{"public"}); !ok { + t.Error("partitioned parent _dbrest_test_part not exposed") + } + if _, ok := model.Lookup("_dbrest_test_part_us", []string{"public"}); ok { + t.Error("leaf partition _dbrest_test_part_us leaked as an endpoint") + } + if _, ok := model.Lookup("_dbrest_test_part_eu", []string{"public"}); ok { + t.Error("leaf partition _dbrest_test_part_eu leaked as an endpoint") + } + + // The foreign table is exposed and carries the table kind (an FDW can write). + if haveForeign { + ft, ok := model.Lookup("_dbrest_test_ft", []string{"public"}) + if !ok { + t.Error("foreign table _dbrest_test_ft not exposed") + } else if ft.Kind != schema.KindTable { + t.Errorf("foreign table kind = %v, want KindTable", ft.Kind) + } + } +} + +// TestIntegrationCatalogMetadata proves the introspector populates the catalog +// metadata PostgREST's schema cache carries and dbrest's frontend already +// consumes: unique constraints and unique indexes (one-to-one detection, P10), +// identity columns folded into HasDefault with the Identity flag set (P15), and +// table, column, and schema comments (P16). Before the fix none of these reached +// the model: unique sets were empty, identity columns looked default-less, and the +// model carried no descriptions. +func TestIntegrationCatalogMetadata(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + DROP TABLE IF EXISTS _dbrest_test_meta; + CREATE TABLE _dbrest_test_meta ( + id int GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + email text NOT NULL UNIQUE, + slug text NOT NULL, + tenant int NOT NULL, + label text + ); + CREATE UNIQUE INDEX _dbrest_test_meta_slug_tenant ON _dbrest_test_meta (slug, tenant); + COMMENT ON TABLE _dbrest_test_meta IS 'People records'; + COMMENT ON COLUMN _dbrest_test_meta.email IS 'Primary contact email'; + COMMENT ON SCHEMA public IS 'The default schema'`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP TABLE IF EXISTS _dbrest_test_meta; + COMMENT ON SCHEMA public IS NULL`) + }) + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_meta", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_meta not found") + } + + // P15: the identity column is folded into HasDefault and flags Identity. + idCol, ok := rel.Column("id") + if !ok { + t.Fatal("id column missing") + } + if !idCol.Identity { + t.Error("id Identity = false, want true (GENERATED ALWAYS AS IDENTITY)") + } + if !idCol.HasDefault { + t.Error("id HasDefault = false, want true (identity column is server-generated)") + } + + // P10: the single-column unique constraint on email and the composite unique + // index on (slug, tenant) both reach the model; the PK is not duplicated here. + hasUnique := func(want ...string) bool { + for _, u := range rel.Unique { + if len(u) == len(want) { + match := true + for i := range want { + if u[i] != want[i] { + match = false + break + } + } + if match { + return true + } + } + } + return false + } + if !hasUnique("email") { + t.Errorf("unique sets %v missing [email]", rel.Unique) + } + if !hasUnique("slug", "tenant") { + t.Errorf("unique sets %v missing [slug tenant]", rel.Unique) + } + for _, u := range rel.Unique { + if len(u) == 1 && u[0] == "id" { + t.Errorf("unique sets %v include the primary key, want it excluded", rel.Unique) + } + } + + // P16: table, column, and schema comments are populated. + if rel.Comment != "People records" { + t.Errorf("table comment = %q, want %q", rel.Comment, "People records") + } + emailCol, _ := rel.Column("email") + if emailCol.Comment != "Primary contact email" { + t.Errorf("email comment = %q, want %q", emailCol.Comment, "Primary contact email") + } + if got := model.SchemaComment("public"); got != "The default schema" { + t.Errorf("schema comment = %q, want %q", got, "The default schema") + } +} + +// TestIntegrationVoidCallStatus proves a void-returning function answers 204 on +// both verbs, not just POST. A STABLE void function runs through the read path +// (executeCallRead); before the fix that path never detected void, so a GET +// answered 200 with a body while a POST to the same function answered 204. Both +// now signal 204. Finding 03-P17. +func TestIntegrationVoidCallStatus(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE OR REPLACE FUNCTION public._dbrest_void_stable() RETURNS void + LANGUAGE sql STABLE AS $$ SELECT $$; + CREATE OR REPLACE FUNCTION public._dbrest_void_volatile() RETURNS void + LANGUAGE sql VOLATILE AS $$ SELECT $$`); err != nil { + t.Fatalf("seed: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, `DROP FUNCTION IF EXISTS public._dbrest_void_stable(); + DROP FUNCTION IF EXISTS public._dbrest_void_volatile()`) + }) + + status := func(fn, method string) int { + t.Helper() + plan := &ir.Plan{ + ReadOnly: method == "GET", + Call: &ir.Call{Function: ir.Ref{Name: fn}, Args: map[string]ir.Value{}}, + } + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: method, Path: "/rpc/" + fn}) + if err != nil { + t.Fatalf("Execute(%s %s): %v", method, fn, err) + } + return res.ResponseControls().Status + } + + // GET to the stable function runs the read path; POST to the volatile function + // runs the write path. Both detect void and signal 204. + if got := status("_dbrest_void_stable", "GET"); got != 204 { + t.Errorf("GET void status = %d, want 204 (read path void detection)", got) + } + if got := status("_dbrest_void_volatile", "POST"); got != 204 { + t.Errorf("POST void status = %d, want 204", got) + } +} + +// TestIntegrationRangeRendering proves int4range, numrange, daterange, tsrange, +// tstzrange, and int4multirange columns render through the backend as the same +// text PostgreSQL itself emits, instead of the pgtype.Range/Multirange Go structs +// json would marshal. The expected values are read back with to_json so the +// assertion tracks the live server's TimeZone for tstzrange. Finding 04-E05. +func TestIntegrationRangeRendering(t *testing.T) { + be := openBE(t) + ctx := context.Background() + + if _, err := be.Pool().Exec(ctx, ` + CREATE TABLE IF NOT EXISTS _dbrest_test_range ( + id int PRIMARY KEY, + i4 int4range, + nr numrange, + dr daterange, + tsr tsrange, + ttzr tstzrange, + mr int4multirange, + emp int4range, + unb int8range + ); + TRUNCATE _dbrest_test_range; + INSERT INTO _dbrest_test_range VALUES ( + 1, + '[10,20)', + '(1.5,3.5]', + '[2020-01-01,2020-12-31)', + '[2020-01-01 10:00,2020-06-01 12:00)', + '[2020-01-01 10:00+05,2020-06-01 12:00+05)', + '{[1,3),[5,8)}', + 'empty', + '[100,)' + )`); err != nil { + t.Fatalf("seed range table: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(ctx, "DROP TABLE IF EXISTS _dbrest_test_range") + }) + + cols := []string{"i4", "nr", "dr", "tsr", "ttzr", "mr", "emp", "unb"} + // The oracle: PostgreSQL's own JSON spelling for each column. to_json renders a + // range/multirange as a JSON string, so unmarshalling yields the bare text form + // (with the quoted-bound escaping already resolved) the renderer must produce. + want := make([]string, len(cols)) + for i, c := range cols { + var j string + if err := be.Pool().QueryRow(ctx, + "SELECT to_json("+c+")::text FROM _dbrest_test_range WHERE id = 1").Scan(&j); err != nil { + t.Fatalf("oracle to_json(%s): %v", c, err) + } + if err := json.Unmarshal([]byte(j), &want[i]); err != nil { + t.Fatalf("oracle unmarshal(%s) %q: %v", c, j, err) + } + } + + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + rel, ok := model.Lookup("_dbrest_test_range", []string{"public"}) + if !ok { + t.Fatal("_dbrest_test_range not found") + } + sel := make([]ir.SelectItem, len(cols)) + for i, c := range cols { + sel[i] = ir.Column{Path: []string{c}} + } + plan := &ir.Plan{Rel: rel, Query: &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Schema: "public", Name: "_dbrest_test_range"}, + Select: sel, + }} + res, err := be.Execute(ctx, plan, &reqctx.Context{Method: "GET", Path: "/_dbrest_test_range"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rs := res.Rows() + defer rs.Close() + if !rs.Next() { + t.Fatal("no rows") + } + vals, err := rs.Values() + if err != nil { + t.Fatalf("Values: %v", err) + } + for i, c := range cols { + got, ok := vals[i].(string) + if !ok { + t.Errorf("column %s rendered as %T (%v), want a string", c, vals[i], vals[i]) + continue + } + if got != want[i] { + t.Errorf("column %s = %q, want %q (PostgreSQL to_json)", c, got, want[i]) + } + } +} + +func condPtr(c ir.Cond) *ir.Cond { return &c } +func intPtr(n int) *int { return &n } + func BenchmarkIntegrationRead(b *testing.B) { dsn := os.Getenv("DBREST_PG_DSN") if dsn == "" { diff --git a/backend/postgres/introspect.go b/backend/postgres/introspect.go index 90223cf..ee2bffe 100644 --- a/backend/postgres/introspect.go +++ b/backend/postgres/introspect.go @@ -9,10 +9,11 @@ import ( // Introspect builds the unified schema model from PostgreSQL's system catalogs. // The exposed schemas come from b.searchPath; if none are configured, only the -// default search_path ($user, public) is used. Only ordinary tables and views are -// exposed; sequences, materialized views, and internal catalogs are omitted. -// Columns, primary keys, and foreign keys are read from pg_attribute and -// pg_constraint. See spec 08. +// default search_path ($user, public) is used. The exposed relations mirror +// PostgREST's schema cache: ordinary tables, views, materialized views, foreign +// tables, and partitioned parents (leaf partitions excluded). Columns, primary +// keys, unique sets, foreign keys, identity flags, and comments are read from +// pg_attribute, pg_constraint, pg_index, and pg_description. See spec 08. func (b *Backend) Introspect(ctx context.Context) (*schema.Model, error) { schemas := b.searchPath if len(schemas) == 0 { @@ -29,22 +30,124 @@ func (b *Backend) Introspect(ctx context.Context) (*schema.Model, error) { } fksByRel := groupFKs(fks) + // View output columns mapped to their base-relation columns, so the model can + // project base-table foreign keys onto views and embedding works through a view + // the same way it does through its base table (spec 09). + viewCols, err := b.loadViewColumns(ctx, schemas) + if err != nil { + return nil, err + } + + // Computed fields are functions taking a relation's row type and returning a + // scalar, exposed as virtual columns selectable, filterable, and orderable like + // stored ones (spec 11). They are read here with the rest of the catalog so they + // refresh on every rebuild and attach to their relation below. + computed, err := b.loadComputedFields(ctx, schemas) + if err != nil { + return nil, err + } + + // Computed relationships are functions taking a relation's row type and + // returning rows of another relation, exposed as embeddable edges (spec 11, the + // escape hatch for recursive embeds). Read here with the rest of the catalog and + // attached to their parent relation below. + computedRels, err := b.loadComputedRels(ctx, schemas) + if err != nil { + return nil, err + } + + // Data representations are domain types whose casts to and from json/text + // reformat a column's wire value (spec 11). Read once, keyed by domain type OID, + // and attached to each column of that domain as the columns are loaded below. + reps, err := b.loadRepresentations(ctx, schemas) + if err != nil { + return nil, err + } + + // Function volatility drives the native RPC transaction access mode (a STABLE + // or IMMUTABLE function runs read-only even on POST), so it is loaded here with + // the rest of the catalog and refreshed whenever the model is rebuilt. + vol, err := b.loadFunctionVolatility(ctx, schemas) + if err != nil { + return nil, err + } + b.funcVol = vol + + // Function return shapes drive the native RPC renderer (a SETOF scalar renders + // as an array of bare values, a single composite as one object), loaded with the + // rest of the catalog and refreshed on every rebuild like volatility. + ret, err := b.loadFunctionReturns(ctx, schemas) + if err != nil { + return nil, err + } + b.funcRet = ret + + // The native function registry is the function half of PostgREST's schema cache: + // full signatures per schema so the native RPC path resolves overloads, raises + // PGRST202/PGRST203, and partitions GET arguments from result filters through the + // shared planner. Loaded with the catalog and refreshed on every rebuild. + reg, err := b.loadFunctionRegistry(ctx, schemas) + if err != nil { + return nil, err + } + b.funcReg = reg + + // Impersonated-role settings (ALTER ROLE ... SET) are replayed per request as + // transaction-scoped settings, so they are loaded with the catalog and + // refreshed on every rebuild, the same lifecycle PostgREST gives them. + rs, iso, err := b.loadRoleSettings(ctx) + if err != nil { + return nil, err + } + b.roleSettings = rs + b.roleIsolation = iso + + // Function SET clauses (pg_proc.proconfig) drive db-hoisted-tx-settings: an RPC + // call hoists the named settings to the transaction. Loaded with the catalog so + // it refreshes on every rebuild, like role settings and volatility. + pc, err := b.loadFunctionProconfig(ctx, schemas) + if err != nil { + return nil, err + } + b.funcProconfig = pc + var out []*schema.Relation for _, r := range rels { - cols, pk, err := b.columns(ctx, r.oid) + cols, pk, err := b.columns(ctx, r.oid, reps) + if err != nil { + return nil, err + } + uniq, err := b.uniques(ctx, r.oid) if err != nil { return nil, err } out = append(out, &schema.Relation{ - Schema: r.schemaName, - Name: r.name, - Kind: r.kind, - Columns: cols, - PrimaryKey: pk, - ForeignKeys: fksByRel[r.oid], + Schema: r.schemaName, + Name: r.name, + Kind: r.kind, + Comment: r.comment, + Columns: cols, + PrimaryKey: pk, + Unique: uniq, + ForeignKeys: fksByRel[r.oid], + ViewColumns: viewCols[r.oid], + Computed: computed[r.oid], + ComputedRels: computedRels[r.oid], }) } - return schema.NewModel(out), nil + + // Schema-level comments feed the OpenAPI info block (title and description), + // the same source PostgREST uses. They are attached to the model before it is + // published, alongside the relation and column comments read above. + comments, err := b.schemaComments(ctx, schemas) + if err != nil { + return nil, err + } + model := schema.NewModel(out) + for name, comment := range comments { + model.SetSchemaComment(name, comment) + } + return model, nil } type relInfo struct { @@ -52,16 +155,26 @@ type relInfo struct { schemaName string name string kind schema.Kind + comment string } func (b *Backend) relationNames(ctx context.Context, schemas []string) ([]relInfo, error) { - // Build a literal array of quoted schema names for the ANY(...) test. + // The relation set mirrors PostgREST's schema cache: ordinary tables ('r'), + // views ('v'), materialized views ('m'), foreign tables ('f'), and partitioned + // parents ('p'). Materialized views map to the view kind (read-mostly; a write + // fails with PostgreSQL's own error, the same passthrough as PostgREST), while + // foreign tables map to the table kind since an FDW can accept writes. + // Partitions are excluded with NOT c.relispartition so only the partitioned + // parent is an endpoint, matching upstream; this drops both leaf partitions + // ('r') and intermediate sub-partitioned tables ('p'). q := ` SELECT c.oid, n.nspname, c.relname, - CASE c.relkind WHEN 'v' THEN 'v' ELSE 't' END AS kind + CASE c.relkind WHEN 'v' THEN 'v' WHEN 'm' THEN 'v' ELSE 't' END AS kind, + COALESCE(obj_description(c.oid, 'pg_class'), '') AS comment FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace - WHERE c.relkind IN ('r','v','p') + WHERE c.relkind IN ('r','v','m','f','p') + AND NOT c.relispartition AND n.nspname = ANY($1) ORDER BY n.nspname, c.relname` rows, err := b.pool.Query(ctx, q, schemas) @@ -74,7 +187,7 @@ SELECT c.oid, n.nspname, c.relname, for rows.Next() { var r relInfo var kindStr string - if err := rows.Scan(&r.oid, &r.schemaName, &r.name, &kindStr); err != nil { + if err := rows.Scan(&r.oid, &r.schemaName, &r.name, &kindStr, &r.comment); err != nil { return nil, err } if kindStr == "v" { @@ -87,16 +200,20 @@ SELECT c.oid, n.nspname, c.relname, return out, rows.Err() } -func (b *Backend) columns(ctx context.Context, relOID uint32) ([]*schema.Column, []string, error) { +func (b *Backend) columns(ctx context.Context, relOID uint32, reps map[uint32]*schema.Representation) ([]*schema.Column, []string, error) { // pg_attribute carries every attribute including system columns (attnum < 0) // and dropped columns (attisdropped). We want only live user columns. // pg_constraint with contype='p' gives the primary-key columns in confkey order // via unnest; the conkey[] entries are attribute numbers matching attnum. + // atttypid is the column's exact type OID, which carries the representation cast + // set when the type is a domain. colQ := ` SELECT a.attname, format_type(a.atttypid, a.atttypmod), NOT a.attnotnull AS nullable, pg_get_expr(d.adbin, d.adrelid) IS NOT NULL AS has_default, - a.attnum + a.attidentity <> '' AS is_identity, + COALESCE(col_description(a.attrelid, a.attnum), '') AS comment, + a.attnum, a.atttypid FROM pg_attribute a LEFT JOIN pg_attrdef d ON d.adrelid = a.attrelid AND d.adnum = a.attnum WHERE a.attrelid = $1 AND a.attnum > 0 AND NOT a.attisdropped @@ -110,18 +227,28 @@ SELECT a.attname, format_type(a.atttypid, a.atttypmod), attByNum := map[int]string{} var cols []*schema.Column for rows.Next() { - var name, pgType string - var nullable, hasDef bool + var name, pgType, comment string + var nullable, hasDef, isIdentity bool var attnum int - if err := rows.Scan(&name, &pgType, &nullable, &hasDef, &attnum); err != nil { + var typOID uint32 + if err := rows.Scan(&name, &pgType, &nullable, &hasDef, &isIdentity, &comment, &attnum, &typOID); err != nil { return nil, nil, err } cols = append(cols, &schema.Column{ - Name: name, - Type: canonicalType(pgType), + Name: name, + Type: canonicalType(pgType), + // An identity column has no pg_attrdef row, so fold it into HasDefault: + // it is server-generated and never required, the same way PostgREST + // treats GENERATED AS IDENTITY. Generated (STORED) columns already carry + // a pg_attrdef row, so has_default covers them. Nullable: nullable, - HasDefault: hasDef, + HasDefault: hasDef || isIdentity, + Identity: isIdentity, + Comment: comment, Position: attnum, + // A column whose type is a domain with representation casts carries the + // cast set so the compiler reformats it on the wire (spec 11). + Rep: reps[typOID], }) attByNum[attnum] = name } @@ -152,6 +279,72 @@ SELECT a.attname return cols, pk, pkRows.Err() } +// uniques reads the relation's unique column sets, the data the model needs to +// see a foreign key as one-to-one (an FK whose columns equal a unique set embeds +// as an object, not an array; spec 09). It reads unique indexes rather than only +// unique constraints, which captures both: every unique constraint is backed by a +// unique index, and a bare CREATE UNIQUE INDEX is just as good a one-to-one +// witness, which is what PostgREST's pks_uniques_cols covers. The primary key is +// excluded (indisprimary) because the model already carries it separately; a +// partial index (indpred) cannot guarantee uniqueness over the whole table, and an +// expression index has a zero attnum, so both are dropped. Only the key columns +// count, not INCLUDE columns past indnkeyatts. +func (b *Backend) uniques(ctx context.Context, relOID uint32) ([][]string, error) { + q := ` +SELECT array_agg(a.attname ORDER BY k.ord) + FROM pg_index i + CROSS JOIN LATERAL unnest(i.indkey) WITH ORDINALITY AS k(attnum, ord) + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = k.attnum + WHERE i.indrelid = $1 + AND i.indisunique + AND NOT i.indisprimary + AND i.indpred IS NULL + AND k.ord <= i.indnkeyatts + GROUP BY i.indexrelid +HAVING bool_and(k.attnum > 0)` + rows, err := b.pool.Query(ctx, q, relOID) + if err != nil { + return nil, err + } + defer rows.Close() + + var out [][]string + for rows.Next() { + var cols []string + if err := rows.Scan(&cols); err != nil { + return nil, err + } + out = append(out, cols) + } + return out, rows.Err() +} + +// schemaComments reads the database comment on each exposed schema, the source of +// the OpenAPI info title (first line) and description (rest), the same as +// PostgREST. A schema with no comment is omitted from the map. +func (b *Backend) schemaComments(ctx context.Context, schemas []string) (map[string]string, error) { + q := ` +SELECT n.nspname, obj_description(n.oid, 'pg_namespace') + FROM pg_namespace n + WHERE n.nspname = ANY($1) + AND obj_description(n.oid, 'pg_namespace') IS NOT NULL` + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + out := map[string]string{} + for rows.Next() { + var name, comment string + if err := rows.Scan(&name, &comment); err != nil { + return nil, err + } + out[name] = comment + } + return out, rows.Err() +} + type fkInfo struct { relOID uint32 name string diff --git a/backend/postgres/listener.go b/backend/postgres/listener.go new file mode 100644 index 0000000..5a5ef89 --- /dev/null +++ b/backend/postgres/listener.go @@ -0,0 +1,97 @@ +package postgres + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5" + + "github.com/tamnd/dbrest/backend" +) + +// listenMaxBackoff caps the reconnect wait, matching PostgREST's listener. +const listenMaxBackoff = 32 * time.Second + +// Listen implements backend.Listener over LISTEN/NOTIFY. It dedicates a +// connection (separate from the pool, since a connection blocked waiting for a +// notification cannot serve queries) to the named channel and reconnects with +// exponential backoff capped at 32 seconds. After a reconnect it calls +// OnReconnect, because notifications sent while the connection was down are lost +// and the schema cache may be stale, mirroring PostgREST's "assume we lost +// notifications, refresh the schema cache" behavior. +func (b *Backend) Listen(ctx context.Context, channel string, h backend.ListenHandler) error { + backoff := time.Second + firstConnect := true + for { + if err := ctx.Err(); err != nil { + return err + } + conn, err := pgx.ConnectConfig(ctx, b.connConfig) + if err != nil { + if werr := waitBackoff(ctx, backoff); werr != nil { + return werr + } + backoff = nextBackoff(backoff) + continue + } + // (Re)connected: reset the backoff and signal a reconnect so the caller + // can recover any notifications missed while the connection was down. + backoff = time.Second + if !firstConnect && h.OnReconnect != nil { + h.OnReconnect() + } + firstConnect = false + + err = b.waitForNotifications(ctx, conn, channel, h) + _ = conn.Close(context.Background()) + if ctx.Err() != nil { + return ctx.Err() + } + // err is a lost-connection error; the loop above reconnects with backoff. + _ = err + if werr := waitBackoff(ctx, backoff); werr != nil { + return werr + } + backoff = nextBackoff(backoff) + } +} + +// waitForNotifications issues LISTEN on the channel, then blocks delivering each +// notification's payload to the handler until the connection drops or ctx is +// canceled. It returns the error that ended the loop. +func (b *Backend) waitForNotifications(ctx context.Context, conn *pgx.Conn, channel string, h backend.ListenHandler) error { + if _, err := conn.Exec(ctx, "LISTEN "+pgx.Identifier{channel}.Sanitize()); err != nil { + return err + } + for { + n, err := conn.WaitForNotification(ctx) + if err != nil { + return err + } + if h.OnNotify != nil { + h.OnNotify(n.Payload) + } + } +} + +// waitBackoff sleeps for d unless ctx is canceled first, in which case it returns +// ctx.Err(). +func waitBackoff(ctx context.Context, d time.Duration) error { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return nil + } +} + +// nextBackoff doubles d, capped at listenMaxBackoff. +func nextBackoff(d time.Duration) time.Duration { + d *= 2 + if d > listenMaxBackoff { + return listenMaxBackoff + } + return d +} diff --git a/backend/postgres/listener_integration_test.go b/backend/postgres/listener_integration_test.go new file mode 100644 index 0000000..34e1598 --- /dev/null +++ b/backend/postgres/listener_integration_test.go @@ -0,0 +1,69 @@ +package postgres_test + +import ( + "context" + "testing" + "time" + + "github.com/tamnd/dbrest/backend" +) + +// TestIntegrationListen drives the live db-channel path: a NOTIFY on the channel +// the backend LISTENs on is delivered to the handler with its payload intact. +// NOTIFY only reaches sessions already listening, so the test notifies on a +// ticker until the listener has subscribed and the payload arrives. +func TestIntegrationListen(t *testing.T) { + be := openBE(t) + const channel = "dbrest_test_chan" + + got := make(chan string, 8) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + _ = be.Listen(ctx, channel, backend.ListenHandler{ + OnNotify: func(payload string) { got <- payload }, + }) + }() + + tick := time.NewTicker(100 * time.Millisecond) + defer tick.Stop() + deadline := time.After(5 * time.Second) + for { + select { + case <-tick.C: + if _, err := be.Pool().Exec(ctx, "SELECT pg_notify($1, $2)", channel, "reload schema"); err != nil { + t.Fatalf("pg_notify: %v", err) + } + case p := <-got: + if p != "reload schema" { + t.Fatalf("payload = %q, want %q", p, "reload schema") + } + return + case <-deadline: + t.Fatal("no notification delivered within 5s") + } + } +} + +// TestIntegrationListenStopsOnCancel confirms Listen returns promptly with the +// context error once its context is canceled, so the boot-time goroutine does +// not leak. +func TestIntegrationListenStopsOnCancel(t *testing.T) { + be := openBE(t) + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan error, 1) + go func() { done <- be.Listen(ctx, "dbrest_test_cancel", backend.ListenHandler{}) }() + + // Give the listener a moment to connect and subscribe, then cancel. + time.Sleep(200 * time.Millisecond) + cancel() + select { + case err := <-done: + if err != context.Canceled { + t.Fatalf("Listen returned %v, want context.Canceled", err) + } + case <-time.After(3 * time.Second): + t.Fatal("Listen did not return within 3s of cancel") + } +} diff --git a/backend/postgres/native_count_test.go b/backend/postgres/native_count_test.go new file mode 100644 index 0000000..bd7b81a --- /dev/null +++ b/backend/postgres/native_count_test.go @@ -0,0 +1,160 @@ +package postgres + +import ( + "reflect" + "strings" + "testing" + + "github.com/tamnd/dbrest/backend/sqlgen" + "github.com/tamnd/dbrest/ir" +) + +// 07.13: a native (non-registry) RPC with count=exact used to crash, because the +// count path ran the registry count compiler on a nil function. The native count +// is now built in the backend by wrapping the same call the row query runs. + +// TestCompileNativeCallCountWrapsCall: the native count is a count(*) over the +// SELECT * FROM fn(...) the row statement issues. +func TestCompileNativeCallCountWrapsCall(t *testing.T) { + b := &Backend{searchPath: []string{"public"}} + c := &ir.Call{Function: ir.Ref{Name: "recent_films"}} + + row, apiErr := b.compileNativeCall(c, "public", nil) + if apiErr != nil { + t.Fatalf("compileNativeCall: %v", apiErr) + } + cnt, apiErr := b.compileNativeCallCount(c, "public", nil) + if apiErr != nil { + t.Fatalf("compileNativeCallCount: %v", apiErr) + } + + want := "SELECT count(*) FROM (" + row.SQL + ") _rpc" + if cnt.SQL != want { + t.Errorf("count SQL = %q, want %q", cnt.SQL, want) + } + if !strings.Contains(cnt.SQL, `"public"."recent_films"`) { + t.Errorf("count SQL missing schema-qualified call: %q", cnt.SQL) + } +} + +// TestCompileNativeCallCountWithArgs: the wrapper carries the call arguments +// through as embedded literals, the same as the row statement. +func TestCompileNativeCallCountWithArgs(t *testing.T) { + b := &Backend{searchPath: []string{"app"}} + c := &ir.Call{ + Function: ir.Ref{Name: "search"}, + Args: map[string]ir.Value{"q": {Text: "blade"}}, + } + cnt, apiErr := b.compileNativeCallCount(c, "app", nil) + if apiErr != nil { + t.Fatalf("compileNativeCallCount: %v", apiErr) + } + if !strings.HasPrefix(cnt.SQL, "SELECT count(*) FROM (SELECT * FROM ") { + t.Errorf("count SQL prefix = %q", cnt.SQL) + } + if !strings.Contains(cnt.SQL, "'blade'") { + t.Errorf("count SQL missing argument literal: %q", cnt.SQL) + } +} + +// 03-P02: a VOLATILE function must run exactly once, so a POST with count=exact +// cannot count with a separate statement (that would invoke the function twice and +// double its side effects). The counted wrap rides count(*) OVER () on the row +// query; the total is read off any returned row and the column dropped. + +// TestCompileNativeCallCountedWrapRidesWindow: the counted wrap projects the call +// columns plus count(*) OVER () AS "_pgrst_count", in one statement. +func TestCompileNativeCallCountedWrapRidesWindow(t *testing.T) { + b := &Backend{searchPath: []string{"public"}} + c := &ir.Call{Function: ir.Ref{Name: "enroll_and_list"}} + + inner, apiErr := b.compileNativeCall(c, "public", nil) + if apiErr != nil { + t.Fatalf("compileNativeCall: %v", apiErr) + } + st, apiErr := sqlgen.CompileNativeCallCountedWrap(Dialect{}, c, inner) + if apiErr != nil { + t.Fatalf("CompileNativeCallCountedWrap: %v", apiErr) + } + want := `SELECT *, count(*) OVER () AS "_pgrst_count" FROM (` + inner.SQL + `) _rpc` + if st.SQL != want { + t.Errorf("SQL = %q\nwant %q", st.SQL, want) + } +} + +// TestCompileNativeCallCountedWrapPostFilters: the page-shaping select, filter, and +// window apply to the wrapped call exactly as the uncounted wrap, and count(*) OVER +// () still counts the full filtered set because it is evaluated before the LIMIT. +func TestCompileNativeCallCountedWrapPostFilters(t *testing.T) { + b := &Backend{searchPath: []string{"public"}} + limit := 2 + c := &ir.Call{ + Function: ir.Ref{Name: "make_films"}, + Select: []ir.SelectItem{col("title")}, + Limit: &limit, + } + inner, apiErr := b.compileNativeCall(c, "public", nil) + if apiErr != nil { + t.Fatalf("compileNativeCall: %v", apiErr) + } + st, apiErr := sqlgen.CompileNativeCallCountedWrap(Dialect{}, c, inner) + if apiErr != nil { + t.Fatalf("CompileNativeCallCountedWrap: %v", apiErr) + } + want := `SELECT "title", count(*) OVER () AS "_pgrst_count" FROM (` + inner.SQL + `) _rpc LIMIT 2` + if st.SQL != want { + t.Errorf("SQL = %q\nwant %q", st.SQL, want) + } +} + +// TestExtractCountWindow: the helper reads the repeated total off the first row and +// drops the count column from the columns and every row, leaving the body shape the +// renderer expects. +func TestExtractCountWindow(t *testing.T) { + cols := []string{"id", "title", sqlgen.CountColName} + buf := [][]any{ + {int64(1), "Dune", int64(2)}, + {int64(2), "Arrival", int64(2)}, + } + gotCols, gotRows, total := extractCountWindow(cols, buf) + if total != 2 { + t.Errorf("total = %d, want 2", total) + } + if !reflect.DeepEqual(gotCols, []string{"id", "title"}) { + t.Errorf("cols = %v, want [id title]", gotCols) + } + want := [][]any{{int64(1), "Dune"}, {int64(2), "Arrival"}} + if !reflect.DeepEqual(gotRows, want) { + t.Errorf("rows = %v, want %v", gotRows, want) + } +} + +// An empty result carries no row to read the window off, so the total is zero and +// the count column is still dropped from the column list. +func TestExtractCountWindowEmpty(t *testing.T) { + cols := []string{"id", sqlgen.CountColName} + gotCols, gotRows, total := extractCountWindow(cols, nil) + if total != 0 { + t.Errorf("total = %d, want 0", total) + } + if !reflect.DeepEqual(gotCols, []string{"id"}) { + t.Errorf("cols = %v, want [id]", gotCols) + } + if len(gotRows) != 0 { + t.Errorf("rows = %v, want empty", gotRows) + } +} + +// Without the window column the helper is a no-op: a function whose result was not +// compiled with the counted wrap passes through unchanged. +func TestExtractCountWindowAbsent(t *testing.T) { + cols := []string{"id", "title"} + buf := [][]any{{int64(1), "Dune"}} + gotCols, gotRows, total := extractCountWindow(cols, buf) + if total != 0 { + t.Errorf("total = %d, want 0", total) + } + if !reflect.DeepEqual(gotCols, cols) || !reflect.DeepEqual(gotRows, buf) { + t.Errorf("expected pass-through, got cols=%v rows=%v", gotCols, gotRows) + } +} diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index 5441023..deb70ad 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -13,6 +13,8 @@ import ( "context" "errors" "strconv" + "strings" + "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" @@ -31,11 +33,21 @@ const defaultPoolMaxConns = 10 // connection pool, the server version (which grades a couple of capabilities), // the function registry, and the search path applied to every request. type Backend struct { - pool *pgxpool.Pool - version Version - funcs rpc.Registry - searchPath []string - searchPathSQL string // pre-built "SET LOCAL search_path TO ..." statement + pool *pgxpool.Pool + connConfig *pgx.ConnConfig // template for a dedicated LISTEN connection (db-channel) + version Version + funcs rpc.Registry + searchPath []string + extraSearchPath []string // db-extra-search-path, appended after the active schema + loc *time.Location // server TimeZone, for rendering timestamptz like PostgREST + funcVol map[string]rpc.Volatility // "schema.name" -> volatility, for native RPC access mode + funcRet map[string]rpc.ReturnShape // "schema.name" -> return shape, for native RPC result rendering + funcReg map[string]rpc.Registry // schema -> native function registry, the function half of the schema cache + roleSettings map[string][]roleSetting // impersonated-role ALTER ROLE ... SET replays + roleIsolation map[string]pgx.TxIsoLevel // impersonated-role default_transaction_isolation + + hoistedTxSettings []string // db-hoisted-tx-settings: which function SET options hoist to the tx + funcProconfig map[string][]roleSetting // "schema.name" -> function SET clause (pg_proc.proconfig) } // Open connects to PostgreSQL by connection string (a libpq URI or keyword/value @@ -48,6 +60,23 @@ type Backend struct { // queries avoid a server-side parse on every execution. This is one of the key // throughput advantages over PostgREST. func Open(dsn string) (*Backend, error) { + return OpenWith(dsn, Options{PreparedStatements: true}) +} + +// Options carries the open-time settings the postgres backend can vary. The zero +// value is not the default: callers use Open (prepared statements on) or pass an +// explicit Options. +type Options struct { + // PreparedStatements maps PostgREST's db-prepared-statements. On (the default), + // the pool uses cache_describe so each distinct query is parsed once per + // connection. Off selects the unprepared exec protocol, which parameterizes + // every query over the extended protocol without caching a statement, the + // pooler-safe equivalent of PostgREST's "parameterized but not prepared". + PreparedStatements bool +} + +// OpenWith connects like Open but honors the supplied Options. +func OpenWith(dsn string, opts Options) (*Backend, error) { cfg, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, err @@ -56,11 +85,7 @@ func Open(dsn string) (*Backend, error) { if cfg.MaxConns < 1 { cfg.MaxConns = defaultPoolMaxConns } - // Enable automatic prepared-statement caching so the server parses each - // distinct query only once per connection. pgx stores the type-descriptor - // cache on the connection; pgxpool serializes reuse so the cache is - // consistent per connection lifetime. - cfg.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe + cfg.ConnConfig.DefaultQueryExecMode = resolveExecMode(dsn, cfg.ConnConfig.DefaultQueryExecMode, opts.PreparedStatements) pool, err := pgxpool.NewWithConfig(context.Background(), cfg) if err != nil { return nil, err @@ -75,7 +100,38 @@ func Open(dsn string) (*Backend, error) { pool.Close() return nil, err } - return &Backend{pool: pool, version: ParseVersion(ver)}, nil + // PostgREST assembles JSON in the database, so a timestamptz renders in the + // server's TimeZone. Capture it once here and render timestamptz in the same + // zone (DST included) so the wire value matches; fall back to UTC when the + // name does not resolve to a Go location. + loc := time.UTC + var tz string + if err := pool.QueryRow(ctx, "SHOW timezone").Scan(&tz); err == nil { + if l, lerr := time.LoadLocation(tz); lerr == nil { + loc = l + } + } + return &Backend{pool: pool, connConfig: cfg.ConnConfig.Copy(), version: ParseVersion(ver), loc: loc}, nil +} + +// resolveExecMode picks the pool's default query exec mode. An explicit DSN +// choice wins (default_query_exec_mode=simple_protocol or exec, pgx's documented +// escape hatch for poolers); honor it rather than clobbering it. pgx parses the +// param into parsed, but an omitted value and an explicit cache_statement both +// decode to the same zero value, so the presence test keys on the raw DSN string, +// where the param name is unambiguous. With no DSN choice, db-prepared-statements +// decides: on (the default) selects cache_describe, which parses each distinct +// query once per connection while keeping unnamed statements a transaction-mode +// pooler (PgBouncer) tolerates; off selects exec, which parameterizes every query +// without preparing one, matching PostgREST's db-prepared-statements=false. +func resolveExecMode(dsn string, parsed pgx.QueryExecMode, prepared bool) pgx.QueryExecMode { + if strings.Contains(dsn, "default_query_exec_mode") { + return parsed + } + if !prepared { + return pgx.QueryExecModeExec + } + return pgx.QueryExecModeCacheDescribe } // Pool exposes the underlying connection pool, for tests that seed a database. @@ -84,13 +140,32 @@ func (b *Backend) Pool() *pgxpool.Pool { return b.pool } // ServerVersion reports the parsed server version, for logging and tests. func (b *Backend) ServerVersion() Version { return b.version } -// SetSchemas records the exposed schemas as the search path applied to every -// request (SET LOCAL search_path), matching PostgREST's db-schemas behaviour so -// unqualified names in policies and functions resolve the same way. The -// corresponding SQL statement is pre-built once here and reused per request. +// SetSchemas records the exposed schemas. The first is the default active +// schema; the rest are reachable by Accept-Profile/Content-Profile. The +// per-request search_path is built from the active schema (not the whole set), +// matching PostgREST, which puts only the active schema plus db-extra-search-path +// on the path so unqualified names resolve the same way (see queueSessionItems). func (b *Backend) SetSchemas(schemas []string) { b.searchPath = schemas - b.searchPathSQL = buildSearchPathSQL(schemas) +} + +// SetExtraSearchPath records db-extra-search-path: schemas appended to the +// search_path after the active schema so type and function resolution can reach +// them without exposing them as queryable schemas. PostgREST defaults this to +// "public" and does not dedup, so a request on the public schema gets the path +// "public", "public"; dbrest reproduces that verbatim. +func (b *Backend) SetExtraSearchPath(schemas []string) { + b.extraSearchPath = schemas +} + +// SetHoistedTxSettings records db-hoisted-tx-settings: the function SET options +// (statement_timeout, plan_filter.statement_cost_limit, +// default_transaction_isolation by default) that an RPC call hoists to the +// transaction so they override the role and connection settings for the whole +// statement, matching PostgREST. The named settings are applied per call from the +// function's introspected proconfig (see hoistFor). +func (b *Backend) SetHoistedTxSettings(names []string) { + b.hoistedTxSettings = names } // Register installs the portable function registry exposed at /rpc/. On @@ -108,12 +183,29 @@ func (b *Backend) Functions() rpc.Registry { return b.funcs } +// SchemaFunctions returns the native function registry introspected for one +// exposed schema, the function half of the schema cache. It is empty until +// Introspect has run, and empty for a schema with no functions, so a caller always +// has a registry to resolve against. The native RPC path uses it to resolve +// overloads and partition GET arguments through the shared planner. +func (b *Backend) SchemaFunctions(schema string) rpc.Registry { + if reg, ok := b.funcReg[schema]; ok { + return reg + } + return rpc.EmptyRegistry{} +} + // Capabilities reports the PostgreSQL feature tiers for the connected server // version (spec 04/06). func (b *Backend) Capabilities() backend.Capabilities { return Capabilities(b.version) } +// SupportsPreRequest reports that this backend runs the db-pre-request function. +// queueSessionItems issues SELECT () in the request transaction after the +// session settings, so main.go accepts the option here rather than refusing it. +func (b *Backend) SupportsPreRequest() bool { return true } + // Close releases the pool. func (b *Backend) Close() error { b.pool.Close() @@ -133,27 +225,49 @@ func (b *Backend) MapError(err error) *pgerr.APIError { if pg, ok := errors.AsType[*pgconn.PgError](err); ok { return mapPgError(pg) } + return mapTransportError(err) +} + +// mapTransportError classifies a driver-level failure that never reached a +// SQLSTATE into PostgREST's connection-error family (group 0). A failed or +// refused dial surfaces from pgx as *pgconn.ConnectError and becomes PGRST000 +// (503, retryable); a pool-acquisition timeout surfaces as a context deadline +// and becomes PGRST003 (504, the "Timed out acquiring connection" case); +// anything else stays an internal 500. PostgREST also has PGRST002 for a schema +// cache that cannot be built, but dbrest builds the cache at startup and refuses +// to start on failure, so that code has no runtime analog here. +func mapTransportError(err error) *pgerr.APIError { + var ce *pgconn.ConnectError + if errors.As(err, &ce) { + return pgerr.ErrDBConnection(err.Error()) + } + if errors.Is(err, context.DeadlineExceeded) { + return pgerr.ErrAcquireTimeout() + } return pgerr.ErrInternal(err.Error()) } // mapPgError builds the API envelope from a PostgreSQL error, passing the -// SQLSTATE through as the code and grading the HTTP status by the same rules -// PostgREST applies (its Error module's pgErrorStatus). The well-known -// constraint violations reuse dbrest's named constructors so their message and -// status match the other backends; everything else carries the server's own -// message with the status the SQLSTATE class implies. +// SQLSTATE through as the code, the server's own message and detail and hint +// verbatim, and the HTTP status graded by the same rules PostgREST applies (its +// Error module's pgErrorStatus). PostgREST forwards PostgreSQL errors unchanged, +// constraint name and "Key (col)=(val)" detail included, so the postgres backend +// does too rather than rewriting them to a canonical text; the SQLSTATE class +// alone fixes the status (a unique or foreign-key violation is 409, the rest of +// class 23 is 400). The named constructors stay for the backends whose driver +// reports a constraint without PostgreSQL's wording. func mapPgError(pg *pgconn.PgError) *pgerr.APIError { - switch pg.Code { - case "23505": // unique_violation - return withServerText(pgerr.ErrUniqueViolation(pg.Detail), pg) - case "23502": // not_null_violation - return withServerText(pgerr.ErrNotNullViolation(pg.Detail), pg) - case "23503": // foreign_key_violation - return withServerText(pgerr.ErrForeignKeyViolation(pg.Detail), pg) - case "23514": // check_violation - return withServerText(pgerr.ErrCheckViolation(pg.Detail), pg) - } - e := pgerr.New(statusForSQLState(pg.Code), pg.Code, pg.Message) + // A function can take full control of the response by raising SQLSTATE + // 'PGRST': the server reports the chosen envelope in MESSAGE and the status + // and headers in DETAIL, both as JSON. FromRaise parses both (or yields + // PGRST121 on a malformed payload); its headers ride on the error so the + // renderer emits them. This is distinct from the PTxxx status-only convention + // handled in statusForSQLState. + if pg.Code == "PGRST" { + e, headers := pgerr.FromRaise(pg.Message, pg.Detail) + return e.WithHeaders(headers) + } + e := pgerr.New(statusForSQLState(pg.Code, pg.Message), pg.Code, pg.Message) if pg.Detail != "" { e = e.WithDetails(pg.Detail) } @@ -163,46 +277,62 @@ func mapPgError(pg *pgconn.PgError) *pgerr.APIError { return e } -// withServerText keeps a named constructor's code and status but lets the -// server's hint ride through when it carries one, so a constraint error still -// reads like PostgREST's. -func withServerText(e *pgerr.APIError, pg *pgconn.PgError) *pgerr.APIError { - if pg.Hint != "" { - e = e.WithHint(pg.Hint) - } - return e -} - // statusForSQLState maps a PostgreSQL SQLSTATE to the HTTP status PostgREST -// returns for it. The table mirrors PostgREST's pgErrorStatus: most classes fold -// to 500, a few auth and resource classes have their own status, the constraint -// codes are 4xx, and a function can drive a custom status by raising a SQLSTATE -// in the PTxxx form (the three digits after PT are the status). The default for -// an unrecognized code is 400, as in PostgREST. -func statusForSQLState(code string) int { +// returns for it. The table mirrors PostgREST v14's pgErrorStatus (Error.hs) +// row for row: most classes fold to 500, a few auth and resource classes have +// their own status, the constraint codes are 4xx, two codes (21000, 22023) +// disambiguate on the server message, a function can drive a custom status with +// the PTxxx convention, and the default for an unrecognized code is 400. msg is +// the server message, needed only for the two message-sniffing rows. +// +// PostgREST reads the status off the raw integer for PTxxx and would emit even a +// nonsensical value; Go's response writer rejects a status below 100, so a PT +// status under 100 falls back to 500 here rather than panicking. Every PTxxx in +// the realistic 100-599 range (and up to 999) passes through unchanged. +func statusForSQLState(code, msg string) int { if len(code) != 5 { return 400 } // PTxxx lets a function set the response status directly (PostgREST's // "RAISE sqlstate 'PT403'" convention); the digits after PT are the status. + // PostgREST falls back to 500 when the suffix does not parse. if code[:2] == "PT" { - if n, err := strconv.Atoi(code[2:]); err == nil && n >= 100 && n <= 599 { + if n, err := strconv.Atoi(code[2:]); err == nil && n >= 100 && n <= 999 { return n } + return 500 } switch code { case "23503", "23505": // foreign_key / unique violation return 409 case "25006": // read_only_sql_transaction return 405 - case "42883": // undefined_function - return 404 + case "21000": // cardinality_violation: pg-safeupdate's missing-WHERE guard is + // a client error (400); the generic "more than one row" form is a server + // error (500), matching PostgREST's suffix test. + if strings.HasSuffix(msg, "requires a WHERE clause") { + return 400 + } + return 500 + case "22023": // invalid_parameter_value: a JWT naming a role that does not + // exist is an auth failure (401); everything else is a client error (400). + if strings.HasPrefix(msg, "role") && strings.HasSuffix(msg, "does not exist") { + return 401 + } + return 400 + case "53400": // configuration_limit_exceeded: 500, not the 503 of its class + return 500 + case "57P01": // admin_shutdown: 503-with-retry, not the 500 of its class + return 503 case "42P01": // undefined_table return 404 - case "42501": // insufficient_privilege → 401 matching PostgREST - return 401 case "42P17": // infinite_recursion return 500 + case "42501": // insufficient_privilege: 403 base, lifted to 401 for an + // anonymous request by mapExecError, mirroring PostgREST's pgErrorStatus. + return 403 + case "P0001": // raise_exception default code: client error + return 400 } switch code[:2] { case "08": // connection exception @@ -224,7 +354,7 @@ func statusForSQLState(code string) int { case "53": // insufficient resources return 503 case "54": // program limit exceeded (statement too complex) - return 413 + return 500 case "55": // object not in prerequisite state return 500 case "57": // operator intervention @@ -235,11 +365,17 @@ func statusForSQLState(code string) int { return 500 case "HV": // foreign data wrapper error return 500 - case "P0": // PL/pgSQL raise_exception and friends - return 400 + case "P0": // PL/pgSQL raise_exception and friends (P0001 handled above) + return 500 case "XX": // internal error return 500 - case "42": // syntax error / access rule violation (undefined column, ...) + case "42": // syntax / access rule violation; 42883 splits on the message + if code == "42883" { // undefined_function: xmlagg ambiguity is a 406 + if strings.HasPrefix(msg, "function xmlagg(") { + return 406 + } + return 404 + } return 400 } return 400 @@ -250,3 +386,14 @@ func init() { backend.Register("postgres", postgresDriver{}) } type postgresDriver struct{} func (postgresDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } + +// OpenWithOptions implements backend.OptionsDriver so the server can thread +// db-prepared-statements through the generic registry. PreparedStatements +// defaults to on when the option is unset. +func (postgresDriver) OpenWithOptions(dsn string, opts backend.OpenOptions) (backend.Backend, error) { + prepared := true + if opts.PreparedStatements != nil { + prepared = *opts.PreparedStatements + } + return OpenWith(dsn, Options{PreparedStatements: prepared}) +} diff --git a/backend/postgres/postgres_test.go b/backend/postgres/postgres_test.go index 214a407..1dfdc70 100644 --- a/backend/postgres/postgres_test.go +++ b/backend/postgres/postgres_test.go @@ -2,13 +2,44 @@ package postgres import ( "context" + "errors" + "fmt" "testing" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/tamnd/dbrest/pgerr" ) +// TestResolveExecMode covers finding 02-P09: the pooler-tolerant cache_describe +// default is used only when the DSN does not name a mode, and an operator's +// default_query_exec_mode choice in the DSN (the documented PgBouncer escape +// hatch) is honored rather than clobbered. +func TestResolveExecMode(t *testing.T) { + cases := []struct { + name string + dsn string + parsed pgx.QueryExecMode + prepared bool + want pgx.QueryExecMode + }{ + {"omitted, prepared on, defaults to cache_describe", "postgres://u:p@h/db", pgx.QueryExecModeCacheStatement, true, pgx.QueryExecModeCacheDescribe}, + {"omitted, prepared off, uses exec", "postgres://u:p@h/db", pgx.QueryExecModeCacheStatement, false, pgx.QueryExecModeExec}, + {"DSN choice wins over prepared off", "postgres://u:p@h/db?default_query_exec_mode=cache_statement", pgx.QueryExecModeCacheStatement, false, pgx.QueryExecModeCacheStatement}, + {"simple_protocol honored", "postgres://u:p@h/db?default_query_exec_mode=simple_protocol", pgx.QueryExecModeSimpleProtocol, true, pgx.QueryExecModeSimpleProtocol}, + {"exec honored", "postgres://u:p@h/db?default_query_exec_mode=exec", pgx.QueryExecModeExec, true, pgx.QueryExecModeExec}, + {"explicit cache_statement honored", "postgres://u:p@h/db?default_query_exec_mode=cache_statement", pgx.QueryExecModeCacheStatement, true, pgx.QueryExecModeCacheStatement}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := resolveExecMode(tc.dsn, tc.parsed, tc.prepared); got != tc.want { + t.Errorf("resolveExecMode(%q, %v, %v) = %v, want %v", tc.dsn, tc.parsed, tc.prepared, got, tc.want) + } + }) + } +} + // MapError maps PostgreSQL SQLSTATE codes to the API error envelope the way // PostgREST does. Unit tests drive mapPgError and statusForSQLState directly so // there is no need for a live server. @@ -38,6 +69,73 @@ func TestMapErrorConstraintViolations(t *testing.T) { } } +// PostgREST forwards a PostgreSQL constraint error's message and detail +// verbatim, so an application reading the constraint name out of the message or +// the offending key out of the detail still finds them. The postgres backend +// passes both through unchanged rather than rewriting them to a canonical text. +func TestMapErrorConstraintMessageVerbatim(t *testing.T) { + pg := &pgconn.PgError{ + Code: "23505", + Message: `duplicate key value violates unique constraint "films_pkey"`, + Detail: "Key (id)=(1) already exists.", + Hint: "use a different id", + } + got := mapPgError(pg) + if got.Message != pg.Message { + t.Errorf("Message = %q, want verbatim %q", got.Message, pg.Message) + } + if got.Details == nil || *got.Details != pg.Detail { + t.Errorf("Details = %v, want verbatim %q", got.Details, pg.Detail) + } + if got.Hint == nil || *got.Hint != pg.Hint { + t.Errorf("Hint = %v, want verbatim %q", got.Hint, pg.Hint) + } +} + +// A function raising SQLSTATE 'PGRST' takes full control: mapPgError reads the +// envelope from MESSAGE and the status and headers from DETAIL, surfacing the +// headers on the error so the renderer emits them (item 04.9). +func TestMapErrorRaisePGRSTFullControl(t *testing.T) { + pg := &pgconn.PgError{ + Code: "PGRST", + Message: `{"code":"123","message":"Payment Required","details":"pay up","hint":"add a card"}`, + Detail: `{"status":402,"headers":{"X-Reason":"quota"}}`, + } + got := mapPgError(pg) + if got.Code != "123" || got.Message != "Payment Required" { + t.Errorf("envelope = %q/%q, want 123/Payment Required", got.Code, got.Message) + } + if got.HTTPStatus != 402 { + t.Errorf("status = %d, want 402 from detail.status", got.HTTPStatus) + } + if got.Details == nil || *got.Details != "pay up" { + t.Errorf("details = %v, want 'pay up'", got.Details) + } + if h := got.Headers.Get("X-Reason"); h != "quota" { + t.Errorf("X-Reason header = %q, want quota", h) + } +} + +// A malformed full-control payload is PGRST121 (500), not a leaked raw string +// (item 04.9). The DETAIL here is not valid JSON. +func TestMapErrorRaisePGRSTMalformed(t *testing.T) { + pg := &pgconn.PgError{ + Code: "PGRST", + Message: `{"code":"123","message":"ok"}`, + Detail: `not json`, + } + got := mapPgError(pg) + if got.Code != "PGRST121" { + t.Errorf("code = %q, want PGRST121", got.Code) + } + if got.HTTPStatus != 500 { + t.Errorf("status = %d, want 500", got.HTTPStatus) + } + if len(got.Headers) != 0 { + t.Errorf("a malformed payload must apply no headers, got %v", got.Headers) + } +} + func TestMapErrorPassthrough(t *testing.T) { pg := &pgconn.PgError{Code: "42P01", Message: "relation does not exist", Hint: "check your schema"} got := mapPgError(pg) @@ -61,46 +159,91 @@ func TestMapErrorNil(t *testing.T) { func TestMapErrorNonPg(t *testing.T) { b := &Backend{} - got := b.MapError(context.DeadlineExceeded) + got := b.MapError(errors.New("boom")) if got == nil { t.Fatal("MapError(non-PG) = nil, want internal error") } if got.HTTPStatus != 500 { t.Errorf("HTTPStatus = %d, want 500", got.HTTPStatus) } + if got.Code != pgerr.CodeInternal { + t.Errorf("Code = %q, want %q", got.Code, pgerr.CodeInternal) + } +} + +// TestMapTransportError covers finding 03-P19(e): a driver failure that never +// reached a SQLSTATE is classified into PostgREST's connection-error family +// instead of collapsing to 500. A refused dial (pgx *pgconn.ConnectError) is +// PGRST000/503; a pool-acquisition timeout (context deadline) is PGRST003/504. +func TestMapTransportError(t *testing.T) { + b := &Backend{} + + // A real *pgconn.ConnectError from a refused dial (the wrapped error is + // unexported, so a refused localhost port is the way to get a valid one). + _, dialErr := pgconn.Connect(context.Background(), "postgres://nobody@127.0.0.1:1/none") + var ce *pgconn.ConnectError + if !errors.As(dialErr, &ce) { + t.Fatalf("expected a *pgconn.ConnectError from a refused dial, got %T (%v)", dialErr, dialErr) + } + conn := b.MapError(dialErr) + if conn.HTTPStatus != 503 || conn.Code != pgerr.CodeDBConnection { + t.Errorf("connect error => %d %q, want 503 %q", conn.HTTPStatus, conn.Code, pgerr.CodeDBConnection) + } + + to := b.MapError(fmt.Errorf("acquire: %w", context.DeadlineExceeded)) + if to.HTTPStatus != 504 || to.Code != pgerr.CodeAcquireTimeout { + t.Errorf("acquire timeout => %d %q, want 504 %q", to.HTTPStatus, to.Code, pgerr.CodeAcquireTimeout) + } + if to.Message != "Timed out acquiring connection from connection pool." { + t.Errorf("acquire timeout message = %q", to.Message) + } } func TestStatusForSQLState(t *testing.T) { cases := []struct { code string + msg string want int }{ // well-known individual codes - {"23503", 409}, - {"23505", 409}, - {"25006", 405}, - {"42883", 404}, - {"42P01", 404}, - {"42501", 401}, // matches PostgREST: insufficient_privilege → 401 + {"23503", "", 409}, + {"23505", "", 409}, + {"25006", "", 405}, + {"42883", "function foo() does not exist", 404}, + {"42883", "function xmlagg() does not exist", 406}, // xmlagg ambiguity + {"42P01", "", 404}, + {"42501", "", 403}, // insufficient_privilege: 403 base, anon lifted to 401 by mapExecError + // message-sniffing rows + {"21000", "UPDATE requires a WHERE clause", 400}, + {"21000", "more than one row returned by a subquery", 500}, + {"22023", `role "ghost" does not exist`, 401}, + {"22023", "time zone displacement out of range", 400}, // PTxxx convention - {"PT403", 403}, - {"PT201", 201}, - // class rules - {"08000", 503}, - {"28000", 403}, - {"53100", 503}, - {"54001", 413}, - {"XX000", 500}, - {"P0001", 400}, + {"PT403", "", 403}, + {"PT201", "", 201}, + {"PT999", "", 999}, // out of the 100-599 band but still a parsed status + {"PTabc", "", 500}, // unparseable suffix falls back to 500 + {"PT042", "", 500}, // below 100, not emittable, falls back to 500 + // class rules and corrected edge rows + {"08000", "", 503}, + {"28000", "", 403}, + {"53100", "", 503}, + {"53400", "", 500}, // config limit exceeded: 500, not its class 503 + {"54001", "", 500}, // program limit exceeded: 500, not 413 + {"57000", "", 500}, + {"57P01", "", 503}, // admin shutdown: 503, not its class 500 + {"XX000", "", 500}, + {"P0001", "", 400}, // raise default: client error + {"P0002", "", 500}, // other PL/pgSQL: server error // default - {"00000", 400}, - {"ZZZZZ", 400}, - {"short", 400}, + {"00000", "", 400}, + {"ZZZZZ", "", 400}, + {"short", "", 400}, } for _, c := range cases { - got := statusForSQLState(c.code) + got := statusForSQLState(c.code, c.msg) if got != c.want { - t.Errorf("statusForSQLState(%q) = %d, want %d", c.code, got, c.want) + t.Errorf("statusForSQLState(%q, %q) = %d, want %d", c.code, c.msg, got, c.want) } } } diff --git a/backend/postgres/representation.go b/backend/postgres/representation.go new file mode 100644 index 0000000..8f40924 --- /dev/null +++ b/backend/postgres/representation.go @@ -0,0 +1,78 @@ +package postgres + +import ( + "context" + + "github.com/tamnd/dbrest/schema" +) + +// loadRepresentations maps each domain type's OID to its data-representation cast +// set: the functions PostgREST applies to reformat a column of that domain on the +// wire (domain representations, spec 11). A data representation is a domain over a +// base type plus casts registered in pg_cast: a cast from the domain to json +// formats the value for a response, a cast from json to the domain parses a write +// body, and a cast from text to the domain parses a query-string filter literal. +// +// PostgreSQL ignores these casts in the `::` operator (it warns "cast will be +// ignored because the source/target data type is a domain"), so the cast cannot be +// applied as col::json. The cast function is what does the work, and that is what +// this reads: the introspector records the cast function per direction so the +// compiler calls it directly. Only function-method casts ('f') carry a function; +// the rare binary-coercible domain cast has none and drives no representation. +func (b *Backend) loadRepresentations(ctx context.Context, schemas []string) (map[uint32]*schema.Representation, error) { + const q = ` +SELECT dt.oid AS domain_oid, + fn.nspname AS fn_schema, + p.proname AS fn_name, + st.typname AS src_name, + stt.typtype AS src_typtype, + tt.typname AS tgt_name, + ttt.typtype AS tgt_typtype + FROM pg_cast c + JOIN pg_proc p ON p.oid = c.castfunc + JOIN pg_namespace fn ON fn.oid = p.pronamespace + JOIN pg_type stt ON stt.oid = c.castsource + JOIN pg_type ttt ON ttt.oid = c.casttarget + JOIN pg_type st ON st.oid = c.castsource + JOIN pg_type tt ON tt.oid = c.casttarget + JOIN pg_type dt ON dt.oid = (CASE WHEN stt.typtype = 'd' THEN c.castsource ELSE c.casttarget END) + JOIN pg_namespace dn ON dn.oid = dt.typnamespace + WHERE c.castmethod = 'f' + AND dt.typtype = 'd' + AND dn.nspname = ANY($1) + AND ( + (stt.typtype = 'd' AND tt.typname IN ('json', 'jsonb')) + OR (st.typname IN ('json', 'jsonb') AND ttt.typtype = 'd') + OR (st.typname = 'text' AND ttt.typtype = 'd') + )` + + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + out := map[uint32]*schema.Representation{} + for rows.Next() { + var domainOID uint32 + var fnSchema, fnName, srcName, srcTyptype, tgtName, tgtTyptype string + if err := rows.Scan(&domainOID, &fnSchema, &fnName, &srcName, &srcTyptype, &tgtName, &tgtTyptype); err != nil { + return nil, err + } + rep := out[domainOID] + if rep == nil { + rep = &schema.Representation{} + out[domainOID] = rep + } + ref := schema.FuncRef{Schema: fnSchema, Name: fnName} + switch { + case srcTyptype == "d" && (tgtName == "json" || tgtName == "jsonb"): + rep.ToJSON = ref // domain -> json: format on read + case (srcName == "json" || srcName == "jsonb") && tgtTyptype == "d": + rep.FromJSON = ref // json -> domain: parse a write value + case srcName == "text" && tgtTyptype == "d": + rep.FromText = ref // text -> domain: parse a filter literal + } + } + return out, rows.Err() +} diff --git a/backend/postgres/result.go b/backend/postgres/result.go index a222df8..bd742d4 100644 --- a/backend/postgres/result.go +++ b/backend/postgres/result.go @@ -2,7 +2,9 @@ package postgres import ( "context" + "fmt" "io" + "strings" "time" "github.com/jackc/pgx/v5" @@ -28,11 +30,12 @@ type streamResult struct { controls *reqctx.ResponseControls count int64 hasCount bool + loc *time.Location } func (r *streamResult) Body() io.Reader { return nil } func (r *streamResult) Rows() backend.RowStream { - return &streamRows{ctx: r.ctx, tx: r.tx, rows: r.rows, cols: r.cols} + return &streamRows{ctx: r.ctx, tx: r.tx, rows: r.rows, cols: r.cols, loc: r.loc} } func (r *streamResult) Count() (int64, bool) { return r.count, r.hasCount } func (r *streamResult) Affected() (int64, bool) { return 0, false } @@ -46,6 +49,7 @@ type streamRows struct { tx pgx.Tx rows pgx.Rows cols []string + loc *time.Location } func (s *streamRows) Columns() []string { return s.cols } @@ -61,7 +65,7 @@ func (s *streamRows) Values() ([]any, error) { if err != nil { return nil, err } - return normalizeValues(vals, s.rows.FieldDescriptions()), nil + return normalizeValues(vals, s.rows.FieldDescriptions(), s.loc), nil } // Close releases the cursor and commits the transaction that scoped the role and @@ -116,23 +120,298 @@ func (s *bufStream) Err() error { return nil } func (s *bufStream) Close() error { return nil } // normalizeValues adjusts pgx's decoded values to the shapes the renderer maps to -// JSON. json and jsonb arrive as raw bytes; they are turned into strings so the -// renderer's raw-JSON columns pass them through verbatim rather than base64. A -// bytea value also arrives as bytes, but its column is not a raw-JSON column, so -// it renders as a string like the other backends. PostgreSQL date columns -// (OID 1082) arrive as time.Time but must be formatted as "YYYY-MM-DD" to match -// PostgREST, not as a full RFC3339 timestamp. Every other value is left as pgx -// decoded it. -func normalizeValues(vals []any, fields []pgconn.FieldDescription) []any { +// JSON, so the wire value matches the JSON PostgREST assembles inside the +// database. json and jsonb arrive as raw bytes turned into strings so raw-JSON +// columns pass through verbatim; a bytea value also arrives as bytes and renders +// as a string. Temporal columns are formatted by OID to PostgreSQL's own JSON +// spellings rather than left to Go's default struct/RFC3339 marshalling: +// +// - date (1082): "2006-01-02". +// - time (1083): pgx returns pgtype.Time, which json would render as a struct; +// format as "HH:MM:SS[.ffffff]". +// - timetz (1266): pgx already returns the correct "HH:MM:SS[.ffffff]+TZ" +// string, so it passes through. +// - interval (1186): pgx returns pgtype.Interval; format in PostgreSQL's +// default (postgres) IntervalStyle. +// - timestamp (1114): no zone, so format the wall clock as +// "2006-01-02T15:04:05[.ffffff]" with no suffix (Go would append "Z"). +// - timestamptz (1184): render the instant in the server TimeZone with an ISO +// "+HH:MM" offset, matching PostgreSQL (Go's RFC3339 emits "Z" for UTC). +// - range / multirange: pgx returns pgtype.Range / pgtype.Multirange structs; +// format them to PostgreSQL's own range text ("[10,20)", "{[1,2),[5,8)}") +// rather than the Go struct json would marshal. +// +// loc is the server TimeZone; a nil loc defaults to UTC. Every other value is +// left as pgx decoded it. +func normalizeValues(vals []any, fields []pgconn.FieldDescription, loc *time.Location) []any { + if loc == nil { + loc = time.UTC + } for i, v := range vals { + var oid uint32 + if i < len(fields) { + oid = fields[i].DataTypeOID + } switch t := v.(type) { case []byte: vals[i] = string(t) case time.Time: - if i < len(fields) && fields[i].DataTypeOID == pgtype.DateOID { + switch oid { + case pgtype.DateOID: vals[i] = t.Format("2006-01-02") + case pgtype.TimestamptzOID: + vals[i] = t.In(loc).Format("2006-01-02T15:04:05.999999-07:00") + case pgtype.TimestampOID: + vals[i] = t.Format("2006-01-02T15:04:05.999999") + } + case pgtype.Time: + if t.Valid { + vals[i] = formatTimeOfDay(t.Microseconds) + } + case pgtype.Interval: + if t.Valid { + vals[i] = formatInterval(t) } + case pgtype.Range[any]: + if t.Valid { + vals[i] = formatRange(t, oid, loc) + } + case pgtype.Multirange[pgtype.Range[any]]: + vals[i] = formatMultirange(t, oid, loc) } } return vals } + +// formatRange renders a range value as PostgreSQL's own text output (the spelling +// `col::text` produces and the form PostgREST emits in JSON), instead of the +// pgtype.Range Go struct json would otherwise marshal. An empty range is "empty"; +// otherwise the bracket reflects each bound's inclusivity ('[' / '(' lower, ']' / +// ')' upper), an unbounded side renders as the empty string, and each present +// bound is formatted by the range's element type and quoted by PostgreSQL's range +// rules. oid is the range column OID, which selects the element formatting. +func formatRange(r pgtype.Range[any], oid uint32, loc *time.Location) string { + if r.LowerType == pgtype.Empty || r.UpperType == pgtype.Empty { + return "empty" + } + var sb strings.Builder + if r.LowerType == pgtype.Inclusive { + sb.WriteByte('[') + } else { + sb.WriteByte('(') + } + if r.LowerType != pgtype.Unbounded { + sb.WriteString(quoteRangeBound(formatRangeElem(r.Lower, oid, loc))) + } + sb.WriteByte(',') + if r.UpperType != pgtype.Unbounded { + sb.WriteString(quoteRangeBound(formatRangeElem(r.Upper, oid, loc))) + } + if r.UpperType == pgtype.Inclusive { + sb.WriteByte(']') + } else { + sb.WriteByte(')') + } + return sb.String() +} + +// formatMultirange renders a multirange as PostgreSQL's text output: the +// brace-wrapped, comma-separated list of its member ranges. Each member is +// formatted with the corresponding range OID so its element type is rendered the +// same as a bare range. +func formatMultirange(m pgtype.Multirange[pgtype.Range[any]], oid uint32, loc *time.Location) string { + relemOID := multirangeRangeOID(oid) + var sb strings.Builder + sb.WriteByte('{') + for i, r := range m { + if i > 0 { + sb.WriteByte(',') + } + sb.WriteString(formatRange(r, relemOID, loc)) + } + sb.WriteByte('}') + return sb.String() +} + +// multirangeRangeOID maps a multirange OID to the OID of its member range type, +// so formatRange formats each member's bounds by the right element type. +func multirangeRangeOID(oid uint32) uint32 { + switch oid { + case pgtype.Int4multirangeOID: + return pgtype.Int4rangeOID + case pgtype.Int8multirangeOID: + return pgtype.Int8rangeOID + case pgtype.NummultirangeOID: + return pgtype.NumrangeOID + case pgtype.DatemultirangeOID: + return pgtype.DaterangeOID + case pgtype.TsmultirangeOID: + return pgtype.TsrangeOID + case pgtype.TstzmultirangeOID: + return pgtype.TstzrangeOID + } + return 0 +} + +// formatRangeElem renders one range bound by the range's element type. Temporal +// element types are formatted to PostgreSQL's range text spelling (which uses the +// raw timestamp output, not the ISO json spelling, so the timestamptz offset is +// "+07" rather than "+07:00"); numeric elements use their decimal text, and the +// rest fall back to their default string form. +func formatRangeElem(v any, oid uint32, loc *time.Location) string { + switch oid { + case pgtype.DaterangeOID, pgtype.DatemultirangeOID: + if t, ok := v.(time.Time); ok { + return t.Format("2006-01-02") + } + case pgtype.TsrangeOID, pgtype.TsmultirangeOID: + if t, ok := v.(time.Time); ok { + return t.Format("2006-01-02 15:04:05.999999") + } + case pgtype.TstzrangeOID, pgtype.TstzmultirangeOID: + if t, ok := v.(time.Time); ok { + return formatTimestamptzText(t, loc) + } + } + switch x := v.(type) { + case pgtype.Numeric: + if b, err := x.MarshalJSON(); err == nil { + return string(b) + } + case []byte: + return string(x) + case string: + return x + } + return fmt.Sprint(v) +} + +// formatTimestamptzText renders a timestamptz the way PostgreSQL's text output +// (and thus range text) does: the wall clock in the server zone followed by a +// signed offset that carries minutes only when non-zero and seconds rarer still, +// e.g. "+07", "+05:30". This differs from the ISO "+07:00" json spelling used for +// a bare timestamptz column. +func formatTimestamptzText(t time.Time, loc *time.Location) string { + t = t.In(loc) + base := t.Format("2006-01-02 15:04:05.999999") + _, off := t.Zone() + sign := byte('+') + if off < 0 { + sign = '-' + off = -off + } + h := off / 3600 + m := (off % 3600) / 60 + s := off % 60 + out := fmt.Sprintf("%s%c%02d", base, sign, h) + if m != 0 || s != 0 { + out += fmt.Sprintf(":%02d", m) + } + if s != 0 { + out += fmt.Sprintf(":%02d", s) + } + return out +} + +// quoteRangeBound quotes a range bound the way PostgreSQL does: an empty string +// or one containing a comma, brackets, parentheses, a quote, a backslash, or +// whitespace is double-quoted with embedded quotes and backslashes escaped; +// anything else is left bare. +func quoteRangeBound(s string) string { + if s == "" { + return `""` + } + if !strings.ContainsAny(s, "(),[]\"\\ \t\n\r") { + return s + } + var sb strings.Builder + sb.WriteByte('"') + for _, r := range s { + if r == '"' || r == '\\' { + sb.WriteByte('\\') + } + sb.WriteRune(r) + } + sb.WriteByte('"') + return sb.String() +} + +// formatTimeOfDay renders a time-of-day microsecond count as PostgreSQL's JSON +// time spelling "HH:MM:SS" with a fractional part only when non-zero, trailing +// zeros trimmed (so 13:00:00.5, not 13:00:00.500000). +func formatTimeOfDay(micros int64) string { + h := micros / 3_600_000_000 + m := (micros / 60_000_000) % 60 + s := (micros / 1_000_000) % 60 + frac := micros % 1_000_000 + out := fmt.Sprintf("%02d:%02d:%02d", h, m, s) + if frac != 0 { + out += strings.TrimRight(fmt.Sprintf(".%06d", frac), "0") + } + return out +} + +// formatInterval renders a pgtype.Interval in PostgreSQL's default (postgres) +// IntervalStyle, matching EncodeInterval: each non-zero year/month/day field is +// emitted with its unit (pluralized unless the value is exactly 1), a field +// after a negative one gets an explicit leading "+" when positive, and the time +// part carries a "-" when negative or a "+" when it follows a negative field. +// The all-zero interval is "00:00:00". Months fold to years (12 per year). +func formatInterval(iv pgtype.Interval) string { + years := iv.Months / 12 + mons := iv.Months % 12 + + var sb strings.Builder + wrote := false // a field has been emitted + prevNeg := false // the previous emitted field was negative + + addInt := func(value int32, unit string) { + if value == 0 { + return + } + if wrote { + sb.WriteByte(' ') + } + if prevNeg && value > 0 { + sb.WriteByte('+') + } + plural := "s" + if value == 1 { + plural = "" + } + fmt.Fprintf(&sb, "%d %s%s", value, unit, plural) + prevNeg = value < 0 + wrote = true + } + + addInt(years, "year") + addInt(mons, "mon") + addInt(iv.Days, "day") + + micros := iv.Microseconds + if !wrote || micros != 0 { + neg := micros < 0 + abs := micros + if neg { + abs = -micros + } + h := abs / 3_600_000_000 + m := (abs / 60_000_000) % 60 + s := (abs / 1_000_000) % 60 + frac := abs % 1_000_000 + if wrote { + sb.WriteByte(' ') + } + switch { + case neg: + sb.WriteByte('-') + case prevNeg: + sb.WriteByte('+') + } + fmt.Fprintf(&sb, "%02d:%02d:%02d", h, m, s) + if frac != 0 { + sb.WriteString(strings.TrimRight(fmt.Sprintf(".%06d", frac), "0")) + } + } + return sb.String() +} diff --git a/backend/postgres/result_format_test.go b/backend/postgres/result_format_test.go new file mode 100644 index 0000000..8b23480 --- /dev/null +++ b/backend/postgres/result_format_test.go @@ -0,0 +1,74 @@ +package postgres + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +// The expected strings below are PostgreSQL's own to_json spellings, captured +// from a live server (IntervalStyle=postgres). formatInterval and formatTimeOfDay +// must reproduce them exactly so a temporal column renders on the wire the way +// PostgREST renders it. Finding 03-P07 / E05. + +func TestFormatInterval(t *testing.T) { + cases := []struct { + months, days int32 + micros int64 + want string + }{ + {0, 0, 90000000, "00:01:30"}, + {0, 0, -90000000, "-00:01:30"}, + {0, 1, 90000000, "1 day 00:01:30"}, + {14, 3, 14706000000, "1 year 2 mons 3 days 04:05:06"}, + {1, 0, 0, "1 mon"}, + {0, 2, 0, "2 days"}, + {0, -1, -7384000000, "-1 days -02:03:04"}, + {0, 1, 7384500000, "1 day 02:03:04.5"}, + {13, 0, 0, "1 year 1 mon"}, + {12, 0, 0, "1 year"}, + {-14, 0, 0, "-1 years -2 mons"}, + {0, 3, -14706000000, "3 days -04:05:06"}, + {0, 0, 0, "00:00:00"}, + {0, 0, 5400000000, "01:30:00"}, + {-1, 2, 0, "-1 mons +2 days"}, + {0, 0, 86400000000, "24:00:00"}, + {2, -3, 0, "2 mons -3 days"}, + {0, -3, 90000000, "-3 days +00:01:30"}, + {-1, 0, 90000000, "-1 mons +00:01:30"}, + {1, -2, -4000000, "1 mon -2 days -00:00:04"}, + {0, -3, 14706000000, "-3 days +04:05:06"}, + {-2, 3, -4000000, "-2 mons +3 days -00:00:04"}, + {0, 2, -90000000, "2 days -00:01:30"}, + {0, -2, 90000000, "-2 days +00:01:30"}, + {11, 0, 0, "11 mons"}, + {-11, 0, 0, "-11 mons"}, + {5, -10, 1000000, "5 mons -10 days +00:00:01"}, + {0, 0, 1, "00:00:00.000001"}, + {0, 0, -500000, "-00:00:00.5"}, + } + for _, c := range cases { + iv := pgtype.Interval{Months: c.months, Days: c.days, Microseconds: c.micros, Valid: true} + if got := formatInterval(iv); got != c.want { + t.Errorf("formatInterval(mon=%d days=%d us=%d) = %q, want %q", c.months, c.days, c.micros, got, c.want) + } + } +} + +func TestFormatTimeOfDay(t *testing.T) { + cases := []struct { + micros int64 + want string + }{ + {46800500000, "13:00:00.5"}, + {46800000000, "13:00:00"}, + {0, "00:00:00"}, + {1, "00:00:00.000001"}, + {86399999999, "23:59:59.999999"}, + } + for _, c := range cases { + if got := formatTimeOfDay(c.micros); got != c.want { + t.Errorf("formatTimeOfDay(%d) = %q, want %q", c.micros, got, c.want) + } + } +} diff --git a/backend/postgres/role_settings.go b/backend/postgres/role_settings.go new file mode 100644 index 0000000..054a6ca --- /dev/null +++ b/backend/postgres/role_settings.go @@ -0,0 +1,127 @@ +package postgres + +import ( + "context" + "strings" + + "github.com/jackc/pgx/v5" + + "github.com/tamnd/dbrest/reqctx" +) + +// roleSetting is one ALTER ROLE ... SET key/value the backend replays as a +// transaction-scoped setting for the impersonated role. +type roleSetting struct { + name string + value string +} + +// loadRoleSettings reads the per-role configuration the impersonated role carries +// (ALTER ROLE SET ...), so the backend can replay it as transaction-scoped +// settings the way PostgREST does. It mirrors PostgREST's queryRoleSettings +// (src/PostgREST/Config/Database.hs): settings come from pg_roles.rolconfig for +// every role the connected authenticator is a member of, and a setting is kept +// only when it is USERSET (pg_settings.context = 'user') or, on PostgreSQL 15+, +// the authenticator holds SET privilege on it (has_parameter_privilege), so a +// setting the session could not apply is skipped instead of aborting the request. +// +// default_transaction_isolation is pulled out separately: it cannot be applied +// with set_config once the transaction has run a statement, so it is returned as +// a per-role isolation level the execute paths pass to BeginTx, while the rest are +// returned as set_config replays. +func (b *Backend) loadRoleSettings(ctx context.Context) (map[string][]roleSetting, map[string]pgx.TxIsoLevel, error) { + // has_parameter_privilege is only available on PostgreSQL 15+, matching the + // gate PostgREST applies to the same filter. + privClause := "" + if b.version.Major >= 15 { + privClause = " OR has_parameter_privilege(quote_ident(current_user)::regrole::oid, ps.name, 'set')" + } + q := ` +WITH role_setting AS ( + SELECT r.rolname, unnest(r.rolconfig) AS setting + FROM pg_auth_members m + JOIN pg_roles r ON r.oid = m.roleid + WHERE m.member = quote_ident(current_user)::regrole::oid +), +kv AS ( + SELECT rolname, + substr(setting, 1, strpos(setting, '=') - 1) AS key, + substr(setting, strpos(setting, '=') + 1) AS value + FROM role_setting +) +SELECT kv.rolname, kv.key, kv.value + FROM kv + LEFT JOIN pg_settings ps ON ps.name = kv.key + WHERE kv.value IS NOT NULL + AND (kv.key = 'default_transaction_isolation' OR ps.context = 'user'` + privClause + `)` + + rows, err := b.pool.Query(ctx, q) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + settings := map[string][]roleSetting{} + isolation := map[string]pgx.TxIsoLevel{} + for rows.Next() { + var role, key, value string + if err := rows.Scan(&role, &key, &value); err != nil { + return nil, nil, err + } + if key == "default_transaction_isolation" { + if lvl, ok := isoLevelFromName(value); ok { + isolation[role] = lvl + } + continue + } + settings[role] = append(settings[role], roleSetting{name: key, value: value}) + } + return settings, isolation, rows.Err() +} + +// isoLevelFromName maps a default_transaction_isolation value to the pgx isolation +// level. The names are PostgreSQL's canonical spellings; an unrecognized value +// leaves the server default in place. +func isoLevelFromName(name string) (pgx.TxIsoLevel, bool) { + switch strings.ToLower(strings.TrimSpace(name)) { + case "serializable": + return pgx.Serializable, true + case "repeatable read": + return pgx.RepeatableRead, true + case "read committed": + return pgx.ReadCommitted, true + case "read uncommitted": + return pgx.ReadUncommitted, true + default: + return "", false + } +} + +// roleIso returns the impersonated role's default_transaction_isolation level, or +// "" when the role pins none, so BeginTx keeps the server default. +func (b *Backend) roleIso(rc *reqctx.Context) pgx.TxIsoLevel { + if b.roleIsolation == nil || rc == nil || rc.Role == "" { + return "" + } + return b.roleIsolation[rc.Role] +} + +// txOptions builds the transaction options for a request: the given access mode +// plus the impersonated role's default_transaction_isolation when it pins one, so +// a role's ALTER ROLE ... SET default_transaction_isolation takes effect on every +// request the way it does under PostgREST. An empty access mode keeps pgx's server +// default (read-write). +func (b *Backend) txOptions(rc *reqctx.Context, mode pgx.TxAccessMode) pgx.TxOptions { + o := pgx.TxOptions{AccessMode: mode} + if iso := b.roleIso(rc); iso != "" { + o.IsoLevel = iso + } + return o +} + +// isoAtLeastRepeatableRead reports whether lvl already gives a single transaction +// snapshot (REPEATABLE READ or SERIALIZABLE), so the counted-read path does not +// downgrade a role that pins a stronger level. +func isoAtLeastRepeatableRead(lvl pgx.TxIsoLevel) bool { + return lvl == pgx.RepeatableRead || lvl == pgx.Serializable +} diff --git a/backend/postgres/session.go b/backend/postgres/session.go index f28bb50..1da0056 100644 --- a/backend/postgres/session.go +++ b/backend/postgres/session.go @@ -22,9 +22,28 @@ func queueSessionItems(batch *pgx.Batch, b *Backend, rc *reqctx.Context) int { if rc.Role != "" { batch.Queue("SET LOCAL ROLE " + (Dialect{}).QuoteIdent(rc.Role)) n++ + // Replay the impersonated role's ALTER ROLE ... SET settings as + // transaction-scoped settings, after the role switch and before the main + // statement, matching PostgREST. default_transaction_isolation is not here; + // it cannot be set after a statement has run, so it is applied at BeginTx. + for _, s := range b.roleSettings[rc.Role] { + batch.Queue("SELECT set_config($1,$2,true)", s.name, s.value) + n++ + } + } + if sp := b.searchPathValue(rc); sp != "" { + // set_config(...,true) is SET LOCAL search_path. PostgREST sets it the same + // way rather than with SET ... TO , so the GUC string is the literal + // quoted value verbatim ("schema", "public"); a SET ... TO would let the + // server re-canonicalize and strip quotes from simple names, so a policy that + // reads current_setting('search_path') would see a different string. + batch.Queue("SELECT set_config('search_path',$1,true)", sp) + n++ } - if b.searchPathSQL != "" { - batch.Queue(b.searchPathSQL) + if rc.TimeZone != "" { + // set_config(...,true) is the SET LOCAL timezone analog, parameterized so a + // name with a slash (America/Los_Angeles) needs no identifier quoting. + batch.Queue("SELECT set_config('timezone',$1,true)", rc.TimeZone) n++ } batch.Queue( @@ -36,7 +55,27 @@ func queueSessionItems(batch *pgx.Batch, b *Backend, rc *reqctx.Context) int { "request.headers", string(rc.HeadersJSON()), "request.cookies", string(rc.CookiesJSON()), ) - return n + 1 + n++ + if rc.PreRequest != "" { + // db-pre-request runs after the transaction-scoped settings and before the + // main query, in the same transaction, so it sees the request context and + // can raise to abort or write response.status/response.headers. A raised + // error surfaces when the batch is drained and aborts the request. + batch.Queue("SELECT " + preRequestCall(rc.PreRequest) + "()") + n++ + } + return n +} + +// preRequestCall renders the db-pre-request function name as a quoted, possibly +// schema-qualified callable, so a name like auth.check or one needing quoting is +// safe to interpolate. +func preRequestCall(fn string) string { + parts := strings.Split(fn, ".") + for i, p := range parts { + parts[i] = (Dialect{}).QuoteIdent(p) + } + return strings.Join(parts, ".") } // applySession sends the per-request GUC setup as a SINGLE batch within tx, @@ -86,19 +125,23 @@ func readResponseControls(ctx context.Context, tx pgx.Tx, controls *reqctx.Respo return nil } -// buildSearchPathSQL pre-computes the SET LOCAL search_path statement for a -// Backend so the string is built once and reused per request. -func buildSearchPathSQL(schemas []string) string { - if len(schemas) == 0 { +// searchPathValue builds the per-request search_path GUC value. The path is the +// request's active schema (the Accept-Profile/Content-Profile choice, or the +// first configured schema by default) followed by db-extra-search-path, matching +// PostgREST: only the active schema is on the path, not the whole exposed set, +// and the extra entries are appended without deduplication. Each name is quoted, +// so the joined string is the literal value PostgREST writes ("schema", "public"). +// An empty active schema (the emulated-namespace marker the engines without named +// schemas use) yields an empty value, so no search_path is set. +func (b *Backend) searchPathValue(rc *reqctx.Context) string { + active := b.callSchema(rc) + if active == "" { return "" } - var b strings.Builder - b.WriteString("SET LOCAL search_path TO ") + schemas := append([]string{active}, b.extraSearchPath...) + parts := make([]string, len(schemas)) for i, s := range schemas { - if i > 0 { - b.WriteString(", ") - } - b.WriteString((Dialect{}).QuoteIdent(s)) + parts[i] = (Dialect{}).QuoteIdent(s) } - return b.String() + return strings.Join(parts, ", ") } diff --git a/backend/postgres/session_test.go b/backend/postgres/session_test.go new file mode 100644 index 0000000..276ae45 --- /dev/null +++ b/backend/postgres/session_test.go @@ -0,0 +1,88 @@ +package postgres + +import ( + "strings" + "testing" + + "github.com/jackc/pgx/v5" + + "github.com/tamnd/dbrest/reqctx" +) + +// queuedSQL collects the SQL text of every item in a batch, for asserting which +// session-setup statements were queued. +func queuedSQL(batch *pgx.Batch) []string { + out := make([]string, 0, len(batch.QueuedQueries)) + for _, q := range batch.QueuedQueries { + out = append(out, q.SQL) + } + return out +} + +// TestQueueSessionTimeZone checks Prefer: timezone= becomes a SET LOCAL timezone +// (via set_config(...,true)) carrying the validated zone as a parameter. +func TestQueueSessionTimeZone(t *testing.T) { + b := &Backend{} + batch := &pgx.Batch{} + rc := &reqctx.Context{Role: "web_anon", TimeZone: "America/Los_Angeles"} + queueSessionItems(batch, b, rc) + + var tzItem *pgx.QueuedQuery + for _, q := range batch.QueuedQueries { + if strings.Contains(q.SQL, "'timezone'") { + tzItem = q + } + } + if tzItem == nil { + t.Fatalf("no timezone item queued; queued: %v", queuedSQL(batch)) + } + if !strings.Contains(tzItem.SQL, "set_config('timezone',$1,true)") { + t.Errorf("timezone SQL = %q", tzItem.SQL) + } + if len(tzItem.Arguments) != 1 || tzItem.Arguments[0] != "America/Los_Angeles" { + t.Errorf("timezone args = %v, want [America/Los_Angeles]", tzItem.Arguments) + } +} + +// TestQueueSessionNoTimeZone checks the timezone item is absent when the request +// stated no zone, so the engine default stands. +func TestQueueSessionNoTimeZone(t *testing.T) { + b := &Backend{} + batch := &pgx.Batch{} + queueSessionItems(batch, b, &reqctx.Context{Role: "web_anon"}) + for _, sql := range queuedSQL(batch) { + if strings.Contains(sql, "'timezone'") { + t.Errorf("unexpected timezone item: %q", sql) + } + } +} + +// TestQueueSessionPreRequest checks db-pre-request becomes a SELECT of the quoted +// function as the last session item, so it runs after the GUCs and before the +// main query. +func TestQueueSessionPreRequest(t *testing.T) { + b := &Backend{} + batch := &pgx.Batch{} + queueSessionItems(batch, b, &reqctx.Context{Role: "web_anon", PreRequest: "auth.check_request"}) + last := queuedSQL(batch) + if len(last) == 0 { + t.Fatal("no items queued") + } + got := last[len(last)-1] + if got != `SELECT "auth"."check_request"()` { + t.Errorf("pre-request item = %q", got) + } +} + +// TestQueueSessionNoPreRequest checks no pre-request item is queued when none is +// configured. +func TestQueueSessionNoPreRequest(t *testing.T) { + b := &Backend{} + batch := &pgx.Batch{} + queueSessionItems(batch, b, &reqctx.Context{Role: "web_anon"}) + for _, sql := range queuedSQL(batch) { + if strings.HasPrefix(sql, "SELECT \"") { + t.Errorf("unexpected pre-request item: %q", sql) + } + } +} diff --git a/backend/postgres/views.go b/backend/postgres/views.go new file mode 100644 index 0000000..8232d3c --- /dev/null +++ b/backend/postgres/views.go @@ -0,0 +1,84 @@ +package postgres + +import ( + "context" + + "github.com/tamnd/dbrest/schema" +) + +// loadViewColumns maps every exposed view's output columns to the base-relation +// columns they project, the data the model needs to carry base-table foreign keys +// onto views (spec 09, finding 03-P09). PostgreSQL records the origin of each view +// output column directly in the view's rewrite rule: every TARGETENTRY of the +// _RETURN rule carries resorigtbl (the OID of the relation the column came from) +// and resorigcol (its attribute number there), which survive column renames and +// point at the immediate source even through a view-over-view chain. We read those +// out of pg_rewrite.ev_action, the rule's parsed query tree rendered as text, and +// resolve them to names through pg_attribute. +// +// The mapping is per immediate source, not the ultimate base table: a view over a +// view points at the inner view, and the model's projection runs to a fixpoint so +// the inner view's inherited keys are available when the outer view projects. +// Columns with resorigtbl 0 (an expression or a literal, not a plain column +// reference) are skipped, so an expression column simply carries no mapping and +// inherits nothing. Set-operation views (UNION/INTERSECT/EXCEPT) are skipped +// entirely, matching PostgREST, because a set operation can combine rows from +// unrelated relations under one output column, so no single base column owns it. +func (b *Backend) loadViewColumns(ctx context.Context, schemas []string) (map[uint32][]schema.ViewColumn, error) { + // resname (the output column name) escapes spaces and other specials with a + // backslash in the node-tree text, so it is matched as a run of escaped chars + // or non-spaces and discarded; the output name is read from pg_attribute by + // resno instead, which avoids unescaping. resorigtbl and resorigcol are the + // fields we keep. The set-operation guard drops any view whose rule carries a + // SETOPERATIONSTMT node (an empty :setOperations <> stays). + const q = ` +WITH views AS ( + SELECT c.oid AS view_oid, rw.ev_action::text AS act + FROM pg_rewrite rw + JOIN pg_class c ON c.oid = rw.ev_class + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE c.relkind IN ('v','m') + AND n.nspname = ANY($1) + AND rw.rulename = '_RETURN' + AND rw.ev_action::text NOT LIKE '%:setOperations {%' +), +entries AS ( + SELECT view_oid, + (m[1])::int AS resno, + (m[2])::oid AS resorigtbl, + (m[3])::int AS resorigcol + FROM views, + regexp_matches(act, + ':resno (\d+) :resname (?:\\.|[^ ])+ :ressortgroupref \d+ :resorigtbl (\d+) :resorigcol (\d+) :resjunk false', + 'g') AS m + WHERE (m[2])::oid <> 0 +) +SELECT e.view_oid, + va.attname AS view_column, + bn.nspname AS base_schema, + bc.relname AS base_relation, + ba.attname AS base_column + FROM entries e + JOIN pg_attribute va ON va.attrelid = e.view_oid AND va.attnum = e.resno + JOIN pg_class bc ON bc.oid = e.resorigtbl + JOIN pg_namespace bn ON bn.oid = bc.relnamespace + JOIN pg_attribute ba ON ba.attrelid = e.resorigtbl AND ba.attnum = e.resorigcol + ORDER BY e.view_oid, e.resno` + + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + out := map[uint32][]schema.ViewColumn{} + for rows.Next() { + var viewOID uint32 + var vc schema.ViewColumn + if err := rows.Scan(&viewOID, &vc.Name, &vc.BaseSchema, &vc.BaseRelation, &vc.BaseColumn); err != nil { + return nil, err + } + out[viewOID] = append(out[viewOID], vc) + } + return out, rows.Err() +} diff --git a/backend/postgres/volatility.go b/backend/postgres/volatility.go new file mode 100644 index 0000000..a964db4 --- /dev/null +++ b/backend/postgres/volatility.go @@ -0,0 +1,176 @@ +package postgres + +import ( + "context" + + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/reqctx" + "github.com/tamnd/dbrest/rpc" +) + +// loadFunctionVolatility reads pg_proc.provolatile for every function in the +// exposed schemas into a map keyed by "schema.name", so the native RPC path can +// pick the transaction access mode from volatility the way PostgREST does: +// a STABLE or IMMUTABLE function runs read-only even when called with POST, only +// a VOLATILE function runs read-write. When a name has several overloads with +// differing volatility, the most write-capable one wins (Volatile), so a POST +// never loses its write transaction; the cost is only that a read-only overload +// sharing a name with a volatile one runs read-write, the safe direction. +func (b *Backend) loadFunctionVolatility(ctx context.Context, schemas []string) (map[string]rpc.Volatility, error) { + const q = ` +SELECT n.nspname, p.proname, p.provolatile::text + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = ANY($1)` + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + out := map[string]rpc.Volatility{} + for rows.Next() { + var nsp, name, vol string + if err := rows.Scan(&nsp, &name, &vol); err != nil { + return nil, err + } + v := volatilityFromChar(vol) + key := nsp + "." + name + // Volatile wins for a name with mixed overloads, so a write overload keeps + // its read-write transaction. Volatile is the zero value, so an existing + // Volatile entry is never downgraded. + if cur, ok := out[key]; ok && cur == rpc.Volatile { + continue + } + out[key] = v + } + return out, rows.Err() +} + +// volatilityFromChar maps a pg_proc.provolatile char to the portable Volatility. +// Anything unexpected falls back to Volatile, the safe (read-write) classification. +func volatilityFromChar(c string) rpc.Volatility { + switch c { + case "i": + return rpc.Immutable + case "s": + return rpc.Stable + default: + return rpc.Volatile + } +} + +// loadFunctionReturns reads each function's return shape (pg_proc.proretset and +// the return type's class) for every function in the exposed schemas, keyed by +// "schema.name". The native RPC renderer uses it to shape a result the way +// PostgREST does instead of guessing from column names: a SETOF scalar function +// renders as a JSON array of bare values, a function returning a single composite +// row renders as one object, a SETOF or TABLE function as an array of objects, a +// scalar function as the bare value, and a void function as a null body. +// +// When a name has several overloads with differing return shapes the first by oid +// wins; resolving the exact overload's shape needs full parameter introspection +// (the native registry, a later slice), and same-named overloads almost always +// share a return shape in practice. +func (b *Backend) loadFunctionReturns(ctx context.Context, schemas []string) (map[string]rpc.ReturnShape, error) { + const q = ` +SELECT n.nspname, p.proname, p.proretset, p.prorettype, t.typtype::text, t.typname + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + JOIN pg_type t ON t.oid = p.prorettype + WHERE n.nspname = ANY($1) + ORDER BY n.nspname, p.proname, p.oid` + rows, err := b.pool.Query(ctx, q, schemas) + if err != nil { + return nil, err + } + defer rows.Close() + + out := map[string]rpc.ReturnShape{} + for rows.Next() { + var ( + nsp, name, typtype, typname string + retset bool + rettype uint32 + ) + if err := rows.Scan(&nsp, &name, &retset, &rettype, &typtype, &typname); err != nil { + return nil, err + } + key := nsp + "." + name + if _, ok := out[key]; ok { + continue // first overload by oid wins + } + out[key] = returnShapeFor(retset, rettype, typtype, typname) + } + return out, rows.Err() +} + +// PostgreSQL OIDs for the pseudo-types the return shape keys on. +const ( + oidRecord = 2249 // RETURNS record / RETURNS TABLE(...) carry this prorettype + oidVoid = 2278 // RETURNS void +) + +// returnShapeFor maps a function's pg_proc return facts to a portable ReturnShape. +// A composite return (pg_type.typtype 'c') or a record (the TABLE/OUT-parameter +// form) is object-shaped; everything else is scalar-shaped. proretset then decides +// array vs single: a set of objects is a table, a single object is one object; a +// set of scalars is a setof, a single scalar is a scalar. Type carries the return +// type name so the scalar renderer can embed a json/jsonb value verbatim. +func returnShapeFor(retset bool, rettype uint32, typtype, typname string) rpc.ReturnShape { + if rettype == oidVoid { + return rpc.ReturnShape{Kind: rpc.ReturnVoid} + } + composite := typtype == "c" || rettype == oidRecord + switch { + case retset && composite: + return rpc.ReturnShape{Kind: rpc.ReturnTable} + case retset: + return rpc.ReturnShape{Kind: rpc.ReturnSetOf, Type: typname} + case composite: + return rpc.ReturnShape{Kind: rpc.ReturnObject} + default: + return rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: typname} + } +} + +// nativeFunc builds the descriptor for a native RPC call from the introspected +// catalog: the return shape (funcRet) and volatility (funcVol), keyed by the +// call's schema and name. Query stays nil, which marks the function native so the +// executor keeps lowering it through the literal-splice path; the descriptor only +// gives the renderer and the access-mode check a real return kind instead of a +// column-name guess. It returns nil when the function was not introspected (for +// example a search-path schema outside the exposed set), leaving the legacy +// heuristic in place rather than asserting a shape that may be wrong. +func (b *Backend) nativeFunc(c *ir.Call, schema string) *rpc.Function { + if b.funcRet == nil { + return nil + } + key := schema + "." + c.Function.Name + shape, ok := b.funcRet[key] + if !ok { + return nil + } + fn := &rpc.Function{Name: c.Function.Name, Returns: shape} + if v, ok := b.funcVol[key]; ok { + fn.Volatility = v + } + return fn +} + +// nativeCallReadOnly reports whether a native RPC call should run in a read-only +// transaction. A GET/HEAD is already read-only (plan.ReadOnly). For a POST, it +// downgrades to read-only when the resolved function is known to be STABLE or +// IMMUTABLE, matching PostgREST's access-mode table; an unknown function (not yet +// introspected) keeps the method-derived mode so a write still runs read-write. +func (b *Backend) nativeCallReadOnly(plan *ir.Plan, rc *reqctx.Context) bool { + if plan.ReadOnly { + return true + } + if b.funcVol == nil { + return false + } + key := b.callSchema(rc) + "." + plan.Call.Function.Name + v, ok := b.funcVol[key] + return ok && v.ReadOnly() +} diff --git a/backend/responsecontrols.go b/backend/responsecontrols.go new file mode 100644 index 0000000..234beb8 --- /dev/null +++ b/backend/responsecontrols.go @@ -0,0 +1,194 @@ +package backend + +import ( + "encoding/json" + "maps" + "strconv" + + "github.com/tamnd/dbrest/pgerr" + "github.com/tamnd/dbrest/reqctx" +) + +// The reserved output columns a portable registry function projects to steer the +// response. A backend with a SQL-readable session store (PostgreSQL) lets a +// function call set_config('response.status', ...) and current_setting reads it +// back; an emulated backend has no setting a single SELECT can write, so a +// portable function carries the same intent as result columns named exactly like +// the GUCs. The column values use the same shapes the GUCs take: an integer +// status and a JSON array of single-key {name: value} header objects. +const ( + ColResponseStatus = "response.status" + ColResponseHeaders = "response.headers" +) + +// HasResponseControlCols reports whether a result carries either reserved +// response-control column, so the caller can keep streaming the common case and +// only buffer when the controls must be lifted out. +func HasResponseControlCols(cols []string) bool { + for _, c := range cols { + if c == ColResponseStatus || c == ColResponseHeaders { + return true + } + } + return false +} + +// LiftResponseControls folds a portable registry function's reserved +// response-control columns into the response controls and removes them from the +// body. The values are read from the first row (a function sets one status and +// one header set per request, matching the GUC model); a result with no rows +// leaves the controls untouched. The returned columns and rows have the reserved +// columns stripped so they never reach the rendered body. A result with neither +// reserved column is returned unchanged. +// +// A response.status that is not a valid HTTP status code is PGRST112, and a +// response.headers that is not the array-of-single-key-objects shape is PGRST111, +// matching the way PostgREST rejects a junk GUC rather than forwarding it. The +// error returns before the controls are applied, so a volatile function's +// transaction rolls back through the caller's deferred rollback. +func LiftResponseControls(cols []string, rows [][]any, controls *reqctx.ResponseControls) ([]string, [][]any, *pgerr.APIError) { + statusIdx, headersIdx := -1, -1 + for i, c := range cols { + switch c { + case ColResponseStatus: + statusIdx = i + case ColResponseHeaders: + headersIdx = i + } + } + if statusIdx < 0 && headersIdx < 0 { + return cols, rows, nil + } + + if len(rows) > 0 && controls != nil { + first := rows[0] + if statusIdx >= 0 && statusIdx < len(first) { + if v := first[statusIdx]; v != nil { + code, ok := toStatus(v) + if !ok || !validStatus(code) { + return cols, rows, pgerr.ErrInvalidResponseStatus() + } + controls.SetStatus(code) + } + } + if headersIdx >= 0 && headersIdx < len(first) { + if v := first[headersIdx]; v != nil { + hdrs, ok := toHeaders(v) + if !ok { + return cols, rows, pgerr.ErrInvalidResponseHeaders() + } + maps.Copy(controlHeaders(controls), hdrs) + } + } + } + + drop := map[int]bool{} + if statusIdx >= 0 { + drop[statusIdx] = true + } + if headersIdx >= 0 { + drop[headersIdx] = true + } + cols, rows = stripColumns(cols, rows, drop) + return cols, rows, nil +} + +// validStatus reports whether a status override is in the range an HTTP response +// can carry. net/http panics outside 100..999; PostgREST rejects anything that is +// not a real status code, so the tighter 100..599 range is used. +func validStatus(code int) bool { return code >= 100 && code <= 599 } + +// controlHeaders returns the controls' header map, allocating it on first use so +// maps.Copy has a destination. +func controlHeaders(controls *reqctx.ResponseControls) map[string]string { + if controls.Headers == nil { + controls.Headers = map[string]string{} + } + return controls.Headers +} + +// toStatus reads a status override from a reserved column value, accepting the +// integer the column most often holds as well as the float and string forms a +// driver may surface. +func toStatus(v any) (int, bool) { + switch n := v.(type) { + case int64: + return int(n), true + case int: + return n, true + case float64: + return int(n), true + case json.Number: + if i, err := n.Int64(); err == nil { + return int(i), true + } + case string: + if i, err := strconv.Atoi(n); err == nil { + return i, true + } + case json.RawMessage: + if i, err := strconv.Atoi(string(n)); err == nil { + return i, true + } + } + return 0, false +} + +// toHeaders reads response headers from a reserved column value. The value is the +// JSON the GUC convention uses: an array of single-key {name: value} objects. A +// lone object is also accepted for convenience. ok is false when the value is +// present but not a JSON shape that can carry headers, the PGRST111 case. +func toHeaders(v any) (map[string]string, bool) { + var raw []byte + switch s := v.(type) { + case string: + raw = []byte(s) + case json.RawMessage: + raw = []byte(s) + case []byte: + raw = s + default: + return nil, false + } + out := map[string]string{} + var list []map[string]string + if err := json.Unmarshal(raw, &list); err == nil { + for _, obj := range list { + maps.Copy(out, obj) + } + return out, true + } + var obj map[string]string + if err := json.Unmarshal(raw, &obj); err == nil { + maps.Copy(out, obj) + return out, true + } + return nil, false +} + +// stripColumns returns the columns and rows with the dropped indices removed, +// preserving order. It allocates new slices so the caller's buffers are left +// intact. +func stripColumns(cols []string, rows [][]any, drop map[int]bool) ([]string, [][]any) { + keep := make([]int, 0, len(cols)) + for i := range cols { + if !drop[i] { + keep = append(keep, i) + } + } + outCols := make([]string, len(keep)) + for i, idx := range keep { + outCols[i] = cols[idx] + } + outRows := make([][]any, len(rows)) + for r, row := range rows { + nr := make([]any, len(keep)) + for i, idx := range keep { + if idx < len(row) { + nr[i] = row[idx] + } + } + outRows[r] = nr + } + return outCols, outRows +} diff --git a/backend/responsecontrols_test.go b/backend/responsecontrols_test.go new file mode 100644 index 0000000..d08a78a --- /dev/null +++ b/backend/responsecontrols_test.go @@ -0,0 +1,142 @@ +package backend + +import ( + "testing" + + "github.com/tamnd/dbrest/reqctx" +) + +func TestLiftResponseControlsNoReservedColumns(t *testing.T) { + cols := []string{"id", "title"} + rows := [][]any{{int64(1), "a"}} + var c reqctx.ResponseControls + gotCols, gotRows, err := LiftResponseControls(cols, rows, &c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(gotCols) != 2 || len(gotRows) != 1 { + t.Fatalf("result reshaped without reserved columns: %v %v", gotCols, gotRows) + } + if c.Status != 0 { + t.Errorf("status set without a reserved column: %d", c.Status) + } +} + +func TestLiftResponseControlsStatusAndStrip(t *testing.T) { + cols := []string{"message", ColResponseStatus} + rows := [][]any{{"gone", int64(410)}} + var c reqctx.ResponseControls + gotCols, gotRows, err := LiftResponseControls(cols, rows, &c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(gotCols) != 1 || gotCols[0] != "message" { + t.Errorf("columns = %v, want [message]", gotCols) + } + if len(gotRows[0]) != 1 || gotRows[0][0] != "gone" { + t.Errorf("row = %v, want [gone]", gotRows[0]) + } + if c.Status != 410 { + t.Errorf("status = %d, want 410", c.Status) + } +} + +func TestLiftResponseControlsStatusFromString(t *testing.T) { + cols := []string{ColResponseStatus} + rows := [][]any{{"201"}} + var c reqctx.ResponseControls + if _, _, err := LiftResponseControls(cols, rows, &c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if c.Status != 201 { + t.Errorf("status = %d, want 201", c.Status) + } +} + +func TestLiftResponseControlsHeadersArray(t *testing.T) { + cols := []string{ColResponseHeaders} + rows := [][]any{{`[{"X-A":"1"},{"X-B":"2"}]`}} + var c reqctx.ResponseControls + if _, _, err := LiftResponseControls(cols, rows, &c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if c.Headers["X-A"] != "1" || c.Headers["X-B"] != "2" { + t.Errorf("headers = %v, want X-A=1 X-B=2", c.Headers) + } +} + +func TestLiftResponseControlsHeadersObject(t *testing.T) { + cols := []string{ColResponseHeaders} + rows := [][]any{{`{"X-A":"1"}`}} + var c reqctx.ResponseControls + if _, _, err := LiftResponseControls(cols, rows, &c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if c.Headers["X-A"] != "1" { + t.Errorf("headers = %v, want X-A=1", c.Headers) + } +} + +func TestLiftResponseControlsNoRowsStillStrips(t *testing.T) { + cols := []string{"message", ColResponseStatus} + var c reqctx.ResponseControls + gotCols, gotRows, err := LiftResponseControls(cols, nil, &c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(gotCols) != 1 || gotCols[0] != "message" { + t.Errorf("columns = %v, want [message]", gotCols) + } + if len(gotRows) != 0 { + t.Errorf("rows = %v, want empty", gotRows) + } + if c.Status != 0 { + t.Errorf("status set from an empty result: %d", c.Status) + } +} + +// An out-of-range status is PGRST112, matching PostgREST's rejection of a junk +// response.status rather than forwarding it (and avoiding a net/http panic). +func TestLiftResponseControlsInvalidStatusRange(t *testing.T) { + cols := []string{ColResponseStatus} + rows := [][]any{{int64(9999)}} + var c reqctx.ResponseControls + _, _, err := LiftResponseControls(cols, rows, &c) + if err == nil || err.Code != "PGRST112" { + t.Fatalf("err = %v, want PGRST112", err) + } +} + +// A non-numeric status is PGRST112 too. +func TestLiftResponseControlsInvalidStatusText(t *testing.T) { + cols := []string{ColResponseStatus} + rows := [][]any{{"not-a-number"}} + var c reqctx.ResponseControls + _, _, err := LiftResponseControls(cols, rows, &c) + if err == nil || err.Code != "PGRST112" { + t.Fatalf("err = %v, want PGRST112", err) + } +} + +// A response.headers value that is not the array/object shape is PGRST111. +func TestLiftResponseControlsInvalidHeaders(t *testing.T) { + cols := []string{ColResponseHeaders} + rows := [][]any{{`"just a string"`}} + var c reqctx.ResponseControls + _, _, err := LiftResponseControls(cols, rows, &c) + if err == nil || err.Code != "PGRST111" { + t.Fatalf("err = %v, want PGRST111", err) + } +} + +func TestHasResponseControlCols(t *testing.T) { + if HasResponseControlCols([]string{"a", "b"}) { + t.Error("false positive on plain columns") + } + if !HasResponseControlCols([]string{"a", ColResponseStatus}) { + t.Error("missed response.status") + } + if !HasResponseControlCols([]string{ColResponseHeaders}) { + t.Error("missed response.headers") + } +} diff --git a/backend/spi.go b/backend/spi.go index 8573a29..4c8d94b 100644 --- a/backend/spi.go +++ b/backend/spi.go @@ -2,6 +2,7 @@ package backend import ( "context" + "fmt" "io" "github.com/tamnd/dbrest/ir" @@ -44,6 +45,17 @@ type Backend interface { Close() error } +// SchemaFunctioner is an optional capability of a NativeRPC backend that +// introspects its own functions: it exposes them as a registry per exposed schema, +// the function half of the schema cache. The frontend uses it to resolve native +// overloads, raise PGRST202/PGRST203, and partition GET arguments from result +// filters through the same planner the portable registry uses, instead of building +// a minimal plan and deferring everything to the engine. A backend that does not +// implement it keeps the verb-derived minimal plan. PostgreSQL implements it. +type SchemaFunctioner interface { + SchemaFunctions(schema string) rpc.Registry +} + // Result is the streaming response abstraction. A backend returns either an // assembled Body (the engine built the JSON) or a RowStream the renderer shapes // in Go. Which one is recorded by the JSONAssembly capability (spec 03). @@ -60,14 +72,40 @@ type Result interface { ResponseControls() *reqctx.ResponseControls } -// Explainer is an optional backend capability for the vnd.pgrst.plan+json -// Accept type. Backends that support EXPLAIN implement this interface; -// the frontend type-asserts to it and falls back to 406 when absent. +// PlanFormat is the output format an Accept: application/vnd.pgrst.plan request +// asks for. PostgREST defaults to text (bare type and the +text suffix); +json +// asks for the machine-readable form. +type PlanFormat uint8 + +const ( + PlanText PlanFormat = iota // default: EXPLAIN text output + PlanJSON // +json suffix: EXPLAIN (FORMAT JSON) +) + +// PlanOptions carries the parsed parameters of a plan Accept header. Format +// selects text vs json; For is the media type the plan is computed for (the +// for="" parameter, informational on the wire and echoed back); the +// booleans are the options= flags PostgREST forwards to EXPLAIN. +type PlanOptions struct { + Format PlanFormat + For string + Analyze bool + Verbose bool + Settings bool + Buffers bool + Wal bool +} + +// Explainer is an optional backend capability for the application/vnd.pgrst.plan +// Accept type. Backends that support EXPLAIN implement this interface; the +// frontend type-asserts to it and 406s when it is absent. The three methods +// mirror the three execution paths so a plan can be requested for a read, a +// write, or an RPC call. Each returns the engine's EXPLAIN output already +// formatted per opts.Format (text bytes or a JSON document). type Explainer interface { - // ExplainRead runs EXPLAIN on the read query and returns raw JSON from the - // engine's query planner. If analyze is true the engine also executes and - // times the query (EXPLAIN ANALYZE equivalent). - ExplainRead(ctx context.Context, p *ir.Plan, rc *reqctx.Context, analyze bool) ([]byte, error) + ExplainRead(ctx context.Context, p *ir.Plan, rc *reqctx.Context, opts PlanOptions) ([]byte, error) + ExplainWrite(ctx context.Context, p *ir.Plan, rc *reqctx.Context, opts PlanOptions) ([]byte, error) + ExplainCall(ctx context.Context, p *ir.Plan, rc *reqctx.Context, opts PlanOptions) ([]byte, error) } // RowStream is a forward-only cursor over result rows. The renderer drives it to @@ -84,3 +122,41 @@ type RowStream interface { // Close releases the cursor. Close() error } + +// EnforceMaxAffected is the Prefer: max-affected contract every write backend +// shares. WriteSpec.MaxRows is set only under handling=strict (ir.ParsePrefer +// clears it under lenient), so a non-nil bound always means "enforce". When the +// mutation affected more rows than the bound, it returns PGRST124; the backend +// then returns before commit and its deferred rollback discards the over-broad +// write. It returns nil when no bound is set, the affected count is unknown, or +// the count is within the bound. Callers must invoke it after the affected count +// is known and before commit. +func EnforceMaxAffected(w *ir.WriteSpec, affected int64, hasAffected bool) *pgerr.APIError { + if w == nil || w.MaxRows == nil || !hasAffected { + return nil + } + if affected > *w.MaxRows { + return pgerr.ErrMaxAffected(affected) + } + return nil +} + +// EnforceSingularWrite is the single-object guarantee a write makes when the +// client negotiated application/vnd.pgrst.object+json (q.Singular): the mutation +// must affect exactly one row. A zero-or-many result is PGRST116. Callers invoke +// it after the affected count is known and before commit, so the failure rolls +// the mutation back through the backend's deferred rollback rather than the +// renderer noticing the wrong count after the write is already durable +// (PostgREST's condemn discipline). The renderer keeps the same check for reads, +// where there is no transaction to roll back. It is a no-op for a non-singular +// request or when the count is unknown. +func EnforceSingularWrite(singular bool, affected int64, hasAffected bool) *pgerr.APIError { + if !singular || !hasAffected { + return nil + } + if affected != 1 { + return pgerr.ErrSingularZeroMany(). + WithDetails(fmt.Sprintf("The result contains %d rows", affected)) + } + return nil +} diff --git a/backend/sqlgen/call.go b/backend/sqlgen/call.go index 49dd87b..fba8f76 100644 --- a/backend/sqlgen/call.go +++ b/backend/sqlgen/call.go @@ -18,36 +18,188 @@ import ( // CompileCall lowers a resolved RPC call to a parameterized statement. The // function's SQL is rendered with its :name placeholders bound to the arguments -// (defaults filling omitted optional parameters); a table return additionally -// wraps the result so post-filters compile around it. -func CompileCall(d Dialect, c *ir.Call, fn *rpc.Function) (*Statement, *pgerr.APIError) { +// (defaults filling omitted optional parameters); a placeholder that is not a +// declared parameter binds a reserved request-context value from ctxArgs (see +// ContextArgs); a table return additionally wraps the result so post-filters +// compile around it. +func CompileCall(d Dialect, c *ir.Call, fn *rpc.Function, ctxArgs map[string]any) (*Statement, *pgerr.APIError) { + if fn == nil { + return nil, pgerr.ErrInternal("CompileCall requires a registry function; native calls compile in the backend") + } if fn.Query == nil || strings.TrimSpace(fn.Query.SQL) == "" { return nil, pgerr.ErrUnsupported("this function realization", "sql") } b := newBuilder(d) + b.ctxArgs = ctxArgs + + // A call with embeds projects the function's rows as a parent resource and + // nests each embedded relation, exactly as a table read does. The function + // result is the parent source; bindNamed runs inside the source writer so its + // placeholders bind after the projection's, keeping arguments in textual order. + if len(c.Embeds) > 0 { + return b.writeEmbeddedQuery(callQuery(c), func() *pgerr.APIError { + inner, err := b.bindNamed(fn, c.Args) + if err != nil { + return err + } + b.sb.WriteString("(") + b.sb.WriteString(inner) + b.sb.WriteString(")") + return nil + }) + } inner, err := b.bindNamed(fn, c.Args) if err != nil { return nil, err } - // Only a table return can be projected, filtered, ordered, and paginated; a - // scalar or setof-scalar return is the function's value(s) verbatim. - if fn.Returns.Kind != rpc.ReturnTable || !callHasPostFilter(c) { + // A single scalar or a void return is the function's value verbatim, with no + // row window to slice. Anything else with no post-filter clause is also + // verbatim. + isTable := fn.Returns.Kind == rpc.ReturnTable + isSetof := isTable || fn.Returns.Kind == rpc.ReturnSetOf + if !isSetof || !callHasPostFilter(c) { b.sb.WriteString(inner) return &Statement{SQL: b.sb.String(), Args: b.args}, nil } + // A setof return is wrapped so it can be paginated like a table read. A table + // return carries named columns, so it additionally supports projection, + // horizontal filters, and ordering; a setof-scalar has one anonymous column, + // so it takes only the limit/offset window. const alias = "_rpc" b.sb.WriteString("SELECT ") - if err := b.writeSelect(c.Select); err != nil { - return nil, err + if isTable { + if err := b.writeSelect(c.Select); err != nil { + return nil, err + } + } else { + b.sb.WriteString("*") } b.sb.WriteString(" FROM (") b.sb.WriteString(inner) b.sb.WriteString(") ") b.sb.WriteString(alias) + if isTable && c.Where != nil { + b.sb.WriteString(" WHERE ") + if err := b.writeCond(*c.Where); err != nil { + return nil, err + } + } + hasOrder := isTable && len(c.Order) > 0 + if hasOrder { + if err := b.writeOrder(c.Order); err != nil { + return nil, err + } + } + if clause := b.d.LimitOffset(c.Limit, c.Offset, hasOrder); clause != "" { + b.sb.WriteString(" ") + b.sb.WriteString(clause) + } + return &Statement{SQL: b.sb.String(), Args: b.args}, nil +} + +// CompileNativeCallWrap wraps a backend-built function-call statement in the read +// clauses a table-valued function result supports: vertical projection, horizontal +// filters, ordering, and the limit/offset window, aliased as _rpc. A NativeRPC +// backend (no portable function body, so CompileCall does not apply) renders the +// inner `SELECT * FROM schema.fn(args)` itself and passes it here so a native call +// shapes its result the same way the registry path does. With no post-filter +// clause the inner statement is returned unchanged. The inner statement's bound +// arguments are preserved and the wrapper's own filters bind after them, so +// placeholder numbering stays consistent. +func CompileNativeCallWrap(d Dialect, c *ir.Call, inner *Statement) (*Statement, *pgerr.APIError) { + if !callHasPostFilter(c) { + return inner, nil + } + b := newBuilder(d) + b.args = append(b.args, inner.Args...) + + const alias = "_rpc" + b.sb.WriteString("SELECT ") + if len(c.Select) > 0 { + if err := b.writeSelect(c.Select); err != nil { + return nil, err + } + } else { + b.sb.WriteString("*") + } + b.sb.WriteString(" FROM (") + b.sb.WriteString(inner.SQL) + b.sb.WriteString(") ") + b.sb.WriteString(alias) + + if c.Where != nil { + b.sb.WriteString(" WHERE ") + if err := b.writeCond(*c.Where); err != nil { + return nil, err + } + } + hasOrder := len(c.Order) > 0 + if hasOrder { + if err := b.writeOrder(c.Order); err != nil { + return nil, err + } + } + if clause := b.d.LimitOffset(c.Limit, c.Offset, hasOrder); clause != "" { + b.sb.WriteString(" ") + b.sb.WriteString(clause) + } + return &Statement{SQL: b.sb.String(), Args: b.args}, nil +} + +// CompileNativeCallCountWrap wraps a backend-built function-call statement in a +// count over its rows with the horizontal filter applied, so a native call's +// count=exact total matches the rows the post-filter returns. The select, order, +// and window do not change the count. Like CompileNativeCallWrap it is for the +// NativeRPC path where there is no portable body to drive CompileCallCount. +func CompileNativeCallCountWrap(d Dialect, c *ir.Call, inner *Statement) (*Statement, *pgerr.APIError) { + b := newBuilder(d) + b.args = append(b.args, inner.Args...) + const alias = "_rpc" + b.sb.WriteString("SELECT count(*) FROM (") + b.sb.WriteString(inner.SQL) + b.sb.WriteString(") ") + b.sb.WriteString(alias) + if c.Where != nil { + b.sb.WriteString(" WHERE ") + if err := b.writeCond(*c.Where); err != nil { + return nil, err + } + } + return &Statement{SQL: b.sb.String(), Args: b.args}, nil +} + +// CompileNativeCallCountedWrap wraps a backend-built function-call statement so the +// row query carries the total alongside the page in a single execution. It is the +// volatile path's count: a STABLE or IMMUTABLE function may be counted with a +// separate statement (it has no side effects), but a VOLATILE function must run +// exactly once, so count(*) OVER () rides the projection. The caller reads the +// total off any returned row and drops the _pgrst_count column. The select, +// filter, order, and window match CompileNativeCallWrap so the page is identical; +// count(*) OVER () counts the full filtered set because it is evaluated before the +// LIMIT. Unlike CompileNativeCallWrap it always wraps, since even a bare call needs +// the extra column. +func CompileNativeCallCountedWrap(d Dialect, c *ir.Call, inner *Statement) (*Statement, *pgerr.APIError) { + b := newBuilder(d) + b.args = append(b.args, inner.Args...) + const alias = "_rpc" + b.sb.WriteString("SELECT ") + if len(c.Select) > 0 { + if err := b.writeSelect(c.Select); err != nil { + return nil, err + } + } else { + b.sb.WriteString("*") + } + b.sb.WriteString(`, count(*) OVER () AS "`) + b.sb.WriteString(CountColName) + b.sb.WriteString(`" FROM (`) + b.sb.WriteString(inner.SQL) + b.sb.WriteString(") ") + b.sb.WriteString(alias) if c.Where != nil { b.sb.WriteString(" WHERE ") if err := b.writeCond(*c.Where); err != nil { @@ -73,18 +225,61 @@ func CompileCall(d Dialect, c *ir.Call, fn *rpc.Function) (*Statement, *pgerr.AP // read-only statement for a count=exact request, exactly as a table read's count // does. It is only valid for a read-only function; a volatile function must not // run twice. -func CompileCallCount(d Dialect, c *ir.Call, fn *rpc.Function) (*Statement, *pgerr.APIError) { +func CompileCallCount(d Dialect, c *ir.Call, fn *rpc.Function, ctxArgs map[string]any) (*Statement, *pgerr.APIError) { + if fn == nil { + // A native (non-registry) call has no portable body to count; the backend + // must build its own count wrapper. Returning an error rather than + // dereferencing a nil function keeps a misrouted native call from panicking. + return nil, pgerr.ErrInternal("CompileCallCount requires a registry function; native calls count in the backend") + } if fn.Query == nil || strings.TrimSpace(fn.Query.SQL) == "" { return nil, pgerr.ErrUnsupported("this function realization", "sql") } b := newBuilder(d) + b.ctxArgs = ctxArgs inner, err := b.bindNamed(fn, c.Args) if err != nil { return nil, err } + const alias = "_rpc" b.sb.WriteString("SELECT count(*) FROM (") b.sb.WriteString(inner) - b.sb.WriteString(") _rpc") + b.sb.WriteString(") ") + b.sb.WriteString(alias) + + // A count over an embedded call carries the same parent restriction the row + // query does: the post-filter WHERE plus one EXISTS per !inner embed, so the + // reported total matches the rows returned. + if len(c.Embeds) > 0 { + b.qual = alias + b.parentRef = alias + b.embeds = c.Embeds + wrote := false + if c.Where != nil { + b.sb.WriteString(" WHERE ") + if err := b.writeCond(*c.Where); err != nil { + return nil, err + } + wrote = true + } + for i := range c.Embeds { + emb := &c.Embeds[i] + if emb.Join != ir.JoinInner { + continue + } + if wrote { + b.sb.WriteString(" AND ") + } else { + b.sb.WriteString(" WHERE ") + wrote = true + } + if err := b.writeEmbedExists(emb, alias); err != nil { + return nil, err + } + } + return &Statement{SQL: b.sb.String(), Args: b.args}, nil + } + if fn.Returns.Kind == rpc.ReturnTable && c.Where != nil { b.sb.WriteString(" WHERE ") if err := b.writeCond(*c.Where); err != nil { @@ -94,6 +289,23 @@ func CompileCallCount(d Dialect, c *ir.Call, fn *rpc.Function) (*Statement, *pge return &Statement{SQL: b.sb.String(), Args: b.args}, nil } +// callQuery projects an RPC call's read shape onto an ir.Query so the embedded +// read writer (shared with table reads) drives the parent projection, filters, +// ordering, and window. The relation is the function's resolved result relation, +// already bound onto each embed by the planner; the embedded writer reads the +// embeds, not q.Relation, so a bare Ref naming the result relation is enough. +func callQuery(c *ir.Call) *ir.Query { + return &ir.Query{ + Kind: ir.Read, + Select: c.Select, + Where: c.Where, + Order: c.Order, + Limit: c.Limit, + Offset: c.Offset, + Embeds: c.Embeds, + } +} + // callHasPostFilter reports whether a call carries any clause that wraps the // function result (a projection, a filter, an ordering, or a window). func callHasPostFilter(c *ir.Call) bool { @@ -142,17 +354,33 @@ func (b *builder) bindNamed(fn *rpc.Function, args map[string]ir.Value) (string, func (b *builder) argValue(fn *rpc.Function, name string, args map[string]ir.Value) (string, *pgerr.APIError) { p, ok := fn.Param(name) if !ok { + // Not a declared parameter: a reserved request-context placeholder + // binds the frontend-built value (spec 15: the emulated engines bind + // context, they never read a session store). A declared parameter of + // the same name takes this path only when undeclared, so it cannot be + // shadowed away by a caller. + if v, isCtx := b.ctxArgs[name]; isCtx { + return b.bind(v), nil + } return "", pgerr.ErrInternal("rpc body references undeclared parameter :" + name) } if v, ok := args[name]; ok { - return b.bind(callArg(v)), nil + if p.Variadic { + return b.bindVariadic(v), nil + } + return b.bind(callArg(b.d, v)), nil + } + if p.Variadic { + // A variadic call with no trailing arguments expands to nothing, so a body + // spelling the placeholder inside a call list (f(:nums)) becomes f(). + return "", nil } if p.Optional { return b.bind(p.Default), nil } // A required parameter with no argument cannot happen: Lookup only returns an // overload whose required parameters are all present. Guard anyway. - return "", pgerr.ErrNoFunction(fn.Name) + return "", pgerr.ErrInternal("rpc call is missing required parameter :" + name) } // singleObjectArgs implements the single-unnamed-argument form: a function whose @@ -178,14 +406,54 @@ func singleObjectArgs(fn *rpc.Function, args map[string]ir.Value) map[string]ir. // callArg converts an argument value to a driver argument. A POST argument has a // decoded JSON value (numbers preserved, objects/arrays re-encoded to text); a -// GET argument is the raw query-string text. Type coercion to the declared -// parameter type lands with the types subsystem (spec 16). -func callArg(v ir.Value) any { +// GET argument is the raw query-string text, bound verbatim. An empty query value +// is an empty string, not NULL: PostgREST passes "" to the parameter, and a NULL +// is expressed only by omitting the argument (which binds the parameter default). +func callArg(d Dialect, v ir.Value) any { + if v.JSON != nil { + return writeArg(d, v, "") + } + return v.Text +} + +// bindVariadic expands a variadic argument into a comma-separated list of bound +// placeholders, one per collected element, so a body spelling the placeholder +// inside a call list or an IN (:name) clause receives every value. A GET call +// arrives as a text list; a POST call arrives as a decoded JSON array (a lone +// scalar is treated as a one-element list). An empty list binds nothing. +func (b *builder) bindVariadic(v ir.Value) string { + elems := variadicElems(b.d, v) + parts := make([]string, len(elems)) + for i, e := range elems { + parts[i] = b.bind(e) + } + return strings.Join(parts, ", ") +} + +// variadicElems flattens a variadic argument value into its driver elements. A +// GET text list maps each item verbatim; a POST JSON array maps each element +// through the write-value path (numbers preserved, nested documents re-encoded); +// any other single value is a one-element list. +func variadicElems(d Dialect, v ir.Value) []any { + if v.List != nil { + out := make([]any, len(v.List)) + for i, s := range v.List { + out[i] = s + } + return out + } + if arr, ok := v.JSON.([]any); ok { + out := make([]any, len(arr)) + for i, e := range arr { + out[i] = writeArg(d, ir.Value{JSON: e}, "") + } + return out + } if v.JSON != nil { - return writeArg(v) + return []any{writeArg(d, v, "")} } if v.Text != "" { - return v.Text + return []any{v.Text} } return nil } diff --git a/backend/sqlgen/call_test.go b/backend/sqlgen/call_test.go index cb63f43..1f38424 100644 --- a/backend/sqlgen/call_test.go +++ b/backend/sqlgen/call_test.go @@ -9,7 +9,7 @@ import ( func compileCall(t *testing.T, c *ir.Call, fn *rpc.Function) *Statement { t.Helper() - st, err := CompileCall(stub{}, c, fn) + st, err := CompileCall(stub{}, c, fn, nil) if err != nil { t.Fatalf("CompileCall: %v", err) } @@ -62,6 +62,44 @@ func TestCompileCallOptionalDefault(t *testing.T) { } } +// A variadic parameter expands its placeholder into one bound value per element, +// so an IN (:ids) clause binds every collected id. The GET list form is exercised +// here; the POST array form lands the same elements through the JSON path. +func TestCompileCallVariadicExpandsList(t *testing.T) { + fn := &rpc.Function{ + Name: "pick", + Params: []rpc.Param{{Name: "ids", Variadic: true}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf}, + Query: &rpc.PortableQuery{SQL: "SELECT title FROM films WHERE id IN (:ids)"}, + } + c := &ir.Call{Args: map[string]ir.Value{"ids": {List: []string{"1", "3"}}}} + st := compileCall(t, c, fn) + if st.SQL != "SELECT title FROM films WHERE id IN ($1, $2)" { + t.Errorf("SQL = %q", st.SQL) + } + if len(st.Args) != 2 || st.Args[0] != "1" || st.Args[1] != "3" { + t.Errorf("Args = %v, want [1 3]", st.Args) + } +} + +// A variadic call with no trailing arguments expands to nothing, so f(:ids) +// becomes f() and binds no values. +func TestCompileCallVariadicEmpty(t *testing.T) { + fn := &rpc.Function{ + Name: "pick", + Params: []rpc.Param{{Name: "ids", Variadic: true}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf}, + Query: &rpc.PortableQuery{SQL: "SELECT count_ids(:ids)"}, + } + st := compileCall(t, &ir.Call{Args: map[string]ir.Value{}}, fn) + if st.SQL != "SELECT count_ids()" { + t.Errorf("SQL = %q, want SELECT count_ids()", st.SQL) + } + if len(st.Args) != 0 { + t.Errorf("Args = %v, want none", st.Args) + } +} + func TestCompileCallTableWithPostFilter(t *testing.T) { fn := &rpc.Function{ Name: "films_after", @@ -117,7 +155,7 @@ func TestCompileCallSingleObjectArg(t *testing.T) { func TestCompileCallNoRealizationUnsupported(t *testing.T) { fn := &rpc.Function{Name: "native_only", Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar}} - _, err := CompileCall(stub{}, &ir.Call{}, fn) + _, err := CompileCall(stub{}, &ir.Call{}, fn, nil) if err == nil || err.Code != "PGRST127" { t.Fatalf("want PGRST127, got %v", err) } @@ -135,7 +173,7 @@ func TestCompileCallCountWrapsAndFilters(t *testing.T) { } where := ir.Cond(ir.Compare{Path: []string{"id"}, Op: ir.OpGt, Value: ir.Value{Text: "10"}}) c := &ir.Call{Args: map[string]ir.Value{"y": {Text: "2000"}}, Where: &where} - st, err := CompileCallCount(stub{}, c, fn) + st, err := CompileCallCount(stub{}, c, fn, nil) if err != nil { t.Fatalf("CompileCallCount: %v", err) } @@ -152,8 +190,74 @@ func TestCompileCallCountWrapsAndFilters(t *testing.T) { // PGRST127 rather than running an empty body. func TestCompileCallCountNoRealizationUnsupported(t *testing.T) { fn := &rpc.Function{Name: "native_only", Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar}} - _, err := CompileCallCount(stub{}, &ir.Call{}, fn) + _, err := CompileCallCount(stub{}, &ir.Call{}, fn, nil) if err == nil || err.Code != "PGRST127" { t.Fatalf("want PGRST127, got %v", err) } } + +// A nil function (a misrouted native call) reports an error instead of +// dereferencing the nil pointer, the regression behind the count=exact crash. +func TestCompileCallCountNilFunctionErrors(t *testing.T) { + _, err := CompileCallCount(stub{}, &ir.Call{}, nil, nil) + if err == nil { + t.Fatal("want an error for a nil function, got nil") + } +} + +func TestCompileCallNilFunctionErrors(t *testing.T) { + _, err := CompileCall(stub{}, &ir.Call{}, nil, nil) + if err == nil { + t.Fatal("want an error for a nil function, got nil") + } +} + +// A placeholder that is not a declared parameter binds the reserved request- +// context value, the emulated analog of current_setting('request.method'). +func TestCompileCallContextPlaceholder(t *testing.T) { + fn := &rpc.Function{ + Name: "get_request_method", + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar}, + Query: &rpc.PortableQuery{SQL: "SELECT :request_method"}, + } + st, err := CompileCall(stub{}, &ir.Call{}, fn, map[string]any{"request_method": "GET"}) + if err != nil { + t.Fatalf("CompileCall: %v", err) + } + if st.SQL != "SELECT $1" { + t.Errorf("SQL = %q, want SELECT $1", st.SQL) + } + if len(st.Args) != 1 || st.Args[0] != "GET" { + t.Errorf("Args = %v, want [GET]", st.Args) + } +} + +// A declared parameter of the same name keeps winning over the context value. +func TestCompileCallDeclaredParamBeatsContext(t *testing.T) { + fn := &rpc.Function{ + Name: "f", + Params: []rpc.Param{{Name: "request_method"}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar}, + Query: &rpc.PortableQuery{SQL: "SELECT :request_method"}, + } + c := &ir.Call{Args: map[string]ir.Value{"request_method": {Text: "caller"}}} + st, err := CompileCall(stub{}, c, fn, map[string]any{"request_method": "GET"}) + if err != nil { + t.Fatalf("CompileCall: %v", err) + } + if len(st.Args) != 1 || st.Args[0] != "caller" { + t.Errorf("Args = %v, want [caller]", st.Args) + } +} + +// Without context values an undeclared placeholder is still an internal error. +func TestCompileCallUndeclaredPlaceholderRejected(t *testing.T) { + fn := &rpc.Function{ + Name: "f", + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar}, + Query: &rpc.PortableQuery{SQL: "SELECT :nope"}, + } + if _, err := CompileCall(stub{}, &ir.Call{}, fn, nil); err == nil { + t.Fatal("want error for undeclared placeholder") + } +} diff --git a/backend/sqlgen/compile.go b/backend/sqlgen/compile.go index a31e955..5aa0916 100644 --- a/backend/sqlgen/compile.go +++ b/backend/sqlgen/compile.go @@ -9,6 +9,7 @@ import ( "github.com/tamnd/dbrest/ir" "github.com/tamnd/dbrest/pgerr" + "github.com/tamnd/dbrest/pgtypes" ) // CountColName is the synthetic column appended by CompileReadCounted to carry @@ -36,6 +37,35 @@ type builder struct { args []any qual string aliasN int + // parentRef is how an EmbedPredicate's EXISTS/NOT EXISTS correlates back to the + // outer row: the parent alias (t0) in an embedded read, or the qualified table + // name in a count where the parent has no alias. embeds is the parent query's + // embed list an EmbedPredicate indexes into. + parentRef string + embeds []ir.Embed + // groupBy collects the non-aggregate projected column expressions while the + // select list is written; when the projection also carries an aggregate, these + // become the GROUP BY so the aggregate folds per distinct value. hasAgg records + // whether any aggregate was seen. + groupBy []string + hasAgg bool + // ctxArgs are the reserved :request_* values an RPC body may bind when a + // placeholder is not a declared parameter; see ContextArgs. + ctxArgs map[string]any + // computed maps the current relation's computed-field names to the schema of + // the function backing each, so colRef can render a computed field as a function + // call on the row instead of a bare column. It is swapped alongside qual when + // descending into an embed, since each relation has its own computed fields. + // rootRow is the row reference passed to a computed-field function when no alias + // is in force (qual is empty): the unqualified base-relation name, which names + // the row of the single FROM relation. See spec 11 (computed fields). + computed map[string]string + rootRow string + // reps maps the current relation's column names to their data-representation + // cast functions (spec 11), swapped alongside computed when descending into an + // embed. A column with one reformats on the wire: ToJSON on read, FromJSON on a + // write value, FromText on a filter literal. + reps map[string]ir.Rep } // newBuilder starts a builder with an empty output buffer. @@ -66,12 +96,85 @@ func (b *builder) bind(v any) string { // colRef renders a column reference, qualified by the current table alias when // one is set (inside an embed subquery) and bare otherwise. func (b *builder) colRef(name string) string { + if sch, ok := b.computed[name]; ok { + // A computed field renders as schema.func(row): the row is the current alias + // inside an embed, or the bare relation name at the top level (qual empty). + rowArg := b.qual + if rowArg == "" { + rowArg = b.rootRow + } + return b.d.QuoteIdent(sch) + "." + b.d.QuoteIdent(name) + "(" + rowArg + ")" + } if b.qual == "" { return b.d.QuoteIdent(name) } return b.qual + "." + b.d.QuoteIdent(name) } +// useRelation points the builder's computed-field and data-representation +// rendering at one relation: the name-to-schema map for its computed fields, the +// column-to-cast map for its representations, and the unqualified name to pass as +// the row argument when no alias is in force. It returns the previous trio so a +// caller descending into an embed can restore the parent's on the way out. +func (b *builder) useRelation(q *ir.Query, relName string) (map[string]string, map[string]ir.Rep, string) { + savedC, savedReps, savedR := b.computed, b.reps, b.rootRow + b.computed = q.Computed + b.reps = q.Reps + b.rootRow = b.d.QuoteIdent(relName) + return savedC, savedReps, savedR +} + +// repCall renders a representation cast-function call: schema.func(arg). It is how +// a domain's to-json/from-text/from-json cast is applied, since PostgreSQL ignores +// the cast in the `::` operator and only the function does the reformatting. +func (b *builder) repCall(funcSchema, funcName, arg string) string { + return b.d.QuoteIdent(funcSchema) + "." + b.d.QuoteIdent(funcName) + "(" + arg + ")" +} + +// fromTextValue binds a filter operand, parsing it through the column's from-text +// data representation when one is present (spec 11). It mirrors PostgREST's +// pgFmtUnknownLiteralForField: the domain's text cast wraps the operand for every +// operator that compares against a typed value (eq, neq, the orderings, regex +// match, the array/range operators, and each IN element). The placeholder is +// typed text so the schema-qualified cast function resolves as a call rather than +// as the domain's own input syntax. PostgREST skips the parse for like/ilike (the +// operand is a wildcard pattern), full-text search, and is, so those callers bind +// the literal directly instead of going through here. A JSON-path operand is +// never a represented column, and a column with no from-text cast binds the +// literal unchanged. +func (b *builder) fromTextValue(colName string, isJSON bool, raw string) string { + ph := b.bind(raw) + if isJSON { + return ph + } + if rep, ok := b.reps[colName]; ok && rep.FromTextFunc != "" { + return b.repCall(rep.FromTextSchema, rep.FromTextFunc, ph+"::text") + } + return ph +} + +// filterValue binds a comparison literal through the column's from-text data +// representation, the common path for eq/neq and the orderings. +func (b *builder) filterValue(c ir.Compare) string { + return b.fromTextValue(c.Path[0], len(c.Path) > 1, c.Value.Text) +} + +// writeValue binds an insert/update value, parsing it through the column's +// from-json data representation when one is present (spec 11): the body value is +// bound as json text and passed to the domain's json cast, the same parse +// PostgREST applies to a write. A column with no from-json cast binds the coerced +// value through writeArg as usual. +func (b *builder) writeValue(col string, v ir.Value, colType string) string { + if rep, ok := b.reps[col]; ok && rep.FromJSONFunc != "" { + js, err := json.Marshal(v.JSON) + if err != nil { + js = []byte("null") + } + return b.repCall(rep.FromJSONSchema, rep.FromJSONFunc, b.bind(string(js))+"::json") + } + return b.bind(writeArg(b.d, v, colType)) +} + // CompileRead lowers a resolved read query to a row-returning SELECT. The result // is a parameterized statement the backend hands to the driver; the renderer // shapes the returned rows into the response document. @@ -104,6 +207,7 @@ func CompileReadCounted(d Dialect, q *ir.Query) (*Statement, *pgerr.APIError) { // can extract the total alongside the result rows. func compileReadPlain(d Dialect, q *ir.Query, withCount bool) (*Statement, *pgerr.APIError) { b := newBuilder(d) + b.useRelation(q, q.Relation.Name) b.sb.WriteString("SELECT ") if err := b.writeSelect(q.Select); err != nil { @@ -126,6 +230,14 @@ func compileReadPlain(d Dialect, q *ir.Query, withCount bool) (*Statement, *pger } } + // An aggregate folds over the rest of the projection: the plain columns become + // the GROUP BY keys. With only aggregates and no plain column, the whole + // relation is one group and no GROUP BY is emitted. + if b.hasAgg && len(b.groupBy) > 0 { + b.sb.WriteString(" GROUP BY ") + b.sb.WriteString(strings.Join(b.groupBy, ", ")) + } + hasOrder := len(q.Order) > 0 if hasOrder { if err := b.writeOrder(q.Order); err != nil { @@ -148,14 +260,68 @@ func compileReadPlain(d Dialect, q *ir.Query, withCount bool) (*Statement, *pger func CompileCount(d Dialect, q *ir.Query) (*Statement, *pgerr.APIError) { b := newBuilder(d) b.sb.WriteString("SELECT count(*) FROM ") - b.sb.WriteString(b.qualify(q.Relation)) + if err := b.writeCountFromAndPredicates(q); err != nil { + return nil, err + } + return &Statement{SQL: b.sb.String(), Args: b.args}, nil +} + +// CompileRowEstimateSource lowers a read query to a row-producing SELECT over the +// same relation and predicates the count covers, with no aggregate. A backend +// that estimates a count (count=planned / count=estimated) EXPLAINs this query +// and reads the planner's row estimate off the root node; the count(*) wrapper +// would instead estimate the aggregate's single output row. The empty target +// list (SELECT FROM) keeps it estimate-only: it is never fetched (item 07.7). +func CompileRowEstimateSource(d Dialect, q *ir.Query) (*Statement, *pgerr.APIError) { + b := newBuilder(d) + b.sb.WriteString("SELECT FROM ") + if err := b.writeCountFromAndPredicates(q); err != nil { + return nil, err + } + return &Statement{SQL: b.sb.String(), Args: b.args}, nil +} + +// writeCountFromAndPredicates emits the parent relation and the predicates a +// count ranges over: the horizontal WHERE and an EXISTS per !inner embed, the +// same set the windowed read applies so an exact count matches its body. The +// caller has already written the SELECT list up to FROM. +func (b *builder) writeCountFromAndPredicates(q *ir.Query) *pgerr.APIError { + b.useRelation(q, q.Relation.Name) + parent := b.qualify(q.Relation) + b.sb.WriteString(parent) + + // An embed-existence filter and an !inner embed's EXISTS both correlate to the + // parent by its bare table name here, since a count gives the parent no alias. + b.parentRef = parent + b.embeds = q.Embeds + + wrote := false if q.Where != nil { b.sb.WriteString(" WHERE ") if err := b.writeCond(*q.Where); err != nil { - return nil, err + return err } + wrote = true } - return &Statement{SQL: b.sb.String(), Args: b.args}, nil + // The row query restricts the parent with an EXISTS per !inner embed + // (compileReadEmbedded), so the count must carry the same predicates or + // Content-Range disagrees with the body it accompanies. + for i := range q.Embeds { + emb := &q.Embeds[i] + if emb.Join != ir.JoinInner { + continue + } + if wrote { + b.sb.WriteString(" AND ") + } else { + b.sb.WriteString(" WHERE ") + wrote = true + } + if err := b.writeEmbedExists(emb, parent); err != nil { + return err + } + } + return nil } // CompileInsert lowers an insert (or upsert) to a parameterized INSERT. Every @@ -170,6 +336,7 @@ func CompileInsert(d Dialect, q *ir.Query, returning []string) (*Statement, *pge return nil, pgerr.ErrParse("insert payload is empty") } b := newBuilder(d) + b.useRelation(q, q.Relation.Name) b.sb.WriteString("INSERT INTO ") b.sb.WriteString(b.qualify(q.Relation)) @@ -195,7 +362,7 @@ func CompileInsert(d Dialect, q *ir.Query, returning []string) (*Statement, *pge b.sb.WriteString(", ") } if val, ok := row[c]; ok { - b.sb.WriteString(b.bind(writeArg(val))) + b.sb.WriteString(b.writeValue(c, val, w.ColumnTypes[c])) } else if w.Missing == ir.MissingNull { b.sb.WriteString(b.bind(nil)) } else { @@ -206,7 +373,12 @@ func CompileInsert(d Dialect, q *ir.Query, returning []string) (*Statement, *pge } } - if w.Conflict != nil { + // An upsert with no resolvable conflict target (a table or view without a + // primary key, no on_conflict given) has nothing to merge or ignore on, so it + // degrades to a plain INSERT, the same as PostgREST: a merge or ignore POST to + // a key-less table inserts the rows and returns 201. Emitting ON CONFLICT here + // would produce invalid SQL ("ON CONFLICT DO UPDATE requires inference"). + if w.Conflict != nil && len(w.Conflict.Target) > 0 { if err := b.writeConflict(w); err != nil { return nil, err } @@ -226,6 +398,7 @@ func CompileUpdate(d Dialect, q *ir.Query, returning []string) (*Statement, *pge return nil, pgerr.ErrParse("update payload is empty") } b := newBuilder(d) + b.useRelation(q, q.Relation.Name) b.sb.WriteString("UPDATE ") b.sb.WriteString(b.qualify(q.Relation)) b.sb.WriteString(" SET ") @@ -236,7 +409,7 @@ func CompileUpdate(d Dialect, q *ir.Query, returning []string) (*Statement, *pge } b.sb.WriteString(d.QuoteIdent(c)) b.sb.WriteString(" = ") - b.sb.WriteString(b.bind(writeArg(w.Set[c]))) + b.sb.WriteString(b.writeValue(c, w.Set[c], w.ColumnTypes[c])) } if q.Where != nil { b.sb.WriteString(" WHERE ") @@ -254,6 +427,7 @@ func CompileUpdate(d Dialect, q *ir.Query, returning []string) (*Statement, *pge // update, a delete without a filter removes every row. func CompileDelete(d Dialect, q *ir.Query, returning []string) (*Statement, *pgerr.APIError) { b := newBuilder(d) + b.useRelation(q, q.Relation.Name) b.sb.WriteString("DELETE FROM ") b.sb.WriteString(b.qualify(q.Relation)) if q.Where != nil { @@ -301,6 +475,15 @@ func (b *builder) writeReturning(cols []string) *pgerr.APIError { } quoted := make([]string, len(cols)) for i, c := range cols { + // A data-representation column reads back through its to-json cast, the same + // formatting a read applies, so return=representation carries what a later GET + // would return (spec 11). The cast function would otherwise name the output + // column after itself, so alias it back to the column name. + if rep, ok := b.reps[c]; ok && rep.ToJSONFunc != "" { + id := b.d.QuoteIdent(c) + quoted[i] = b.repCall(rep.ToJSONSchema, rep.ToJSONFunc, id) + " AS " + id + continue + } quoted[i] = b.d.QuoteIdent(c) } clause, ok := b.d.Returning(quoted) @@ -313,14 +496,18 @@ func (b *builder) writeReturning(cols []string) *pgerr.APIError { } // WriteArg converts a decoded JSON payload value to a driver argument. Numbers -// arrive as json.Number (the decoder preserves integer precision); objects and -// arrays are re-encoded to their JSON text so they land in a json/text column. -// It is exported for backends (e.g. the COPY path) that need the same coercion -// without going through the SQL builder. -func WriteArg(v ir.Value) any { return writeArg(v) } - -// writeArg is the unexported implementation used by the builder methods. -func writeArg(v ir.Value) any { +// arrive as json.Number (the decoder preserves integer precision); objects are +// re-encoded to their JSON text so they land in a json/text column; arrays go +// through the dialect, which knows whether the engine wants a PostgreSQL +// {a,b} array literal or JSON text. It is exported for backends (e.g. the +// MERGE path) that need the same coercion without going through the SQL +// builder. +func WriteArg(d Dialect, v ir.Value, colType string) any { return writeArg(d, v, colType) } + +// writeArg is the unexported implementation used by the builder methods. colType +// is the target column's canonical type, which steers how a JSON array value is +// bound (see Dialect.ArrayArg); an empty colType keeps the engine default. +func writeArg(d Dialect, v ir.Value, colType string) any { switch x := v.JSON.(type) { case nil: return nil @@ -333,10 +520,7 @@ func writeArg(v ir.Value) any { } return x.String() case []any: - // PostgreSQL array columns use {elem1,elem2} input syntax, not JSON - // ["elem1","elem2"]. Build the array literal so the server-side cast - // from text to text[]/int4[]/etc. succeeds with or without type OIDs. - return pgArrayLiteral(x) + return d.ArrayArg(x, colType) case map[string]any: bs, err := json.Marshal(x) if err != nil { @@ -348,13 +532,43 @@ func writeArg(v ir.Value) any { } } -// pgArrayLiteral converts a JSON array into a PostgreSQL array literal string +// IsJSONArrayIndex reports whether a JSON path segment is an array index: a +// non-empty run of ASCII digits. PostgREST treats a digit hop (data->arr->0) as +// an array subscript rather than an object key, and the dialects spell it as +// one. A leading-zero or oversized run still counts as an index; the engine +// decides what a missing element yields. +func IsJSONArrayIndex(seg string) bool { + if seg == "" { + return false + } + for i := 0; i < len(seg); i++ { + if seg[i] < '0' || seg[i] > '9' { + return false + } + } + return true +} + +// JSONArrayArg re-encodes a decoded JSON array to its JSON text. It is the +// ArrayArg implementation for engines without array columns, where a write +// payload array lands in a json/text column and must read back as JSON. +func JSONArrayArg(elems []any) any { + bs, err := json.Marshal(elems) + if err != nil { + return nil + } + return string(bs) +} + +// PGArrayLiteral converts a JSON array into a PostgreSQL array literal string // of the form {elem1,"elem with spaces",NULL}. Elements that are plain // alphanumeric strings (and json.Number/bool) are emitted unquoted; strings // that contain commas, braces, backslashes, double-quotes, or whitespace are // double-quoted with internal backslash escaping, matching PostgreSQL's own -// array output format. -func pgArrayLiteral(elems []any) string { +// array output format. It is the PostgreSQL Dialect's ArrayArg: the literal +// text lets the server-side cast from text to text[]/int4[]/etc. succeed with +// or without type OIDs. +func PGArrayLiteral(elems []any) string { var sb strings.Builder sb.WriteByte('{') for i, e := range elems { @@ -447,27 +661,75 @@ func (b *builder) writeSelect(items []ir.SelectItem) *pgerr.APIError { if i > 0 { b.sb.WriteString(", ") } - col, ok := it.(ir.Column) - if !ok { - return pgerr.ErrUnsupported("aggregates and embedded resources in select", "sql") - } - expr, err := b.columnExpr(col) - if err != nil { - return err - } - b.sb.WriteString(expr) - // Alias the output so the renderer sees the PostgREST key, not the raw - // column. Only needed when the key differs from the bare column name. - if name := col.Name(); name != "" && name != lastPath(col.Path) { + switch v := it.(type) { + case ir.Column: + expr, err := b.columnExpr(v) + if err != nil { + return err + } + b.sb.WriteString(expr) + // Alias the output so the renderer sees the PostgREST key, not the raw + // column expression. Always alias when a cast is present, an explicit + // alias was set, the column is a JSON path (data->>x names its column + // after the last hop, the way upstream does), or a data representation + // wrapped the column in a cast function (which would otherwise name the + // output column after the function, not the column). + if name := v.Name(); name != "" && (name != lastPath(v.Path) || v.Cast != "" || len(v.Path) > 1 || b.repAppliedToJSON(v)) { + b.sb.WriteString(" AS ") + b.sb.WriteString(b.d.QuoteIdent(name)) + } + // A plain column alongside an aggregate is a GROUP BY key. + b.groupBy = append(b.groupBy, expr) + case ir.Aggregate: + expr, err := b.aggregateExpr(v) + if err != nil { + return err + } + b.sb.WriteString(expr) b.sb.WriteString(" AS ") - b.sb.WriteString(b.d.QuoteIdent(name)) + b.sb.WriteString(b.d.QuoteIdent(v.Name())) + b.hasAgg = true + default: + return pgerr.ErrUnsupported("embedded resources in select", "sql") } } return nil } -// columnExpr renders a base column with an optional cast. JSON sub-paths are a -// later subsystem; a column carrying one is rejected explicitly. +// aggregateExpr renders an aggregate call: count(*) for a bare count, or +// func(arg) over the aggregated column, with an optional input cast on the +// column and an optional output cast wrapping the result. +func (b *builder) aggregateExpr(a ir.Aggregate) (string, *pgerr.APIError) { + fn := a.Func.String() + var inner string + if a.Arg == nil { + inner = fn + "(*)" + } else { + arg, err := b.columnExpr(*a.Arg) + if err != nil { + return "", err + } + inner = fn + "(" + arg + ")" + } + if a.Cast != "" { + inner = b.d.Cast(inner, a.Cast) + } + return inner, nil +} + +// jsonPathExpr lowers a base column carrying a JSON sub-path through the dialect. +// hops are the segments after the base; last reports the final ->/->> kind. An +// engine without JSON paths reports ok=false and the request is PGRST127. +func (b *builder) jsonPathExpr(base string, hops []string, last ir.JSONStep) (string, *pgerr.APIError) { + frag, ok := b.d.JSONPath(base, hops, last == ir.JSONArrow2) + if !ok { + return "", pgerr.ErrUnsupported("JSON path", "sql") + } + return frag, nil +} + +// columnExpr renders a base column with an optional cast, lowering a JSON +// sub-path (col->a->>b) through the dialect when the column carries one. func (b *builder) columnExpr(c ir.Column) (string, *pgerr.APIError) { if len(c.Path) == 1 && c.Path[0] == "*" && c.Last == ir.JSONNone && c.Cast == "" { if b.qual == "" { @@ -475,16 +737,44 @@ func (b *builder) columnExpr(c ir.Column) (string, *pgerr.APIError) { } return b.qual + ".*", nil } - if len(c.Path) != 1 || c.Last != ir.JSONNone { - return "", pgerr.ErrUnsupported("JSON path projection", "sql") + var expr string + if len(c.Path) > 1 { + frag, err := b.jsonPathExpr(b.colRef(c.Path[0]), c.Path[1:], c.Last) + if err != nil { + return "", err + } + expr = frag + } else { + expr = b.colRef(c.Path[0]) + // A data-representation column reformats on output through its to-json cast + // (spec 11): the stored value is passed to the cast function, which yields the + // json the response carries. An explicit client cast (col::type) opts out, the + // client having asked for a specific rendering instead. + if c.Cast == "" { + if rep, ok := b.reps[c.Path[0]]; ok && rep.ToJSONFunc != "" { + expr = b.repCall(rep.ToJSONSchema, rep.ToJSONFunc, expr) + } + } } - expr := b.colRef(c.Path[0]) if c.Cast != "" { expr = b.d.Cast(expr, c.Cast) } return expr, nil } +// repAppliedToJSON reports whether a plain base column carries a to-json data +// representation that columnExpr will apply, so writeSelect knows to alias the +// projection to the column name (the cast function would otherwise name the output +// column after itself). A JSON sub-path or an explicit client cast opts out, the +// same conditions columnExpr checks. +func (b *builder) repAppliedToJSON(c ir.Column) bool { + if len(c.Path) != 1 || c.Cast != "" { + return false + } + rep, ok := b.reps[c.Path[0]] + return ok && rep.ToJSONFunc != "" +} + func lastPath(path []string) string { if len(path) == 0 { return "" @@ -508,11 +798,28 @@ func (b *builder) writeCond(c ir.Cond) *pgerr.APIError { return nil case ir.Compare: return b.writeCompare(n) + case ir.EmbedPredicate: + return b.writeEmbedPredicate(n) default: return pgerr.ErrInternal(fmt.Sprintf("unknown filter node %T", c)) } } +// writeEmbedPredicate lowers an embed-existence filter (films?actors=is.null / +// not.is.null). not.is.null is a semi-join, the same EXISTS an !inner embed +// adds; is.null is its anti-join complement (NOT EXISTS). The correlation hangs +// off parentRef so the predicate works both in an embedded read (alias t0) and +// in a plain count (the bare table name). See item 01.12. +func (b *builder) writeEmbedPredicate(p ir.EmbedPredicate) *pgerr.APIError { + if p.Index < 0 || p.Index >= len(b.embeds) { + return pgerr.ErrInternal("embed predicate index out of range") + } + if !p.Exists { + b.sb.WriteString("NOT ") + } + return b.writeEmbedExists(&b.embeds[p.Index], b.parentRef) +} + func (b *builder) writeLogical(kids []ir.Cond, sep string) *pgerr.APIError { if len(kids) == 0 { return nil @@ -530,13 +837,32 @@ func (b *builder) writeLogical(kids []ir.Cond, sep string) *pgerr.APIError { return nil } -// writeCompare lowers a single column-operator-value predicate. The column is a -// base column for now (JSON-path filters arrive with the JSON subsystem). +// writeCompare lowers a single column-operator-value predicate. A column may +// carry a JSON sub-path (data->>field), lowered through the dialect. func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError { - if len(c.Path) != 1 { - return pgerr.ErrUnsupported("JSON path filters", "sql") - } col := b.colRef(c.Path[0]) + isJSON := len(c.Path) > 1 + if isJSON { + frag, err := b.jsonPathExpr(col, c.Path[1:], c.Last) + if err != nil { + return err + } + col = frag + } + + // A quantified filter (op(any)/op(all) over a {…} list) expands to a disjunction + // or conjunction of the real operator, one predicate per element (item 01.1). + if c.Quant != ir.QNone { + frag, err := b.writeQuantified(col, c) + if err != nil { + return err + } + if c.Negate { + frag = "NOT (" + frag + ")" + } + b.sb.WriteString(frag) + return nil + } var frag string var err *pgerr.APIError @@ -544,33 +870,35 @@ func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError { case ir.OpEq, ir.OpNeq: // Boolean literals "true"/"false" are rendered via BoolValue so engines // without a native BOOL type (MySQL TINYINT) produce correct predicates - // (e.g. done = 1 rather than done = 'true' which MySQL coerces to 0). - switch c.Value.Text { - case "true": + // (e.g. done = 1 rather than done = 'true' which MySQL coerces to 0). The + // coercion is column-type-aware: against a non-boolean column (a text + // column literally holding the word "true") the words stay text, matching + // PostgreSQL's type-driven coercion (item 07.4). An unknown column type + // keeps the boolean rendering, the common ?col=is-not-the-point filter. + // A JSON ->>/-> extract is a text/json value, never a typed boolean column, + // so the words "true"/"false" bind as text and a JSON field holding the + // string still matches (the eq.true coercion is column-type driven). + boolColumn := !isJSON && (c.ColumnType == "" || pgtypes.ClassOf(c.ColumnType) == pgtypes.ClassBool) + switch { + case c.Value.Text == "true" && boolColumn: frag = col + " " + binaryOp(c.Op) + " " + b.d.BoolValue(true) - case "false": + case c.Value.Text == "false" && boolColumn: frag = col + " " + binaryOp(c.Op) + " " + b.d.BoolValue(false) default: - frag = col + " " + binaryOp(c.Op) + " " + b.bind(c.Value.Text) - } - case ir.OpGt, ir.OpGte, ir.OpLt, ir.OpLte, ir.OpLike: - if c.Quant != ir.QNone { - frag, err = b.writeLikeQuantified(col, ir.OpLike, c.Quant, c.Value.List) - } else { - frag = col + " " + binaryOp(c.Op) + " " + b.bind(c.Value.Text) + frag = col + " " + binaryOp(c.Op) + " " + b.filterValue(c) } + case ir.OpGt, ir.OpGte, ir.OpLt, ir.OpLte: + frag = col + " " + binaryOp(c.Op) + " " + b.filterValue(c) + case ir.OpLike: + frag = col + " " + binaryOp(c.Op) + " " + b.bind(c.Value.Text) case ir.OpILike: - if c.Quant != ir.QNone { - frag, err = b.writeLikeQuantified(col, ir.OpILike, c.Quant, c.Value.List) - } else { - var ok bool - frag, ok = b.d.ILike(col, b.bind(c.Value.Text)) - if !ok { - return pgerr.ErrUnsupported("case-insensitive LIKE", "sql") - } + var ok bool + frag, ok = b.d.ILike(col, b.bind(c.Value.Text)) + if !ok { + return pgerr.ErrUnsupported("case-insensitive LIKE", "sql") } case ir.OpIn: - frag, err = b.writeIn(col, c.Value.List) + frag, err = b.writeIn(col, c.Path[0], isJSON, c.Value.List) case ir.OpIs: frag, err = b.writeIs(col, c.Value.Text) case ir.OpMatch, ir.OpIMatch: @@ -585,8 +913,9 @@ func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError { return pgerr.ErrUnsupported("regular-expression match", "sql") } // Regex returns an already-formed boolean expression carrying PatternMark - // where the bound pattern placeholder goes. - frag = strings.Replace(expr, PatternMark, b.bind(c.Value.Text), 1) + // where the bound pattern placeholder goes. A represented column parses the + // pattern through its from-text cast, as PostgREST does for match/imatch. + frag = strings.Replace(expr, PatternMark, b.fromTextValue(c.Path[0], isJSON, c.Value.Text), 1) case ir.OpFTS: frag, err = b.writeFTS(c, col) case ir.OpIsDistinct: @@ -601,12 +930,35 @@ func (b *builder) writeCompare(c ir.Compare) *pgerr.APIError { default: sqlOp = "&&" } - val := b.bind(c.Value.Text) + // Normalize the PostgreSQL {a,b} array literal to the engine's format + // before binding; the dialect is a no-op for engines that accept {a,b}. A + // represented column parses the literal through its from-text cast, matching + // PostgREST's simple-operator path. + val := b.fromTextValue(c.Path[0], isJSON, b.d.ArrayLiteral(c.Value.Text)) var ok bool - frag, ok = b.d.ArrayOp(col, sqlOp, val) + frag, ok = b.d.ArrayOp(col, sqlOp, val, c.ColumnType) if !ok { return pgerr.ErrUnsupported("array operator "+sqlOp, "sql") } + case ir.OpRangeSL, ir.OpRangeSR, ir.OpRangeNXR, ir.OpRangeNXL, ir.OpRangeAdj: + var rop string + switch c.Op { + case ir.OpRangeSL: + rop = "<<" + case ir.OpRangeSR: + rop = ">>" + case ir.OpRangeNXR: + rop = "&<" + case ir.OpRangeNXL: + rop = "&>" + default: + rop = "-|-" + } + var ok bool + frag, ok = b.d.RangeOp(col, rop, b.fromTextValue(c.Path[0], isJSON, c.Value.Text)) + if !ok { + return pgerr.ErrUnsupported("range operator "+opName(c.Op), "sql") + } default: return pgerr.ErrUnsupported("filter operator "+opName(c.Op), "sql") } @@ -637,52 +989,77 @@ func (b *builder) writeFTS(c ir.Compare, col string) (string, *pgerr.APIError) { RowidRef: b.colRef(rowid), } } - expr, bindVal, ok := b.d.FullText(col, ref, c.FTS, c.Config, c.Value.Text) + expr, bindVal, ok := b.d.FullText(col, c.ColumnType, ref, c.FTS, c.Config, c.Value.Text) if !ok { return "", pgerr.ErrFullTextUnavailable(c.Path[0], "sql") } return strings.Replace(expr, PatternMark, b.bind(bindVal), 1), nil } -func (b *builder) writeIn(col string, list []string) (string, *pgerr.APIError) { +func (b *builder) writeIn(col, colName string, isJSON bool, list []string) (string, *pgerr.APIError) { if len(list) == 0 { // `col IN ()` is a syntax error; an empty IN matches nothing. return "1 = 0", nil } + rep, hasRep := b.reps[colName] + useRep := !isJSON && hasRep && rep.FromTextFunc != "" + // On an engine that binds the list as a single array (PostgreSQL's = ANY), every + // list length is one prepared statement instead of one per length. The element + // quoting is PostgreSQL's array-literal format, the same the array operators use, + // so a value with a comma or brace stays a single element. The bind happens only + // on this branch so the expansion path's placeholder numbering is unaffected. + if frag, ok := b.d.InList(col); ok { + elems := make([]any, len(list)) + for i, v := range list { + elems[i] = v + } + ph := b.bind(PGArrayLiteral(elems)) + operand := ph + if useRep { + // A represented column parses each element through its from-text cast, + // applied over the unpacked array, matching PostgREST's + // pgFmtArrayLiteralForField. + operand = "(SELECT " + b.repCall(rep.FromTextSchema, rep.FromTextFunc, "unnest("+ph+"::text[])") + ")" + } + return strings.Replace(frag, PatternMark, operand, 1), nil + } parts := make([]string, len(list)) for i, v := range list { - parts[i] = b.bind(v) + if useRep { + parts[i] = b.repCall(rep.FromTextSchema, rep.FromTextFunc, b.bind(v)+"::text") + } else { + parts[i] = b.bind(v) + } } return col + " IN (" + strings.Join(parts, ", ") + ")", nil } -// writeLikeQuantified expands like(any)/{...} and like(all)/{...} into a -// conjunction or disjunction of individual LIKE / ILIKE predicates. An empty -// list generates a no-match literal (1 = 0) for ANY and always-match (1 = 1) -// for ALL, consistent with SQL ANY/ALL semantics over an empty set. -func (b *builder) writeLikeQuantified(col string, op ir.Op, q ir.Quant, list []string) (string, *pgerr.APIError) { +// writeQuantified expands a quantified filter (op(any)/op(all) over a {…} list) +// into a disjunction (ANY) or conjunction (ALL) of the real operator, one +// predicate per element. An empty list is a no-match literal (1 = 0) for ANY and +// always-match (1 = 1) for ALL, consistent with SQL ANY/ALL over an empty set, +// though the parser now rejects an empty list upstream. See item 01.1. +func (b *builder) writeQuantified(col string, c ir.Compare) (string, *pgerr.APIError) { + list := c.Value.List if len(list) == 0 { - if q == ir.QAny { + if c.Quant == ir.QAny { return "1 = 0", nil } return "1 = 1", nil } sep := " OR " - if q == ir.QAll { + if c.Quant == ir.QAll { sep = " AND " } + colName := c.Path[0] + isJSON := len(c.Path) > 1 parts := make([]string, len(list)) - for i, pat := range list { - bound := b.bind(pat) - if op == ir.OpILike { - expr, ok := b.d.ILike(col, bound) - if !ok { - return "", pgerr.ErrUnsupported("case-insensitive LIKE", "sql") - } - parts[i] = expr - } else { - parts[i] = col + " LIKE " + bound + for i, v := range list { + frag, err := b.quantElem(col, colName, isJSON, c.Op, v) + if err != nil { + return "", err } + parts[i] = frag } if len(parts) == 1 { return parts[0], nil @@ -690,6 +1067,36 @@ func (b *builder) writeLikeQuantified(col string, op ir.Op, q ir.Quant, list []s return "(" + strings.Join(parts, sep) + ")", nil } +// quantElem lowers one element of a quantified list to its single-operator SQL +// predicate, using the operator's real infix/regex/ILIKE form. +func (b *builder) quantElem(col, colName string, isJSON bool, op ir.Op, v string) (string, *pgerr.APIError) { + switch op { + case ir.OpEq, ir.OpGt, ir.OpGte, ir.OpLt, ir.OpLte: + return col + " " + binaryOp(op) + " " + b.fromTextValue(colName, isJSON, v), nil + case ir.OpLike: + // like carries a wildcard pattern, so PostgREST binds it raw even on a + // represented column. + return col + " " + binaryOp(op) + " " + b.bind(v), nil + case ir.OpILike: + expr, ok := b.d.ILike(col, b.bind(v)) + if !ok { + return "", pgerr.ErrUnsupported("case-insensitive LIKE", "sql") + } + return expr, nil + case ir.OpMatch, ir.OpIMatch: + if feat := b.d.RegexFeatureGap(v); feat != "" { + return "", pgerr.ErrUnsupported(feat, "sql") + } + expr, ok := b.d.Regex(col, v, op == ir.OpIMatch) + if !ok { + return "", pgerr.ErrUnsupported("regular-expression match", "sql") + } + return strings.Replace(expr, PatternMark, b.fromTextValue(colName, isJSON, v), 1), nil + default: + return "", pgerr.ErrUnsupported("quantifier on "+opName(op), "sql") + } +} + func (b *builder) writeIs(col, text string) (string, *pgerr.APIError) { switch text { case "null": @@ -697,9 +1104,20 @@ func (b *builder) writeIs(col, text string) (string, *pgerr.APIError) { case "not_null": return col + " IS NOT NULL", nil case "true": + if frag, ok := b.d.IsBool(col, true); ok { + return frag, nil + } return col + " IS " + b.d.BoolValue(true), nil case "false": + if frag, ok := b.d.IsBool(col, false); ok { + return frag, nil + } return col + " IS " + b.d.BoolValue(false), nil + case "unknown": + if frag, ok := b.d.IsUnknown(col); ok { + return frag, nil + } + return col + " IS NULL", nil default: return "", pgerr.ErrParse("unknown is value " + text) } @@ -707,12 +1125,29 @@ func (b *builder) writeIs(col, text string) (string, *pgerr.APIError) { // writeOrder emits the ORDER BY, delegating NULLs placement to the dialect. func (b *builder) writeOrder(terms []ir.OrderTerm) *pgerr.APIError { + // The parent reference for a related-order subquery is the qualifier in force + // as the ORDER BY is written (t0 in an embedded read, the bare table name in a + // count). + parentAlias := b.qual var sortKeys, orderTerms []string for _, t := range terms { - if len(t.Path) != 1 { - return pgerr.ErrUnsupported("JSON path ordering", "sql") + var col string + if t.Rel != "" { + frag, err := b.relatedOrderExpr(t, parentAlias) + if err != nil { + return err + } + col = frag + } else { + col = b.colRef(t.Path[0]) + if len(t.Path) > 1 { + frag, err := b.jsonPathExpr(col, t.Path[1:], t.Last) + if err != nil { + return err + } + col = frag + } } - col := b.colRef(t.Path[0]) dir := "ASC" if t.Desc { dir = "DESC" @@ -731,6 +1166,59 @@ func (b *builder) writeOrder(terms []ir.OrderTerm) *pgerr.APIError { return nil } +// relatedOrderExpr lowers an order=rel(col) term to a correlated scalar subquery +// selecting the embed's column for the matching to-one row: a parent with no +// related row yields NULL, which the dialect's NULLs placement then orders. The +// embed is matched by the same written name the planner validated, and its +// to-one join condition correlates the subquery back to the parent (item 07.6). +func (b *builder) relatedOrderExpr(t ir.OrderTerm, parentAlias string) (string, *pgerr.APIError) { + emb := b.findEmbed(t.Rel) + if emb == nil { + // The planner validates the relation is embedded; reaching here is a bug. + return "", pgerr.ErrInternal("related order names an unresolved embed: " + t.Rel) + } + rel := emb.Rel + b.aliasN++ + alias := "o" + strconv.Itoa(b.aliasN) + + saved := b.qual + b.qual = alias + savedC, savedReps, savedR := b.useRelation(&emb.Query, rel.Target.Name) + col := b.colRef(t.Path[0]) + if len(t.Path) > 1 { + frag, err := b.jsonPathExpr(col, t.Path[1:], t.Last) + if err != nil { + b.qual = saved + b.computed, b.reps, b.rootRow = savedC, savedReps, savedR + return "", err + } + col = frag + } + b.qual = saved + b.computed, b.reps, b.rootRow = savedC, savedReps, savedR + + from, cond := b.embedSource(rel, alias, parentAlias) + return "(SELECT " + col + " FROM " + from + " WHERE " + cond + ")", nil +} + +// findEmbed returns the embed an order=rel(col) term refers to, matched by the +// embed's alias or, when it has none, its written target name. It mirrors the +// planner's findEmbedByName so the compiler resolves the same relation the +// planner validated. +func (b *builder) findEmbed(name string) *ir.Embed { + for i := range b.embeds { + emb := &b.embeds[i] + written := emb.Alias + if written == "" { + written = emb.Target.Name + } + if written == name { + return emb + } + } + return nil +} + // binaryOp maps an infix operator to its SQL spelling. Only the operators with a // direct infix form go through here; the rest are handled in writeCompare. func binaryOp(op ir.Op) string { diff --git a/backend/sqlgen/compile_test.go b/backend/sqlgen/compile_test.go index 9a80da1..df6cdb1 100644 --- a/backend/sqlgen/compile_test.go +++ b/backend/sqlgen/compile_test.go @@ -67,7 +67,7 @@ func (stub) RegexFeatureGap(string) string { return "" } // FullText models a PostgreSQL-flavored, column-agnostic full text: the index is // ignored (tsvector works on any column), so a nil idx is fine. -func (stub) FullText(col string, _ *FullTextRef, v ir.FTSVariant, _, _ string) (string, string, bool) { +func (stub) FullText(col, _ string, _ *FullTextRef, v ir.FTSVariant, _, _ string) (string, string, bool) { ctor := map[ir.FTSVariant]string{ ir.FTSPlain: "to_tsquery", ir.FTSPlainText: "plainto_tsquery", ir.FTSPhrase: "phraseto_tsquery", ir.FTSWeb: "websearch_to_tsquery", @@ -76,9 +76,15 @@ func (stub) FullText(col string, _ *FullTextRef, v ir.FTSVariant, _, _ string) ( } func (stub) SessionRead(k string) string { return "" } func (stub) SessionWrite(k string) (string, bool) { return "", false } -func (stub) ArrayOp(col, op, val string) (string, bool) { +func (stub) ArrayOp(col, op, val, _ string) (string, bool) { return col + " " + op + " " + val, true } +func (stub) RangeOp(col, op, val string) (string, bool) { + return col + " " + op + " " + val, true +} +func (stub) ArrayLiteral(s string) string { return s } +func (stub) InList(_ string) (string, bool) { return "", false } +func (stub) ArrayArg(e []any, _ string) any { return JSONArrayArg(e) } func (stub) ILike(col, val string) (string, bool) { return col + " ILIKE " + val, true } func (stub) BoolValue(v bool) string { if v { @@ -86,6 +92,26 @@ func (stub) BoolValue(v bool) string { } return "FALSE" } +func (stub) IsBool(string, bool) (string, bool) { return "", false } +func (stub) IsUnknown(string) (string, bool) { return "", false } + +// JSONPath mirrors the PostgreSQL native ->/->> chain so the shared compiler's +// JSON-path routing is assertable without a real engine. +func (stub) JSONPath(base string, hops []string, asText bool) (string, bool) { + expr := base + for i, h := range hops { + op := "->" + if asText && i == len(hops)-1 { + op = "->>" + } + if IsJSONArrayIndex(h) { + expr += op + h + } else { + expr += op + "'" + h + "'" + } + } + return expr, true +} func compile(t *testing.T, q *ir.Query) *Statement { t.Helper() @@ -229,6 +255,76 @@ func TestCompileEmptyInMatchesNothing(t *testing.T) { } } +// TestCompileQuantifiedEqExpandsToOr checks eq(any) over a list fans out into an +// OR of equalities, each value bound (item 01.1). +func TestCompileQuantifiedEqExpandsToOr(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"id"}, + Op: ir.OpEq, + Quant: ir.QAny, + Value: ir.Value{List: []string{"1", "2", "3"}}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + want := `SELECT * FROM "t" WHERE ("id" = $1 OR "id" = $2 OR "id" = $3)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } + if len(st.Args) != 3 || st.Args[0] != "1" || st.Args[2] != "3" { + t.Errorf("Args = %v", st.Args) + } +} + +// TestCompileQuantifiedGtExpandsToAnd checks gt(all) fans out into an AND. +func TestCompileQuantifiedGtExpandsToAnd(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"year"}, + Op: ir.OpGt, + Quant: ir.QAll, + Value: ir.Value{List: []string{"1990", "2000"}}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + want := `SELECT * FROM "t" WHERE ("year" > $1 AND "year" > $2)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// TestCompileQuantifiedMatchUsesDialectRegex checks match(any) routes each +// element through the dialect regex seam (PatternMark replaced by the bind). +func TestCompileQuantifiedMatchUsesDialectRegex(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"c"}, + Op: ir.OpMatch, + Quant: ir.QAny, + Value: ir.Value{List: []string{"^a", "b$"}}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + want := `SELECT * FROM "t" WHERE ("c" ~ $1 OR "c" ~ $2)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } + if len(st.Args) != 2 || st.Args[0] != "^a" || st.Args[1] != "b$" { + t.Errorf("Args = %v", st.Args) + } +} + +// TestCompileQuantifiedNegated checks a negated quantified compare wraps the +// whole fan-out in NOT (…). +func TestCompileQuantifiedNegated(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"id"}, + Op: ir.OpEq, + Quant: ir.QAny, + Negate: true, + Value: ir.Value{List: []string{"1", "2"}}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + want := `SELECT * FROM "t" WHERE NOT (("id" = $1 OR "id" = $2))` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + func TestCompileCount(t *testing.T) { st, err := CompileCount(stub{}, &ir.Query{Relation: ir.Ref{Name: "films"}}) if err != nil { @@ -300,23 +396,23 @@ func TestCompileInsertMultiRow(t *testing.T) { } func TestCompileInsertMissingDefaultAndNull(t *testing.T) { - // A row missing a column takes DEFAULT under missing=default ... + // A row missing a column takes DEFAULT only under an explicit missing=default ... st, _ := CompileInsert(stub{}, &ir.Query{ Relation: ir.Ref{Name: "t"}, Write: &ir.WriteSpec{ Columns: []string{"a", "b"}, + Missing: ir.MissingDefault, Rows: []map[string]ir.Value{{"a": jstr("x")}}, }, }, nil) if st.SQL != `INSERT INTO "t" ("a", "b") VALUES ($1, DEFAULT)` { t.Errorf("default: SQL = %q", st.SQL) } - // ... and a bound NULL under missing=null. + // ... and a bound NULL by default (MissingNull is the zero value, item 01.18). st, _ = CompileInsert(stub{}, &ir.Query{ Relation: ir.Ref{Name: "t"}, Write: &ir.WriteSpec{ Columns: []string{"a", "b"}, - Missing: ir.MissingNull, Rows: []map[string]ir.Value{{"a": jstr("x")}}, }, }, nil) @@ -428,12 +524,13 @@ func TestCompileInsertEmptyPayloadRejected(t *testing.T) { } } -// The base SQL compiler does not lower the array and range operators; a backend -// grades them per dialect. A read using one here reports PGRST127 and names the -// operator, rather than emitting a quietly different predicate. +// A range operator on an engine whose dialect declines (no range types) reports +// PGRST127 and names the PostgREST token, rather than emitting a quietly +// different predicate. A range-capable dialect lowers it instead; see +// rangeop_test.go. func TestCompileRangeOperatorRejectedNamed(t *testing.T) { where := ir.Cond(ir.Compare{Path: []string{"period"}, Op: ir.OpRangeSL, Value: ir.Value{Text: "[1,2)"}}) - _, err := CompileRead(stub{}, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + _, err := CompileRead(noRangeDialect{}, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) if err == nil || err.Code != "PGRST127" { t.Fatalf("want PGRST127, got %v", err) } @@ -442,12 +539,106 @@ func TestCompileRangeOperatorRejectedNamed(t *testing.T) { } } -func TestCompileAggregateRejected(t *testing.T) { - _, err := CompileRead(stub{}, &ir.Query{ +// TestCompileBareCount renders count() with no grouping column as count(*) over +// the whole relation, keyed to its default response name (item 01.4). +func TestCompileBareCount(t *testing.T) { + st := compile(t, &ir.Query{ Relation: ir.Ref{Name: "t"}, Select: []ir.SelectItem{ir.Aggregate{Func: ir.AggCount}}, }) - if err == nil || err.Code != "PGRST127" { - t.Fatalf("want PGRST127 for aggregate, got %v", err) + want := `SELECT count(*) AS "count" FROM "t"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// TestCompileColumnAggregateGroupsBy renders category, amount.sum() as a grouped +// aggregate: the plain column is the GROUP BY key, the aggregate folds per group. +func TestCompileColumnAggregateGroupsBy(t *testing.T) { + st := compile(t, &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "sales"}, + Select: []ir.SelectItem{ + col("category"), + ir.Aggregate{Func: ir.AggSum, Arg: &ir.Column{Path: []string{"amount"}}}, + }, + }) + want := `SELECT "category", sum("amount") AS "sum" FROM "public"."sales" GROUP BY "category"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// TestCompileAggregateAliasAndCasts honors a response-key alias, an input cast on +// the aggregated column, and an output cast on the result. +func TestCompileAggregateAliasAndCasts(t *testing.T) { + st := compile(t, &ir.Query{ + Relation: ir.Ref{Name: "sales"}, + Select: []ir.SelectItem{ + ir.Aggregate{ + Func: ir.AggSum, + Arg: &ir.Column{Path: []string{"amount"}, Cast: "numeric"}, + Cast: "text", + Alias: "total", + }, + }, + }) + want := `SELECT CAST(sum(CAST("amount" AS numeric)) AS text) AS "total" FROM "sales"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// A payload array goes through the dialect on the write path: the stub (like +// every engine without array columns) binds the JSON text, never a PostgreSQL +// {a,b} literal. This is what lets a JSON column round-trip ["go","sql"]. +func TestCompileUpdateArrayBindsDialectForm(t *testing.T) { + st, err := CompileUpdate(stub{}, &ir.Query{ + Relation: ir.Ref{Name: "todos"}, + Write: &ir.WriteSpec{Set: map[string]ir.Value{ + "tags": {JSON: []any{"go", "sql"}}, + }}, + }, nil) + if err != nil { + t.Fatalf("CompileUpdate: %v", err) + } + if len(st.Args) != 1 || st.Args[0] != `["go","sql"]` { + t.Errorf("Args = %#v, want JSON text", st.Args) + } +} + +func TestCompileInsertArrayBindsDialectForm(t *testing.T) { + st, err := CompileInsert(stub{}, &ir.Query{ + Relation: ir.Ref{Name: "todos"}, + Write: &ir.WriteSpec{ + Columns: []string{"tags"}, + Rows: []map[string]ir.Value{{"tags": {JSON: []any{json.Number("1"), "two words"}}}}, + }, + }, nil) + if err != nil { + t.Fatalf("CompileInsert: %v", err) + } + if len(st.Args) != 1 || st.Args[0] != `[1,"two words"]` { + t.Errorf("Args = %#v, want JSON text", st.Args) + } +} + +// PGArrayLiteral is the PostgreSQL form of the same payload: bare elements +// unquoted, strings with spaces or quotes double-quoted and escaped, NULL for +// JSON null. +func TestPGArrayLiteral(t *testing.T) { + cases := []struct { + in []any + want string + }{ + {[]any{"go", "sql"}, `{go,sql}`}, + {[]any{json.Number("1"), json.Number("2.5")}, `{1,2.5}`}, + {[]any{"two words", `qu"ote`, nil}, `{"two words","qu\"ote",NULL}`}, + {[]any{true, false}, `{t,f}`}, + {[]any{}, `{}`}, + } + for _, c := range cases { + if got := PGArrayLiteral(c.in); got != c.want { + t.Errorf("PGArrayLiteral(%v) = %q, want %q", c.in, got, c.want) + } } } diff --git a/backend/sqlgen/computed_test.go b/backend/sqlgen/computed_test.go new file mode 100644 index 0000000..403f58b --- /dev/null +++ b/backend/sqlgen/computed_test.go @@ -0,0 +1,159 @@ +package sqlgen + +import ( + "strings" + "testing" + + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/schema" +) + +// computedRelModel wires authors with two computed relationships to books: a +// set-returning one (to-many) and a single-row one (to-one). The edges carry the +// function to call, not join columns. +func computedRelModel() *schema.Model { + cols := func(names ...string) []*schema.Column { + out := make([]*schema.Column, len(names)) + for i, n := range names { + out[i] = &schema.Column{Name: n, Type: "text", Position: i + 1} + } + return out + } + books := &schema.Relation{Schema: "public", Name: "books", Columns: cols("id", "title")} + authors := &schema.Relation{ + Schema: "public", + Name: "authors", + Columns: cols("id", "name"), + ComputedRels: []schema.ComputedRel{ + {Name: "books", FuncSchema: "public", TargetSchema: "public", TargetName: "books", Card: schema.CardToMany}, + {Name: "first_book", FuncSchema: "public", TargetSchema: "public", TargetName: "books", Card: schema.CardToOne}, + }, + } + return schema.NewModel([]*schema.Relation{authors, books}) +} + +// A to-many computed relationship embeds by calling the function on the parent +// row in the subquery FROM, with no join predicate (the row argument correlates). +func TestComputedRelToMany(t *testing.T) { + m := computedRelModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "authors"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}, ir.EmbedRef{Index: 0}}, + Embeds: []ir.Embed{{ + OutKey: "books", + Target: ir.Ref{Schema: "public", Name: "books"}, + Rel: relateNamed(t, m, "authors", "books", "books"), + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"title"}}}}, + }}, + } + got := compileEmbed(t, q).SQL + if !strings.Contains(got, `FROM "public"."books"(t0) t1 WHERE TRUE`) { + t.Errorf("to-many computed-rel embed did not render function call:\n%s", got) + } +} + +// A to-one computed relationship renders a single-object subquery over the +// function call, again correlating through the row argument. +func TestComputedRelToOne(t *testing.T) { + m := computedRelModel() + rel := relateNamed(t, m, "authors", "books", "first_book") + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "authors"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}, ir.EmbedRef{Index: 0}}, + Embeds: []ir.Embed{{ + OutKey: "first_book", + Target: ir.Ref{Schema: "public", Name: "books"}, + Rel: rel, + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"title"}}}}, + }}, + } + got := compileEmbed(t, q).SQL + want := `(SELECT json_object('title', t1."title") FROM "public"."first_book"(t0) t1 WHERE TRUE LIMIT 1)` + if !strings.Contains(got, want) { + t.Errorf("to-one computed-rel embed:\n got %s\nwant substring %s", got, want) + } +} + +// relateNamed picks the single edge from parent to target whose name matches, +// when more than one edge connects them (two computed relationships here). +func relateNamed(t *testing.T, m *schema.Model, parent, target, name string) *schema.Relationship { + t.Helper() + p, ok := m.Lookup(parent, []string{"public"}) + if !ok { + t.Fatalf("parent %q not in model", parent) + } + cands, found := m.Relationships(p, target, []string{"public"}) + if !found { + t.Fatalf("relateNamed(%s,%s): target not found", parent, target) + } + for i := range cands { + if cands[i].Name == name { + return &cands[i] + } + } + t.Fatalf("relateNamed(%s,%s,%s): no edge named %q among %d", parent, target, name, name, len(cands)) + return nil +} + +// A computed field in the select list renders as schema.func(row), where the row +// is the bare relation name at the top level, and is aliased to the field name +// only when the client renamed it. +func TestComputedFieldSelect(t *testing.T) { + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "authors"}, + Select: []ir.SelectItem{col("id"), col("full_name")}, + Computed: map[string]string{"full_name": "public"}, + } + st := compile(t, q) + want := `SELECT "id", "public"."full_name"("authors") FROM "public"."authors"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// A renamed computed field carries an explicit alias so the output key is the one +// the client asked for, not the function name. +func TestComputedFieldAliased(t *testing.T) { + q := &ir.Query{ + Relation: ir.Ref{Name: "authors"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"full_name"}, Alias: "name"}}, + Computed: map[string]string{"full_name": "public"}, + } + st := compile(t, q) + want := `SELECT "public"."full_name"("authors") AS "name" FROM "authors"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// A computed field is filterable: a predicate on it lowers to the function call, +// not a bare column, so the WHERE references schema.func(row). +func TestComputedFieldFilter(t *testing.T) { + var where ir.Cond = ir.Compare{ + Path: []string{"full_name"}, Op: ir.OpEq, Value: ir.Value{Text: "Ada Lovelace"}, + } + q := &ir.Query{ + Relation: ir.Ref{Name: "authors"}, + Select: []ir.SelectItem{col("id")}, + Where: &where, + Computed: map[string]string{"full_name": "public"}, + } + st := compile(t, q) + if !strings.Contains(st.SQL, `"public"."full_name"("authors") = `) { + t.Errorf("filter did not render computed call: %q", st.SQL) + } +} + +// A computed field is orderable: ORDER BY references the function call. +func TestComputedFieldOrder(t *testing.T) { + q := &ir.Query{ + Relation: ir.Ref{Name: "authors"}, + Select: []ir.SelectItem{col("id")}, + Order: []ir.OrderTerm{{Path: []string{"full_name"}}}, + Computed: map[string]string{"full_name": "public"}, + } + st := compile(t, q) + if !strings.Contains(st.SQL, `ORDER BY "public"."full_name"("authors")`) { + t.Errorf("order did not render computed call: %q", st.SQL) + } +} diff --git a/backend/sqlgen/cond_test.go b/backend/sqlgen/cond_test.go index c0e5b5e..ca252fb 100644 --- a/backend/sqlgen/cond_test.go +++ b/backend/sqlgen/cond_test.go @@ -112,7 +112,7 @@ func TestCompileLikeAll(t *testing.T) { // when the planner attached none. type indexFTSDialect struct{ stub } -func (indexFTSDialect) FullText(col string, idx *FullTextRef, _ ir.FTSVariant, _, _ string) (string, string, bool) { +func (indexFTSDialect) FullText(col, _ string, idx *FullTextRef, _ ir.FTSVariant, _, _ string) (string, string, bool) { if idx == nil { return "", "", false } diff --git a/backend/sqlgen/context.go b/backend/sqlgen/context.go new file mode 100644 index 0000000..eaf119d --- /dev/null +++ b/backend/sqlgen/context.go @@ -0,0 +1,32 @@ +package sqlgen + +import "github.com/tamnd/dbrest/reqctx" + +// ContextArgs builds the reserved :request_* placeholder values a registry +// function may reference in its SQL. On PostgreSQL a function reads the +// request context with current_setting('request.method', true); on engines +// with no SQL-readable session store the same values bind as parameters +// (spec 15), under these names: +// +// :request_method the HTTP method +// :request_path the request path +// :request_role the resolved request role +// :request_jwt_claims the verified claims as a JSON object ("{}" when none) +// :request_headers lower-cased request headers as a JSON object +// :request_cookies request cookies as a JSON object +// +// The call compiler resolves these only when the placeholder is not a +// declared parameter, so a function parameter of the same name keeps winning. +func ContextArgs(rc *reqctx.Context) map[string]any { + if rc == nil { + return nil + } + return map[string]any{ + "request_method": rc.Method, + "request_path": rc.Path, + "request_role": rc.Role, + "request_jwt_claims": string(rc.ClaimsJSON()), + "request_headers": string(rc.HeadersJSON()), + "request_cookies": string(rc.CookiesJSON()), + } +} diff --git a/backend/sqlgen/dialect.go b/backend/sqlgen/dialect.go index a18da81..dbfd798 100644 --- a/backend/sqlgen/dialect.go +++ b/backend/sqlgen/dialect.go @@ -61,13 +61,16 @@ type Dialect interface { RegexFeatureGap(pattern string) string // FullText lowers a full-text predicate. col is the quoted column reference - // (PostgreSQL builds to_tsvector over it); idx is the resolved covering index, - // or nil when the schema has none (an engine that requires one reports ok=false - // and the compiler raises PGRST127). variant is the fts/plfts/phfts/wfts - // grammar; config is the language argument (may be empty); value is the raw - // query text. The returned fragment carries PatternMark where the bound, - // engine-translated query value goes, and bind is that value. See spec 21. - FullText(col string, idx *FullTextRef, variant ir.FTSVariant, config, value string) (frag, bind string, ok bool) + // (PostgreSQL builds to_tsvector over it); colType is the column's canonical + // type, so a dialect can skip the to_tsvector wrap when the column is already + // tsvector (it may be empty when the type is unknown); idx is the resolved + // covering index, or nil when the schema has none (an engine that requires one + // reports ok=false and the compiler raises PGRST127). variant is the + // fts/plfts/phfts/wfts grammar; config is the language argument (may be empty); + // value is the raw query text. The returned fragment carries PatternMark where + // the bound, engine-translated query value goes, and bind is that value. See + // spec 21. + FullText(col, colType string, idx *FullTextRef, variant ir.FTSVariant, config, value string) (frag, bind string, ok bool) // SessionRead reads a request-context value (the GUC analog). SessionRead(key string) string @@ -86,12 +89,71 @@ type Dialect interface { // BoolValue renders a boolean literal. BoolValue(v bool) string + // IsBool renders "col IS TRUE" or "col IS FALSE" in the engine's syntax. + // Engines that restrict IS to NULL/UNKNOWN (SQL Server) return ok=true with + // a = expression; engines that support IS return ok=false to fall back + // to "col IS ". + IsBool(col string, v bool) (string, bool) + + // IsUnknown renders the three-valued "col IS UNKNOWN" test in the engine's + // syntax. PostgreSQL has the native operator and returns ok=true; engines + // without it return ok=false to fall back to "col IS NULL", which selects the + // same rows for a boolean column (its UNKNOWN state is its NULL). + IsUnknown(col string) (string, bool) + // ArrayOp renders an array containment/overlap operator expression, or - // reports ok=false when the engine does not support array types (SQLite, - // MySQL, SQL Server). The compiler emits PGRST127 when ok=false. - // op is one of "@>", "<@", "&&"; col is the quoted column; val is the - // placeholder returned by bind(). - ArrayOp(col, op, val string) (string, bool) + // reports ok=false when the engine does not support array types (MySQL, SQL + // Server) or when the column type does not support array semantics (SQLite + // requires a JSON-typed column for json_each). The compiler emits PGRST127 + // when ok=false. op is one of "@>", "<@", "&&"; col is the quoted column; + // val is the placeholder returned by bind(); colType is the canonical + // column type resolved by the planner ("json", "text", "integer", …). + ArrayOp(col, op, val, colType string) (string, bool) + + // ArrayLiteral converts a PostgREST array literal (PostgreSQL {a,b} syntax) + // to the engine's native format for use as a bound parameter. PostgreSQL + // accepts {a,b} natively; SQLite needs JSON ["a","b"]. Other engines that + // do not support arrays may return the text unchanged (they never reach + // ArrayOp either). + ArrayLiteral(pgText string) string + + // InList renders an in-list filter as a single-parameter form when the engine + // supports it, or reports ok=false to fall back to col IN ($1, $2, ...). col is + // the quoted column; the returned fragment carries PatternMark where the bound + // array placeholder goes, so the compiler binds the one argument only on this + // path (an unused bind would shift every later placeholder). PostgreSQL returns + // col = ANY(PatternMark), which PostgREST uses so a list of any length is one + // prepared statement instead of a distinct statement per length; the rows are + // identical to the expanded IN. Engines without an array-bound ANY return + // ok=false. + InList(col string) (string, bool) + + // ArrayArg converts a decoded JSON array from a write payload into the + // bound driver argument the engine expects. colType is the target column's + // canonical type, so PostgreSQL renders the {elem1,elem2} array-literal text + // for an array column but keeps JSON text for a json/jsonb column (a JSON + // array there is JSON, not an array literal); engines without array columns + // (SQLite, MySQL, SQL Server) keep the JSON text regardless so a json/text + // column stores the array unchanged and reads it back as a JSON array. + ArrayArg(elems []any, colType string) any + + // JSONPath lowers a JSON sub-path access into the engine's spelling. base is + // the already-qualified, quoted base column; hops are the path segments after + // it (data->phones->0->>number gives base data, hops {phones,0,number}); a + // segment that is all digits is an array index. asText reports whether the + // final hop was ->> (returns text) rather than -> (returns json). PostgreSQL + // emits its native ->/->> chain; SQLite emits the ->/->> operators over a + // JSON path. An engine without JSON paths returns ok=false and the compiler + // raises PGRST127. + JSONPath(base string, hops []string, asText bool) (frag string, ok bool) + + // RangeOp renders a range-type operator expression (PostgREST sl/sr/nxr/nxl/ + // adj), or reports ok=false when the engine has no range types so the compiler + // raises PGRST127. op is the engine-neutral PostgreSQL spelling "<<", ">>", + // "&<", "&>", or "-|-"; col is the quoted column; val is the placeholder + // returned by bind(). PostgreSQL emits the native operator; engines without + // range types decline. + RangeOp(col, op, val string) (string, bool) } // PatternMark is the sentinel a Dialect.Regex fragment carries where the bound diff --git a/backend/sqlgen/embed.go b/backend/sqlgen/embed.go index 7747d15..68227c5 100644 --- a/backend/sqlgen/embed.go +++ b/backend/sqlgen/embed.go @@ -28,7 +28,21 @@ import ( // aliased t1, t2, ... as they are emitted, in a stable left-to-right order. func compileReadEmbedded(d Dialect, q *ir.Query) (*Statement, *pgerr.APIError) { b := newBuilder(d) + return b.writeEmbeddedQuery(q, func() *pgerr.APIError { + b.sb.WriteString(b.qualify(q.Relation)) + return nil + }) +} + +// writeEmbeddedQuery emits an embedded read: the parent projection (plain columns +// plus one JSON subquery per embed), a parent source written by writeSource (a +// base relation for a table read, the wrapped function result for an RPC call), +// the parent WHERE with one EXISTS per !inner embed, and the order/window. The +// projection is written before the source so its embed-subquery placeholders bind +// ahead of the source's, keeping positional arguments in textual order. +func (b *builder) writeEmbeddedQuery(q *ir.Query, writeSource func() *pgerr.APIError) (*Statement, *pgerr.APIError) { const parentAlias = "t0" + b.useRelation(q, q.Relation.Name) b.sb.WriteString("SELECT ") if err := b.writeEmbeddedSelect(q, parentAlias); err != nil { @@ -36,7 +50,9 @@ func compileReadEmbedded(d Dialect, q *ir.Query) (*Statement, *pgerr.APIError) { } b.sb.WriteString(" FROM ") - b.sb.WriteString(b.qualify(q.Relation)) + if err := writeSource(); err != nil { + return nil, err + } b.sb.WriteString(" ") b.sb.WriteString(parentAlias) @@ -46,6 +62,8 @@ func compileReadEmbedded(d Dialect, q *ir.Query) (*Statement, *pgerr.APIError) { if q.Where != nil { b.sb.WriteString(" WHERE ") b.qual = parentAlias + b.parentRef = parentAlias + b.embeds = q.Embeds if err := b.writeCond(*q.Where); err != nil { return nil, err } @@ -70,6 +88,9 @@ func compileReadEmbedded(d Dialect, q *ir.Query) (*Statement, *pgerr.APIError) { hasOrder := len(q.Order) > 0 if hasOrder { b.qual = parentAlias + // A related-order term (order=rel(col)) resolves its relation against the + // parent's embeds, so they must be in scope even when no WHERE set them. + b.embeds = q.Embeds if err := b.writeOrder(q.Order); err != nil { return nil, err } @@ -107,12 +128,32 @@ func (b *builder) writeEmbeddedSelect(q *ir.Query, parentAlias string) *pgerr.AP } sep() b.sb.WriteString(expr) - if name := v.Name(); name != "" && name != lastPath(v.Path) && !isStar(v) { + if name := v.Name(); name != "" && !isStar(v) && (name != lastPath(v.Path) || len(v.Path) > 1) { b.sb.WriteString(" AS ") b.sb.WriteString(b.d.QuoteIdent(name)) } case ir.EmbedRef: emb := &q.Embeds[v.Index] + // An empty-parenthesis embed, client(), joins for filtering but is + // not projected; the parent WHERE still carries its !inner EXISTS. + if emb.EmptySelect { + continue + } + // A spread embed, ...client(name), lifts its columns into the parent + // row rather than nesting them under a key. + if emb.Spread { + pairs, err := b.spreadPairs(emb, parentAlias) + if err != nil { + return err + } + for _, p := range pairs { + sep() + b.sb.WriteString(p.Value) + b.sb.WriteString(" AS ") + b.sb.WriteString(b.d.QuoteIdent(p.Key)) + } + continue + } sub, err := b.embedExpr(emb, parentAlias) if err != nil { return err @@ -125,6 +166,11 @@ func (b *builder) writeEmbeddedSelect(q *ir.Query, parentAlias string) *pgerr.AP return pgerr.ErrUnsupported("aggregates in select", "sql") } } + // A select list that named only hidden embeds projects nothing; fall back to + // the parent's columns so the statement stays valid. + if first { + b.sb.WriteString(parentAlias + ".*") + } return nil } @@ -156,7 +202,7 @@ func (b *builder) writeEmbed(emb *ir.Embed, parentAlias string) *pgerr.APIError if err != nil { return err } - from := b.qualify(ir.Ref{Schema: rel.Target.Schema, Name: rel.Target.Name}) + " " + alias + from, corr := b.embedSource(rel, alias, parentAlias) if rel.Card == schema.CardToOne { b.sb.WriteString("(SELECT ") @@ -164,11 +210,16 @@ func (b *builder) writeEmbed(emb *ir.Embed, parentAlias string) *pgerr.APIError b.sb.WriteString(" FROM ") b.sb.WriteString(from) b.sb.WriteString(" WHERE ") - b.sb.WriteString(b.joinCond(alias, rel.Foreign, parentAlias, rel.Local)) + b.sb.WriteString(corr) if err := b.writeEmbedFilter(emb, alias); err != nil { return err } - b.sb.WriteString(" LIMIT 1)") + lim := 1 + if lo := b.d.LimitOffset(&lim, nil, false); lo != "" { + b.sb.WriteString(" ") + b.sb.WriteString(lo) + } + b.sb.WriteString(")") return nil } @@ -197,7 +248,7 @@ func (b *builder) writeEmbed(emb *ir.Embed, parentAlias string) *pgerr.APIError b.sb.WriteString(b.joinCond(jx, rel.JLocal, parentAlias, rel.Local)) } else { b.sb.WriteString(" WHERE ") - b.sb.WriteString(b.joinCond(alias, rel.Foreign, parentAlias, rel.Local)) + b.sb.WriteString(corr) } if err := b.writeEmbedFilter(emb, alias); err != nil { return err @@ -205,12 +256,22 @@ func (b *builder) writeEmbed(emb *ir.Embed, parentAlias string) *pgerr.APIError hasOrder := len(emb.Query.Order) > 0 if hasOrder { saved := b.qual + savedEmbeds := b.embeds b.qual = alias - if err := b.writeOrder(emb.Query.Order); err != nil { + // A related-order term inside this embed resolves against the embed's own + // nested embeds, so scope them in for the duration of its ORDER BY. + b.embeds = emb.Query.Embeds + savedC, savedReps, savedR := b.useRelation(&emb.Query, rel.Target.Name) + restore := func() { b.qual = saved + b.embeds = savedEmbeds + b.computed, b.reps, b.rootRow = savedC, savedReps, savedR + } + if err := b.writeOrder(emb.Query.Order); err != nil { + restore() return err } - b.qual = saved + restore() } if clause := b.d.LimitOffset(emb.Query.Limit, emb.Query.Offset, hasOrder); clause != "" { b.sb.WriteString(" ") @@ -220,6 +281,21 @@ func (b *builder) writeEmbed(emb *ir.Embed, parentAlias string) *pgerr.APIError return nil } +// embedSource renders an embed's FROM entry and its correlation predicate. A +// foreign-key edge selects the target relation aliased and correlates by a join +// on the FK columns. A computed relationship (spec 11) instead calls the backing +// function on the parent row, FuncSchema.FuncName(parentAlias) alias; the function +// argument is the correlation, so the predicate is just TRUE and the embed's own +// filters AND onto it the same way. +func (b *builder) embedSource(rel *schema.Relationship, alias, parentAlias string) (from, corr string) { + if rel.FuncName != "" { + call := b.d.QuoteIdent(rel.FuncSchema) + "." + b.d.QuoteIdent(rel.FuncName) + "(" + parentAlias + ")" + return call + " " + alias, b.d.BoolValue(true) + } + from = b.qualify(ir.Ref{Schema: rel.Target.Schema, Name: rel.Target.Name}) + " " + alias + return from, b.joinCond(alias, rel.Foreign, parentAlias, rel.Local) +} + // writeEmbedExists emits the EXISTS predicate that an !inner embed adds to the // parent WHERE, so a parent row with no embedded match is excluded. The same // embedded filters apply, matching PostgREST's inner-join semantics. @@ -227,7 +303,7 @@ func (b *builder) writeEmbedExists(emb *ir.Embed, parentAlias string) *pgerr.API rel := emb.Rel b.aliasN++ alias := "x" + strconv.Itoa(b.aliasN) - from := b.qualify(ir.Ref{Schema: rel.Target.Schema, Name: rel.Target.Name}) + " " + alias + from, corr := b.embedSource(rel, alias, parentAlias) b.sb.WriteString("EXISTS (SELECT 1 FROM ") b.sb.WriteString(from) @@ -242,7 +318,7 @@ func (b *builder) writeEmbedExists(emb *ir.Embed, parentAlias string) *pgerr.API b.sb.WriteString(b.joinCond(jx, rel.JLocal, parentAlias, rel.Local)) } else { b.sb.WriteString(" WHERE ") - b.sb.WriteString(b.joinCond(alias, rel.Foreign, parentAlias, rel.Local)) + b.sb.WriteString(corr) } if err := b.writeEmbedFilter(emb, alias); err != nil { return err @@ -269,7 +345,8 @@ func (b *builder) embedObject(emb *ir.Embed, alias string) (string, *pgerr.APIEr saved := b.qual b.qual = alias - defer func() { b.qual = saved }() + savedC, savedReps, savedR := b.useRelation(&emb.Query, emb.Rel.Target.Name) + defer func() { b.qual = saved; b.computed, b.reps, b.rootRow = savedC, savedReps, savedR }() for _, it := range emb.Query.Select { switch v := it.(type) { @@ -285,6 +362,21 @@ func (b *builder) embedObject(emb *ir.Embed, alias string) (string, *pgerr.APIEr pairs = append(pairs, Pair{Key: v.Name(), Value: expr}) case ir.EmbedRef: nested := &emb.Query.Embeds[v.Index] + // A nested empty-parenthesis embed joins for filtering but is not + // projected into the parent object, mirroring the top-level rule. + if nested.EmptySelect { + continue + } + // A nested spread lifts its columns into this object, just as a + // top-level spread lifts into the parent row. + if nested.Spread { + lifted, err := b.spreadPairs(nested, alias) + if err != nil { + return "", err + } + pairs = append(pairs, lifted...) + continue + } sub, err := b.embedExpr(nested, alias) if err != nil { return "", err @@ -301,6 +393,89 @@ func (b *builder) embedObject(emb *ir.Embed, alias string) (string, *pgerr.APIEr return b.d.JSONObject(pairs), nil } +// spreadPairs lowers a spread embed (...rel) to the parent-level columns it +// lifts, each a correlated subquery the caller projects flat into the parent row +// (top level) or merges into the enclosing JSON object (nested). A to-one spread +// lifts each column as a scalar; a to-many spread lifts each column as a JSON +// array of that column's values across the related rows (v12.1). Renaming and +// star expansion follow the ordinary projection rules. A spread over a +// many-to-many relationship is not lowered and reports PGRST127 rather than emit +// wrong SQL (item 07.9). +func (b *builder) spreadPairs(emb *ir.Embed, parentAlias string) ([]Pair, *pgerr.APIError) { + rel := emb.Rel + if rel.Junction != nil { + return nil, pgerr.ErrUnsupported("spread over a many-to-many relationship", "sql") + } + b.aliasN++ + alias := "t" + strconv.Itoa(b.aliasN) + from := b.qualify(ir.Ref{Schema: rel.Target.Schema, Name: rel.Target.Name}) + " " + alias + + // The correlation predicate (join plus the embed's own filters) is shared by + // every lifted column, so render it once. + where, err := b.capture(func() *pgerr.APIError { + b.sb.WriteString(b.joinCond(alias, rel.Foreign, parentAlias, rel.Local)) + return b.writeEmbedFilter(emb, alias) + }) + if err != nil { + return nil, err + } + + type lifted struct{ name, expr string } + var cols []lifted + saved := b.qual + b.qual = alias + savedC, savedReps, savedR := b.useRelation(&emb.Query, rel.Target.Name) + defer func() { b.computed, b.reps, b.rootRow = savedC, savedReps, savedR }() + addAll := func() { + for _, n := range rel.Target.ColumnNames() { + cols = append(cols, lifted{n, alias + "." + b.d.QuoteIdent(n)}) + } + } + if len(emb.Query.Select) == 0 { + addAll() + } else { + for _, it := range emb.Query.Select { + col, ok := it.(ir.Column) + if !ok { + b.qual = saved + return nil, pgerr.ErrUnsupported("non-column item in a spread", "sql") + } + if isStar(col) { + addAll() + continue + } + expr, e := b.columnExpr(col) + if e != nil { + b.qual = saved + return nil, e + } + cols = append(cols, lifted{col.Name(), expr}) + } + } + b.qual = saved + + toMany := rel.Card != schema.CardToOne + pairs := make([]Pair, 0, len(cols)) + for _, c := range cols { + var sub string + if toMany { + // COALESCE so a parent with no related rows lifts [] rather than NULL, + // matching the nested to-many array's empty case. + sub = "(SELECT COALESCE(" + b.d.JSONAgg(c.expr, "") + ", " + + b.d.Cast("'[]'", "json") + ") FROM " + from + " WHERE " + where + ")" + } else { + limClause := "" + lim := 1 + if lo := b.d.LimitOffset(&lim, nil, false); lo != "" { + limClause = " " + lo + } + sub = "(SELECT " + c.expr + " FROM " + from + " WHERE " + where + limClause + ")" + } + pairs = append(pairs, Pair{Key: c.name, Value: sub}) + } + return pairs, nil +} + // writeEmbedFilter appends the embed's own horizontal filters, ANDed onto the // join predicate and qualified by the target alias. func (b *builder) writeEmbedFilter(emb *ir.Embed, alias string) *pgerr.APIError { @@ -309,10 +484,12 @@ func (b *builder) writeEmbedFilter(emb *ir.Embed, alias string) *pgerr.APIError } saved := b.qual b.qual = alias + savedC, savedReps, savedR := b.useRelation(&emb.Query, emb.Rel.Target.Name) b.sb.WriteString(" AND (") err := b.writeCond(*emb.Query.Where) b.sb.WriteString(")") b.qual = saved + b.computed, b.reps, b.rootRow = savedC, savedReps, savedR return err } diff --git a/backend/sqlgen/embed_test.go b/backend/sqlgen/embed_test.go index d589b81..348850e 100644 --- a/backend/sqlgen/embed_test.go +++ b/backend/sqlgen/embed_test.go @@ -51,9 +51,10 @@ func embedModel() *schema.Model { } actors := &schema.Relation{Schema: "public", Name: "actors", Columns: cols("id", "name")} roles := &schema.Relation{ - Schema: "public", - Name: "roles", - Columns: cols("film_id", "actor_id"), + Schema: "public", + Name: "roles", + Columns: cols("film_id", "actor_id"), + PrimaryKey: []string{"film_id", "actor_id"}, // composite PK marks roles a junction ForeignKeys: []*schema.ForeignKey{ {Name: "roles_film_id_fkey", Columns: []string{"film_id"}, RefSchema: "public", RefRelation: "films", RefColumns: []string{"id"}}, {Name: "roles_actor_id_fkey", Columns: []string{"actor_id"}, RefSchema: "public", RefRelation: "actors", RefColumns: []string{"id"}}, @@ -196,6 +197,153 @@ func TestEmbedInnerAddsExists(t *testing.T) { } } +// films?actors=not.is.null filters the parent on the existence of a related +// actor: a semi-join, the same EXISTS an !inner embed adds, correlated to t0 +// and crossing the roles junction (item 01.12). +func TestEmbedPredicateNotIsNullSemiJoin(t *testing.T) { + m := embedModel() + where := ir.Cond(ir.EmbedPredicate{Index: 0, Exists: true}) + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "films"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"title"}}, ir.EmbedRef{Index: 0}}, + Where: &where, + Embeds: []ir.Embed{{ + OutKey: "actors", + Target: ir.Ref{Schema: "public", Name: "actors"}, + Rel: relate(t, m, "films", "actors"), + }}, + } + got := compileEmbed(t, q).SQL + if strings.Contains(got, "NOT EXISTS") { + t.Errorf("not.is.null should be a plain EXISTS, not anti-join\n in %q", got) + } + for _, want := range []string{ + `WHERE EXISTS (SELECT 1 FROM "public"."actors" x2`, + `JOIN "public"."roles" xj2 ON xj2."actor_id" = x2."id"`, + `WHERE xj2."film_id" = t0."id"`, + } { + if !strings.Contains(got, want) { + t.Errorf("semi-join missing %q\n in %q", want, got) + } + } +} + +// films?actors=is.null is the anti-join complement: a parent with no related +// actor, lowered to NOT EXISTS over the same relationship (item 01.12). +func TestEmbedPredicateIsNullAntiJoin(t *testing.T) { + m := embedModel() + where := ir.Cond(ir.EmbedPredicate{Index: 0, Exists: false}) + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "directors"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}, ir.EmbedRef{Index: 0}}, + Where: &where, + Embeds: []ir.Embed{{ + OutKey: "films", + Target: ir.Ref{Schema: "public", Name: "films"}, + Rel: relate(t, m, "directors", "films"), + }}, + } + got := compileEmbed(t, q).SQL + if !strings.Contains(got, `WHERE NOT EXISTS (SELECT 1 FROM "public"."films" x2 WHERE x2."director_id" = t0."id")`) { + t.Errorf("is.null missing NOT EXISTS anti-join\n in %q", got) + } +} + +// The embed-existence predicate composes under or=(...): one disjunct is the +// semi-join EXISTS, the other an ordinary parent-column compare. +func TestEmbedPredicateInsideOr(t *testing.T) { + m := embedModel() + where := ir.Cond(ir.Or{Kids: []ir.Cond{ + ir.EmbedPredicate{Index: 0, Exists: true}, + ir.Compare{Path: []string{"name"}, Op: ir.OpEq, Value: ir.Value{Text: "Lynch"}}, + }}) + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "directors"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}, ir.EmbedRef{Index: 0}}, + Where: &where, + Embeds: []ir.Embed{{ + OutKey: "films", + Target: ir.Ref{Schema: "public", Name: "films"}, + Rel: relate(t, m, "directors", "films"), + }}, + } + got := compileEmbed(t, q).SQL + if !strings.Contains(got, `WHERE (EXISTS (SELECT 1 FROM "public"."films" x2 WHERE x2."director_id" = t0."id") OR t0."name" = $1)`) { + t.Errorf("or= with embed predicate not lowered as expected\n in %q", got) + } +} + +// A count over a query carrying an embed-existence filter correlates the EXISTS +// to the parent by its bare table name, since the count gives it no alias. +func TestEmbedPredicateInCount(t *testing.T) { + m := embedModel() + where := ir.Cond(ir.EmbedPredicate{Index: 0, Exists: true}) + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "directors"}, + Where: &where, + Embeds: []ir.Embed{{ + OutKey: "films", + Target: ir.Ref{Schema: "public", Name: "films"}, + Rel: relate(t, m, "directors", "films"), + }}, + } + st, err := CompileCount(embedStub{}, q) + if err != nil { + t.Fatalf("CompileCount: %v", err) + } + if !strings.Contains(st.SQL, `SELECT count(*) FROM "public"."directors" WHERE EXISTS (SELECT 1 FROM "public"."films" x1 WHERE x1."director_id" = "public"."directors"."id")`) { + t.Errorf("count did not correlate embed EXISTS to the bare table\n in %q", st.SQL) + } +} + +// A count over a query carrying an !inner embed restricts the parent with the +// same EXISTS the row query adds, so an exact count matches the filtered body +// (item 07.7). The EXISTS correlates to the bare table name, since the count +// gives the parent no alias. +func TestCountAppliesInnerEmbedExists(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "directors"}, + Embeds: []ir.Embed{{ + OutKey: "films", + Join: ir.JoinInner, + Target: ir.Ref{Schema: "public", Name: "films"}, + Rel: relate(t, m, "directors", "films"), + }}, + } + st, err := CompileCount(embedStub{}, q) + if err != nil { + t.Fatalf("CompileCount: %v", err) + } + want := `SELECT count(*) FROM "public"."directors" ` + + `WHERE EXISTS (SELECT 1 FROM "public"."films" x1 ` + + `WHERE x1."director_id" = "public"."directors"."id")` + if st.SQL != want { + t.Errorf("\n got %q\nwant %q", st.SQL, want) + } +} + +// A non-inner embed leaves the count unrestricted: only !inner embeds prune the +// parent, so a plain to-many embed adds no EXISTS to the count. +func TestCountIgnoresNonInnerEmbed(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "directors"}, + Embeds: []ir.Embed{{ + OutKey: "films", + Target: ir.Ref{Schema: "public", Name: "films"}, + Rel: relate(t, m, "directors", "films"), + }}, + } + st, err := CompileCount(embedStub{}, q) + if err != nil { + t.Fatalf("CompileCount: %v", err) + } + if strings.Contains(st.SQL, "WHERE") { + t.Errorf("non-inner embed should add no predicate, got %q", st.SQL) + } +} + // An embed's own horizontal filter is ANDed onto the join predicate, bound, and // qualified by the target alias. func TestEmbedHorizontalFilterIsBound(t *testing.T) { @@ -239,22 +387,158 @@ func TestEmbedStarProjectsAllColumns(t *testing.T) { } } +// An empty-parenthesis embed, director(), joins for filtering but is not +// projected: the parent select carries no key for it, the opposite of an absent +// or star projection which takes every column (item 07.8). +func TestEmbedEmptySelectHidesKey(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "films"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"title"}}, + ir.EmbedRef{Index: 0}, + }, + Embeds: []ir.Embed{{ + OutKey: "director", + EmptySelect: true, + Target: ir.Ref{Schema: "public", Name: "directors"}, + Rel: relate(t, m, "films", "directors"), + }}, + } + got := compileEmbed(t, q).SQL + if strings.Contains(got, `"director"`) { + t.Errorf("empty-paren embed should project no key, got %q", got) + } + if !strings.HasPrefix(got, `SELECT t0."title" FROM`) { + t.Errorf("parent should project only its own columns, got %q", got) + } +} + +// An empty-parenthesis embed marked !inner still restricts the parent through +// EXISTS even though it projects nothing: the filter-without-fetch idiom. +func TestEmbedEmptySelectInnerStillFilters(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "films"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"title"}}, + ir.EmbedRef{Index: 0}, + }, + Embeds: []ir.Embed{{ + OutKey: "director", + EmptySelect: true, + Join: ir.JoinInner, + Target: ir.Ref{Schema: "public", Name: "directors"}, + Rel: relate(t, m, "films", "directors"), + }}, + } + got := compileEmbed(t, q).SQL + if strings.Contains(got, `"director"`) { + t.Errorf("empty-paren embed should project no key, got %q", got) + } + if !strings.Contains(got, `WHERE EXISTS (SELECT 1 FROM "public"."directors"`) { + t.Errorf("!inner empty-paren embed should still filter the parent, got %q", got) + } +} + // A spread embed is not yet lowered to SQL; it must report PGRST127 rather than -// emit something wrong. -func TestEmbedSpreadUnsupported(t *testing.T) { +// A to-one spread (...director(name)) lifts the embedded column into the parent +// row as a correlated scalar subquery aliased to the column name, not nested +// under a relation key (item 07.9). +func TestEmbedSpreadToOneLiftsColumns(t *testing.T) { m := embedModel() q := &ir.Query{ Relation: ir.Ref{Schema: "public", Name: "films"}, - Select: []ir.SelectItem{ir.EmbedRef{Index: 0}}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"title"}}, + ir.EmbedRef{Index: 0}, + }, Embeds: []ir.Embed{{ - OutKey: "director", + OutKey: "directors", + Spread: true, + Target: ir.Ref{Schema: "public", Name: "directors"}, + Rel: relate(t, m, "films", "directors"), + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}}}, + }}, + } + want := `SELECT t0."title", (SELECT t1."name" FROM "public"."directors" t1 ` + + `WHERE t1."id" = t0."director_id" LIMIT 1) AS "name" FROM "public"."films" t0` + if got := compileEmbed(t, q).SQL; got != want { + t.Errorf("\n got %q\nwant %q", got, want) + } +} + +// A spread renames the lifted column when written col:alias, and the subquery is +// aliased to the new name. +func TestEmbedSpreadRenamesColumn(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "films"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"title"}}, + ir.EmbedRef{Index: 0}, + }, + Embeds: []ir.Embed{{ + OutKey: "directors", Spread: true, Target: ir.Ref{Schema: "public", Name: "directors"}, Rel: relate(t, m, "films", "directors"), + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"name"}, Alias: "director_name"}}}, + }}, + } + if got := compileEmbed(t, q).SQL; !strings.Contains(got, `LIMIT 1) AS "director_name"`) { + t.Errorf("spread rename not applied, got %q", got) + } +} + +// A to-many spread (...films(title)) lifts each column as a JSON array of that +// column's values across the related rows, COALESCEd to [] for a parent with no +// children (v12.1 semantics). +func TestEmbedSpreadToManyLiftsArrays(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "directors"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"name"}}, + ir.EmbedRef{Index: 0}, + }, + Embeds: []ir.Embed{{ + OutKey: "films", + Spread: true, + Target: ir.Ref{Schema: "public", Name: "films"}, + Rel: relate(t, m, "directors", "films"), + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"title"}}}}, + }}, + } + got := compileEmbed(t, q).SQL + for _, want := range []string{ + `json_group_array(t1."title")`, + `FROM "public"."films" t1 WHERE t1."director_id" = t0."id"`, + `) AS "title"`, + } { + if !strings.Contains(got, want) { + t.Errorf("to-many spread missing %q\n in %q", want, got) + } + } +} + +// A spread over a many-to-many relationship is not lowered: it reports PGRST127 +// rather than emit wrong SQL. +func TestEmbedSpreadManyToManyUnsupported(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "films"}, + Select: []ir.SelectItem{ir.EmbedRef{Index: 0}}, + Embeds: []ir.Embed{{ + OutKey: "actors", + Spread: true, + Target: ir.Ref{Schema: "public", Name: "actors"}, + Rel: relate(t, m, "films", "actors"), + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}}}, }}, } if _, err := CompileRead(embedStub{}, q); err == nil || err.Code != "PGRST127" { - t.Fatalf("spread embed err = %v, want PGRST127", err) + t.Fatalf("many-to-many spread err = %v, want PGRST127", err) } } diff --git a/backend/sqlgen/isunknown_test.go b/backend/sqlgen/isunknown_test.go new file mode 100644 index 0000000..eeb3e71 --- /dev/null +++ b/backend/sqlgen/isunknown_test.go @@ -0,0 +1,75 @@ +package sqlgen + +import ( + "strings" + "testing" + + "github.com/tamnd/dbrest/ir" +) + +// 07.4 task 1: is.unknown lowers to the three-valued test. The stub dialect has +// no native spelling (IsUnknown returns ok=false), so the compiler falls back to +// "col IS NULL", which selects the same rows for a boolean column. +func TestCompileIsUnknownFallback(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"done"}, Op: ir.OpIs, Value: ir.Value{Text: "unknown"}}) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + if !strings.Contains(st.SQL, `"done" IS NULL`) { + t.Errorf("SQL = %q, want a `done IS NULL` predicate", st.SQL) + } +} + +// A dialect that spells the operator natively (IsUnknown returns ok=true) keeps +// it, mirroring the IsBool seam: this stub stands in for PostgreSQL. +type unknownDialect struct{ stub } + +func (unknownDialect) IsUnknown(col string) (string, bool) { return col + " IS UNKNOWN", true } + +func TestCompileIsUnknownNative(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"done"}, Op: ir.OpIs, Value: ir.Value{Text: "unknown"}}) + st, err := CompileRead(unknownDialect{}, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + if err != nil { + t.Fatalf("CompileRead: %v", err) + } + if !strings.Contains(st.SQL, `"done" IS UNKNOWN`) { + t.Errorf("SQL = %q, want a native `done IS UNKNOWN` predicate", st.SQL) + } +} + +// 07.4 task 2: eq.true binds a boolean against a boolean column... +func TestCompileEqTrueBooleanColumn(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"done"}, Op: ir.OpEq, ColumnType: "bool", Value: ir.Value{Text: "true"}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + if !strings.Contains(st.SQL, `"done" = TRUE`) { + t.Errorf("SQL = %q, want `done = TRUE`", st.SQL) + } + if len(st.Args) != 0 { + t.Errorf("Args = %v, want none (boolean rendered inline)", st.Args) + } +} + +// ...but binds the literal word against a text column, where "true" is data, not +// a boolean, so a text column holding the word still matches. +func TestCompileEqTrueTextColumn(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"label"}, Op: ir.OpEq, ColumnType: "text", Value: ir.Value{Text: "true"}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + if strings.Contains(st.SQL, "TRUE") { + t.Errorf("SQL = %q, want the word bound as a parameter, not the boolean TRUE", st.SQL) + } + if len(st.Args) != 1 || st.Args[0] != "true" { + t.Errorf("Args = %v, want [true] bound as text", st.Args) + } +} + +// An unknown column type keeps the boolean rendering: the common filter against +// a boolean column whose type the planner did not stamp must not regress. +func TestCompileEqTrueUnknownColumnType(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"done"}, Op: ir.OpEq, Value: ir.Value{Text: "true"}}) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + if !strings.Contains(st.SQL, `"done" = TRUE`) { + t.Errorf("SQL = %q, want `done = TRUE` for an untyped column", st.SQL) + } +} diff --git a/backend/sqlgen/jsonpath_test.go b/backend/sqlgen/jsonpath_test.go new file mode 100644 index 0000000..7d44f38 --- /dev/null +++ b/backend/sqlgen/jsonpath_test.go @@ -0,0 +1,89 @@ +package sqlgen + +import ( + "strings" + "testing" + + "github.com/tamnd/dbrest/ir" +) + +// 07.1: a JSON-path projection lowers through the dialect's JSONPath. The stub +// spells the PostgreSQL native chain, with a digit hop as an array index and a +// final ->> producing text. +func TestCompileJSONPathProjection(t *testing.T) { + st := compile(t, &ir.Query{ + Relation: ir.Ref{Name: "t"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"data", "phones", "0", "number"}, Last: ir.JSONArrow2}, + }, + }) + if !strings.Contains(st.SQL, `"data"->'phones'->0->>'number'`) { + t.Errorf("SQL = %q, want the native ->/->> chain", st.SQL) + } + // The output field is named for the last hop. + if !strings.Contains(st.SQL, `AS "number"`) { + t.Errorf("SQL = %q, want the projection aliased to the last hop", st.SQL) + } +} + +// A JSON-path filter lowers the same way; a final -> keeps the json typing. +func TestCompileJSONPathFilter(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"data", "blood_type"}, Last: ir.JSONArrow2, + Op: ir.OpEq, Value: ir.Value{Text: "A-"}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + if !strings.Contains(st.SQL, `"data"->>'blood_type' = `) { + t.Errorf("SQL = %q, want a ->> text predicate", st.SQL) + } + if len(st.Args) != 1 || st.Args[0] != "A-" { + t.Errorf("Args = %v, want the value bound", st.Args) + } +} + +// eq.true against a JSON ->> extract binds the literal word as text, never the +// boolean TRUE: a JSON field holding "true" must match (07.4 coercion is +// column-type driven and a JSON access is not a boolean column). +func TestCompileJSONPathEqTrueBindsText(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"data", "flag"}, Last: ir.JSONArrow2, + Op: ir.OpEq, Value: ir.Value{Text: "true"}, + }) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + if strings.Contains(st.SQL, "TRUE") { + t.Errorf("SQL = %q, want the word bound, not the boolean TRUE", st.SQL) + } + if len(st.Args) != 1 || st.Args[0] != "true" { + t.Errorf("Args = %v, want [true] bound as text", st.Args) + } +} + +// Ordering by a JSON path lowers through the dialect in the ORDER BY. +func TestCompileJSONPathOrder(t *testing.T) { + st := compile(t, &ir.Query{ + Relation: ir.Ref{Name: "t"}, + Order: []ir.OrderTerm{{Path: []string{"data", "created_at"}, Last: ir.JSONArrow2, Desc: true}}, + }) + if !strings.Contains(st.SQL, `ORDER BY "data"->>'created_at' DESC`) { + t.Errorf("SQL = %q, want ORDER BY on the ->> extract", st.SQL) + } +} + +// An engine without JSON paths reports ok=false and the request is PGRST127 +// rather than emitting wrong SQL. +type noJSONDialect struct{ stub } + +func (noJSONDialect) JSONPath(string, []string, bool) (string, bool) { return "", false } + +func TestCompileJSONPathCapabilityGap(t *testing.T) { + where := ir.Cond(ir.Compare{ + Path: []string{"data", "x"}, Last: ir.JSONArrow2, Op: ir.OpEq, Value: ir.Value{Text: "1"}, + }) + _, err := CompileRead(noJSONDialect{}, &ir.Query{Relation: ir.Ref{Name: "t"}, Where: &where}) + if err == nil { + t.Fatal("expected an unsupported error for a JSON path on an engine without one") + } + if err.Code != "PGRST127" { + t.Errorf("code = %s, want PGRST127", err.Code) + } +} diff --git a/backend/sqlgen/rangeop_test.go b/backend/sqlgen/rangeop_test.go new file mode 100644 index 0000000..ab2c592 --- /dev/null +++ b/backend/sqlgen/rangeop_test.go @@ -0,0 +1,41 @@ +package sqlgen + +import ( + "testing" + + "github.com/tamnd/dbrest/ir" +) + +// The five range operators (sl/sr/nxr/nxl/adj) lower through the dialect's +// RangeOp hook to the native PostgreSQL spellings (item 07.5). The stub models a +// range-capable engine, so each compiles to "col $1". +func TestCompileRangeOperators(t *testing.T) { + cases := []struct { + op ir.Op + want string + }{ + {ir.OpRangeSL, "<<"}, + {ir.OpRangeSR, ">>"}, + {ir.OpRangeNXR, "&<"}, + {ir.OpRangeNXL, "&>"}, + {ir.OpRangeAdj, "-|-"}, + } + for _, c := range cases { + where := ir.Cond(ir.Compare{Path: []string{"period"}, Op: c.op, Value: ir.Value{Text: "[2000-01-01,2000-12-31]"}}) + st := compile(t, &ir.Query{Relation: ir.Ref{Name: "events"}, Where: &where}) + want := `SELECT * FROM "events" WHERE "period" ` + c.want + ` $1` + if st.SQL != want { + t.Errorf("op %v: SQL = %q, want %q", c.op, st.SQL, want) + } + if len(st.Args) != 1 || st.Args[0] != "[2000-01-01,2000-12-31]" { + t.Errorf("op %v: Args = %v, want one range literal", c.op, st.Args) + } + } +} + +// noRangeDialect models an engine without range types: its RangeOp declines, so a +// range filter is PGRST127 naming the operator rather than invalid SQL. The +// decline path is asserted in compile_test.go TestCompileRangeOperatorRejectedNamed. +type noRangeDialect struct{ stub } + +func (noRangeDialect) RangeOp(_, _, _ string) (string, bool) { return "", false } diff --git a/backend/sqlgen/related_order_test.go b/backend/sqlgen/related_order_test.go new file mode 100644 index 0000000..e3e5b45 --- /dev/null +++ b/backend/sqlgen/related_order_test.go @@ -0,0 +1,92 @@ +package sqlgen + +import ( + "strings" + "testing" + + "github.com/tamnd/dbrest/ir" +) + +// A top-level order=rel(col) lowers to a correlated scalar subquery selecting the +// to-one embed's column, joined back to the parent: a parent with no related row +// yields NULL, which the dialect's NULLs placement then orders (item 07.6). +func TestRelatedOrderToOneSubquery(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "films"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"title"}}, + ir.EmbedRef{Index: 0}, + }, + Embeds: []ir.Embed{{ + Cardinality: ir.CardToOne, + OutKey: "directors", + Target: ir.Ref{Schema: "public", Name: "directors"}, + Rel: relate(t, m, "films", "directors"), + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}}}, + }}, + // Match by written target name, since this embed has no alias. + Order: []ir.OrderTerm{{Rel: "directors", Path: []string{"name"}}}, + } + got := compileEmbed(t, q).SQL + // The embed subquery consumes t1; the order subquery takes the next alias, o2. + want := ` ORDER BY (SELECT o2."name" FROM "public"."directors" o2 ` + + `WHERE o2."id" = t0."director_id") ASC NULLS LAST` + if !strings.Contains(got, want) { + t.Errorf("related order subquery missing\n want %q\n in %q", want, got) + } +} + +// The embed an order term names is matched by its alias when one is given, so +// order=client(...) resolves through `client:clients(...)`. +func TestRelatedOrderMatchesAlias(t *testing.T) { + m := embedModel() + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "films"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"title"}}, + ir.EmbedRef{Index: 0}, + }, + Embeds: []ir.Embed{{ + Cardinality: ir.CardToOne, + Alias: "director", + OutKey: "director", + Target: ir.Ref{Schema: "public", Name: "directors"}, + Rel: relate(t, m, "films", "directors"), + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}}}, + }}, + Order: []ir.OrderTerm{{Rel: "director", Path: []string{"name"}, Desc: true}}, + } + got := compileEmbed(t, q).SQL + want := ` ORDER BY (SELECT o2."name" FROM "public"."directors" o2 ` + + `WHERE o2."id" = t0."director_id") DESC NULLS FIRST` + if !strings.Contains(got, want) { + t.Errorf("aliased related order missing\n want %q\n in %q", want, got) + } +} + +// nullsfirst/nullslast still apply to a related order: the parent's NULL (no +// related row) sorts where the client asks, not where the default lands. +func TestRelatedOrderHonorsNullsPlacement(t *testing.T) { + m := embedModel() + nf := true + q := &ir.Query{ + Relation: ir.Ref{Schema: "public", Name: "films"}, + Select: []ir.SelectItem{ + ir.Column{Path: []string{"title"}}, + ir.EmbedRef{Index: 0}, + }, + Embeds: []ir.Embed{{ + Cardinality: ir.CardToOne, + OutKey: "directors", + Target: ir.Ref{Schema: "public", Name: "directors"}, + Rel: relate(t, m, "films", "directors"), + Query: ir.Query{Select: []ir.SelectItem{ir.Column{Path: []string{"name"}}}}, + }}, + Order: []ir.OrderTerm{{Rel: "directors", Path: []string{"name"}, NullsFirst: &nf}}, + } + got := compileEmbed(t, q).SQL + if !strings.Contains(got, `ASC NULLS FIRST`) { + t.Errorf("related order did not honor nullsfirst\n in %q", got) + } +} diff --git a/backend/sqlgen/representation_test.go b/backend/sqlgen/representation_test.go new file mode 100644 index 0000000..421c1a4 --- /dev/null +++ b/backend/sqlgen/representation_test.go @@ -0,0 +1,216 @@ +package sqlgen + +import ( + "testing" + + "github.com/tamnd/dbrest/ir" +) + +// colorRep mirrors the live _p11dr fixture: a "color" domain over integer with a +// cast function per direction (to-json formats, from-text parses a filter literal, +// from-json parses a write value), all in schema _p11dr. +var colorRep = ir.Rep{ + ToJSONSchema: "_p11dr", ToJSONFunc: "json", + FromTextSchema: "_p11dr", FromTextFunc: "color", + FromJSONSchema: "_p11dr", FromJSONFunc: "color", +} + +func TestRepReadAppliesToJSON(t *testing.T) { + st := compile(t, &ir.Query{ + Relation: ir.Ref{Name: "shirts"}, + Select: []ir.SelectItem{col("id"), col("c")}, + Reps: map[string]ir.Rep{"c": colorRep}, + }) + want := `SELECT "id", "_p11dr"."json"("c") AS "c" FROM "shirts"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +func TestRepFilterAppliesFromText(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"c"}, Op: ir.OpEq, Value: ir.Value{Text: "#ff0000"}}) + st, err := CompileRead(stub{}, &ir.Query{ + Relation: ir.Ref{Name: "shirts"}, + Where: &where, + Reps: map[string]ir.Rep{"c": colorRep}, + }) + if err != nil { + t.Fatalf("CompileRead: %v", err) + } + want := `SELECT * FROM "shirts" WHERE "c" = "_p11dr"."color"($1::text)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } + if len(st.Args) != 1 || st.Args[0] != "#ff0000" { + t.Errorf("Args = %#v, want [#ff0000]", st.Args) + } +} + +func TestRepOrderingFilterAppliesFromText(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"c"}, Op: ir.OpGte, Value: ir.Value{Text: "#000080"}}) + st, err := CompileRead(stub{}, &ir.Query{ + Relation: ir.Ref{Name: "shirts"}, + Where: &where, + Reps: map[string]ir.Rep{"c": colorRep}, + }) + if err != nil { + t.Fatalf("CompileRead: %v", err) + } + want := `SELECT * FROM "shirts" WHERE "c" >= "_p11dr"."color"($1::text)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +func TestRepInsertAppliesFromJSON(t *testing.T) { + st, err := CompileInsert(stub{}, &ir.Query{ + Relation: ir.Ref{Name: "shirts"}, + Write: &ir.WriteSpec{ + Columns: []string{"c"}, + Rows: []map[string]ir.Value{{"c": ir.Value{JSON: "#0000ff"}}}, + }, + Reps: map[string]ir.Rep{"c": colorRep}, + }, nil) + if err != nil { + t.Fatalf("CompileInsert: %v", err) + } + want := `INSERT INTO "shirts" ("c") VALUES ("_p11dr"."color"($1::json))` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } + if len(st.Args) != 1 || st.Args[0] != `"#0000ff"` { + t.Errorf("Args = %#v, want [\"#0000ff\"]", st.Args) + } +} + +func TestRepUpdateAppliesFromJSON(t *testing.T) { + st, err := CompileUpdate(stub{}, &ir.Query{ + Relation: ir.Ref{Name: "shirts"}, + Write: &ir.WriteSpec{ + Set: map[string]ir.Value{"c": {JSON: "#00ff00"}}, + }, + Reps: map[string]ir.Rep{"c": colorRep}, + }, nil) + if err != nil { + t.Fatalf("CompileUpdate: %v", err) + } + want := `UPDATE "shirts" SET "c" = "_p11dr"."color"($1::json)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } + if len(st.Args) != 1 || st.Args[0] != `"#00ff00"` { + t.Errorf("Args = %#v, want [\"#00ff00\"]", st.Args) + } +} + +func TestRepInsertReturningAppliesToJSON(t *testing.T) { + st, err := CompileInsert(stub{}, &ir.Query{ + Relation: ir.Ref{Name: "shirts"}, + Write: &ir.WriteSpec{ + Columns: []string{"id", "c"}, + Rows: []map[string]ir.Value{{"id": jnum("1"), "c": ir.Value{JSON: "#0000ff"}}}, + Return: ir.ReturnRepresentation, + }, + Reps: map[string]ir.Rep{"c": colorRep}, + }, []string{"id", "c"}) + if err != nil { + t.Fatalf("CompileInsert: %v", err) + } + want := `INSERT INTO "shirts" ("id", "c") VALUES ($1, "_p11dr"."color"($2::json)) ` + + `RETURNING "id", "_p11dr"."json"("c") AS "c"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// repWhere compiles a single-filter read over the shirts fixture and returns the +// statement, keeping the operator tests short. +func repWhere(t *testing.T, d Dialect, c ir.Compare) *Statement { + t.Helper() + where := ir.Cond(c) + st, err := CompileRead(d, &ir.Query{ + Relation: ir.Ref{Name: "shirts"}, + Where: &where, + Reps: map[string]ir.Rep{"c": colorRep}, + }) + if err != nil { + t.Fatalf("CompileRead: %v", err) + } + return st +} + +// TestRepFilterMatchAppliesFromText: a regex match on a represented column parses +// the pattern through the from-text cast, matching PostgREST's match/imatch path. +func TestRepFilterMatchAppliesFromText(t *testing.T) { + st := repWhere(t, stub{}, ir.Compare{Path: []string{"c"}, Op: ir.OpMatch, Value: ir.Value{Text: "#ff"}}) + want := `SELECT * FROM "shirts" WHERE "c" ~ "_p11dr"."color"($1::text)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// TestRepFilterContainsAppliesFromText: the simple-operator path (cs/cd/ov, range) +// parses the operand through the from-text cast, as PostgREST does. +func TestRepFilterContainsAppliesFromText(t *testing.T) { + st := repWhere(t, stub{}, ir.Compare{Path: []string{"c"}, Op: ir.OpContains, Value: ir.Value{Text: "{1,2}"}}) + want := `SELECT * FROM "shirts" WHERE "c" @> "_p11dr"."color"($1::text)` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// TestRepFilterLikeStaysRaw: like/ilike carry a wildcard pattern, so PostgREST +// binds them raw even on a represented column. Confirm dbrest does not wrap them. +func TestRepFilterLikeStaysRaw(t *testing.T) { + st := repWhere(t, stub{}, ir.Compare{Path: []string{"c"}, Op: ir.OpLike, Value: ir.Value{Text: "#ff%"}}) + want := `SELECT * FROM "shirts" WHERE "c" LIKE $1` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// TestRepFilterInPerElement: on an engine without a native = ANY list bind, IN +// expands per element and each element parses through the from-text cast. +func TestRepFilterInPerElement(t *testing.T) { + st := repWhere(t, stub{}, ir.Compare{Path: []string{"c"}, Op: ir.OpIn, Value: ir.Value{List: []string{"#0000ff", "#00ff00"}}}) + want := `SELECT * FROM "shirts" WHERE "c" IN (` + + `"_p11dr"."color"($1::text), "_p11dr"."color"($2::text))` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// anyListStub is a stub whose InList binds the list as one = ANY array, like the +// PostgreSQL dialect, so the IN representation path can be asserted in its native +// unnest form. +type anyListStub struct{ stub } + +func (anyListStub) InList(col string) (string, bool) { + return col + " = ANY(" + PatternMark + ")", true +} + +// TestRepFilterInAnyUnnest: on a = ANY engine the represented IN list parses each +// element over the unpacked array, matching PostgREST's pgFmtArrayLiteralForField. +func TestRepFilterInAnyUnnest(t *testing.T) { + st := repWhere(t, anyListStub{}, ir.Compare{Path: []string{"c"}, Op: ir.OpIn, Value: ir.Value{List: []string{"#0000ff", "#00ff00"}}}) + want := `SELECT * FROM "shirts" WHERE "c" = ANY(` + + `(SELECT "_p11dr"."color"(unnest($1::text[]))))` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} + +// TestRepReadExplicitCastOptsOut confirms an explicit client cast (col::type) +// suppresses the to-json representation: the client asked for a specific +// rendering, so the domain's formatter is not applied. +func TestRepReadExplicitCastOptsOut(t *testing.T) { + st := compile(t, &ir.Query{ + Relation: ir.Ref{Name: "shirts"}, + Select: []ir.SelectItem{col("id"), ir.Column{Path: []string{"c"}, Cast: "text"}}, + Reps: map[string]ir.Rep{"c": colorRep}, + }) + want := `SELECT "id", CAST("c" AS text) AS "c" FROM "shirts"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} diff --git a/backend/sqlite/aggregate_test.go b/backend/sqlite/aggregate_test.go new file mode 100644 index 0000000..072eeb8 --- /dev/null +++ b/backend/sqlite/aggregate_test.go @@ -0,0 +1,104 @@ +package sqlite + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/tamnd/dbrest/backend/sqlgen" + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/plan" +) + +// openSales seeds a sales table with a category and an amount so an aggregate +// has something to fold over. +func openSales(t *testing.T) *Backend { + t.Helper() + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + b, err := Open(dsn) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { b.Close() }) + _, err = b.DB().Exec(` + CREATE TABLE sales (id INTEGER PRIMARY KEY, category TEXT NOT NULL, amount INTEGER NOT NULL); + INSERT INTO sales (id, category, amount) VALUES + (1, 'a', 10), (2, 'a', 20), (3, 'b', 5); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + return b +} + +// planAgg parses and plans a sales read with aggregates enabled. +func planAgg(t *testing.T, b *Backend, query string) *ir.Query { + t.Helper() + q, perr := ir.ParseRead("sales", query, nil) + if perr != nil { + t.Fatalf("ParseRead: %v", perr) + } + model, err := b.Introspect(context.Background()) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + pl, perr := plan.Read(model, q, nil, plan.Options{AggregatesEnabled: true}) + if perr != nil { + t.Fatalf("plan.Read: %v", perr) + } + return pl.Query +} + +func TestExecuteBareCount(t *testing.T) { + b := openSales(t) + q := planAgg(t, b, "select=count()") + rows := execRead(t, b, q) + if len(rows) != 1 { + t.Fatalf("got %d rows, want 1", len(rows)) + } + if got := fmt.Sprint(rows[0]["count"]); got != "3" { + t.Errorf("count = %v, want 3", rows[0]["count"]) + } +} + +func TestExecuteGroupedSum(t *testing.T) { + b := openSales(t) + q := planAgg(t, b, "select=category,amount.sum()&order=category") + rows := execRead(t, b, q) + if len(rows) != 2 { + t.Fatalf("got %d rows, want 2 (one per category)", len(rows)) + } + got := map[string]string{} + for _, r := range rows { + cat, _ := asString(r["category"]) + got[cat] = fmt.Sprint(r["sum"]) + } + if got["a"] != "30" || got["b"] != "5" { + t.Errorf("sums = %v, want a:30 b:5", got) + } +} + +func TestExecuteAggregateWithAlias(t *testing.T) { + b := openSales(t) + q := planAgg(t, b, "select=category,total:amount.sum()&order=category") + rows := execRead(t, b, q) + if _, ok := rows[0]["total"]; !ok { + t.Fatalf("expected a 'total' key, got %v", rows[0]) + } +} + +// TestCompileGroupedSumSQL pins the GROUP BY shape the grouped aggregate lowers +// to on SQLite. +func TestCompileGroupedSumSQL(t *testing.T) { + b := openSales(t) + q := planAgg(t, b, "select=category,amount.sum()") + st, perr := sqlgen.CompileRead(dialect{}, q) + if perr != nil { + t.Fatalf("CompileRead: %v", perr) + } + want := `SELECT "category", sum("amount") AS "sum" FROM "sales" GROUP BY "category"` + if st.SQL != want { + t.Errorf("SQL = %q, want %q", st.SQL, want) + } +} diff --git a/backend/sqlite/dialect.go b/backend/sqlite/dialect.go index a1c8c1b..e36dd77 100644 --- a/backend/sqlite/dialect.go +++ b/backend/sqlite/dialect.go @@ -146,11 +146,99 @@ func (dialect) SessionRead(string) string { return "" } // SessionWrite reports ok=false: there is no engine setting to write. func (dialect) SessionWrite(string) (string, bool) { return "", false } -// ArrayOp returns false; SQLite has no array types or containment operators. -func (dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +// InList reports ok=false: SQLite has no array-bound ANY, so the compiler emits +// the expanded col IN ($1, $2, ...) form. +func (dialect) InList(_ string) (string, bool) { return "", false } + +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so json_each() in ArrayOp can iterate over it. +func (dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText // already JSON or empty; pass through + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + quoted[i] = p // already JSON-quoted + } else { + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} + +// ArrayArg stores a payload array as its JSON text: SQLite has no array +// columns, so a JSON-typed column holds the array and reads it back as JSON. +// A PostgreSQL {a,b} literal here would corrupt the column. +func (dialect) ArrayArg(elems []any, _ string) any { return sqlgen.JSONArrayArg(elems) } + +// JSONPath lowers a JSON sub-path to SQLite's -> / ->> operators over a single +// JSON path argument. SQLite's ->> returns the SQL text scalar and -> returns +// the JSON representation, matching PostgreSQL's ->>/-> typing. Object keys +// become quoted "label" segments and digit hops become [n] array subscripts, so +// data->phones->0->>number renders as data ->> '$."phones"[0]."number"'. +func (dialect) JSONPath(base string, hops []string, asText bool) (string, bool) { + var p strings.Builder + p.WriteString("$") + for _, h := range hops { + if sqlgen.IsJSONArrayIndex(h) { + p.WriteString("[" + h + "]") + } else { + p.WriteString(`."` + strings.ReplaceAll(h, `"`, `""`) + `"`) + } + } + op := "->" + if asText { + op = "->>" + } + return base + " " + op + " '" + strings.ReplaceAll(p.String(), "'", "''") + "'", true +} + +// ArrayOp implements array containment/overlap via SQLite's json_each(). The +// column must be declared as JSON type and store a JSON array (e.g. +// '["cat","work"]'). For any other column type the operator is unsupported +// (ok=false) so the compiler raises PGRST127. op is one of "@>" (contains), +// "<@" (contained-by), "&&" (overlaps). +func (dialect) ArrayOp(col, op, val, colType string) (string, bool) { + if colType != "json" && colType != "jsonb" { + return "", false + } + switch op { + case "@>": // contains: every element of val appears in col + return "NOT EXISTS (SELECT 1 FROM json_each(" + val + ") AS f WHERE f.value NOT IN (SELECT value FROM json_each(" + col + ")))", true + case "<@": // contained-by: every element of col appears in val + return "NOT EXISTS (SELECT 1 FROM json_each(" + col + ") AS f WHERE f.value NOT IN (SELECT value FROM json_each(" + val + ")))", true + case "&&": // overlaps: at least one common element + return "EXISTS (SELECT 1 FROM json_each(" + col + ") AS f WHERE f.value IN (SELECT value FROM json_each(" + val + ")))", true + } + return "", false +} + +// RangeOp declines: SQLite has no range types, so sl/sr/nxr/nxl/adj are PGRST127. +func (dialect) RangeOp(_, _, _ string) (string, bool) { return "", false } + +// ILike folds case explicitly with lower() on both sides. Plain LIKE cannot be +// relied on for case-insensitivity because the pool sets case_sensitive_like = +// ON (so the like operator stays case-sensitive like PostgreSQL); lower() folds +// ASCII, which is the documented best-effort, leaving non-ASCII folding as a gap. +func (dialect) ILike(col, val string) (string, bool) { + return "lower(" + col + ") LIKE lower(" + val + ")", true +} + +// IsBool falls back to the generic "IS 1"/"IS 0" form; SQLite's IS operator is +// a NULL-safe equality that works with any value. +func (dialect) IsBool(string, bool) (string, bool) { return "", false } -// ILike uses plain LIKE which is case-insensitive for ASCII in SQLite. -func (dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } +// IsUnknown has no SQLite spelling; the compiler falls back to "col IS NULL", +// which selects the same rows for a boolean column. +func (dialect) IsUnknown(string) (string, bool) { return "", false } // BoolValue renders a boolean as 1/0; SQLite has no native boolean. func (dialect) BoolValue(v bool) string { diff --git a/backend/sqlite/dialect_test.go b/backend/sqlite/dialect_test.go index 2c6d1d3..1c5a98f 100644 --- a/backend/sqlite/dialect_test.go +++ b/backend/sqlite/dialect_test.go @@ -168,14 +168,20 @@ func TestMapErrorNilAndNonDriver(t *testing.T) { func TestMapErrorConstraintCodes(t *testing.T) { b := openConstraintDB(t) cases := []struct { - name string - exec string - code string - status int + name string + exec string + code string + status int + message string }{ - {"not-null", `INSERT INTO widgets (id, name) VALUES (1, NULL)`, pgerr.CodeNotNullViolation, 400}, - {"check", `INSERT INTO widgets (id, name, qty) VALUES (2, 'a', -1)`, pgerr.CodeCheckViolation, 400}, - {"foreign-key", `INSERT INTO parts (id, widget_id) VALUES (1, 999)`, pgerr.CodeForeignKeyViolation, 409}, + {"not-null", `INSERT INTO widgets (id, name) VALUES (1, NULL)`, + pgerr.CodeNotNullViolation, 400, + `null value in column "name" of relation "widgets" violates not-null constraint`}, + {"check", `INSERT INTO widgets (id, name, qty) VALUES (2, 'a', -1)`, + pgerr.CodeCheckViolation, 400, ""}, + {"foreign-key", `INSERT INTO parts (id, widget_id) VALUES (1, 999)`, + pgerr.CodeForeignKeyViolation, 409, + "insert or update on table violates foreign key constraint"}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -187,6 +193,14 @@ func TestMapErrorConstraintCodes(t *testing.T) { if api == nil || api.Code != c.code || api.HTTPStatus != c.status { t.Fatalf("MapError = %#v, want %s/%d", api, c.code, c.status) } + // The synthesized message is PostgreSQL's wording; the native SQLite + // text never leaks into details on any arm. + if c.message != "" && api.Message != c.message { + t.Errorf("message = %q, want %q", api.Message, c.message) + } + if api.Details != nil { + t.Errorf("details = %q, want no leaked native text", *api.Details) + } }) } } diff --git a/backend/sqlite/embed_test.go b/backend/sqlite/embed_test.go index aca38a2..7333ad0 100644 --- a/backend/sqlite/embed_test.go +++ b/backend/sqlite/embed_test.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "encoding/json" "strings" "testing" @@ -9,6 +10,7 @@ import ( "github.com/tamnd/dbrest/ir" "github.com/tamnd/dbrest/plan" "github.com/tamnd/dbrest/reqctx" + "github.com/tamnd/dbrest/schema" ) // openEmbed seeds two related tables (directors and films, with a films->directors @@ -66,6 +68,49 @@ func TestIntrospectForeignKey(t *testing.T) { } } +// TestIntrospectUniqueConstraint covers 01.8 end-to-end on SQLite: a UNIQUE +// constraint on a foreign-key column is read from PRAGMA index_list/index_info, +// recorded on the relation, and makes the reverse embed one-to-one so it renders +// as an object rather than an array. +func TestIntrospectUniqueConstraint(t *testing.T) { + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + b, err := Open(dsn) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { b.Close() }) + + _, err = b.DB().Exec(` + CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL); + CREATE TABLE profiles ( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL UNIQUE REFERENCES users(id), + bio TEXT + ); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + + model, err := b.Introspect(context.Background()) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + profiles, _ := model.Lookup("profiles", nil) + if len(profiles.Unique) != 1 || profiles.Unique[0][0] != "user_id" { + t.Fatalf("profiles.Unique = %v, want [[user_id]]", profiles.Unique) + } + + users, _ := model.Lookup("users", nil) + cands, _ := model.Relationships(users, "profiles", nil) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if cands[0].Card != schema.CardToOne { + t.Errorf("Card = %v, want to-one (user_id is unique)", cands[0].Card) + } +} + // planEmbed parses, plans, and returns the resolved query for a films read with // an embed expressed as a select string. func planEmbed(t *testing.T, b *Backend, relation, query string) *ir.Query { @@ -78,7 +123,7 @@ func planEmbed(t *testing.T, b *Backend, relation, query string) *ir.Query { if err != nil { t.Fatalf("Introspect: %v", err) } - pl, perr := plan.Read(model, q, nil) + pl, perr := plan.Read(model, q, nil, plan.Options{}) if perr != nil { t.Fatalf("plan.Read: %v", perr) } @@ -153,6 +198,92 @@ func TestExecuteEmbedToManyValue(t *testing.T) { } } +// openEmbedNull seeds directors where one (Welles) has no films, so an +// embed-existence filter has something to include and exclude. +func openEmbedNull(t *testing.T) *Backend { + t.Helper() + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + b, err := Open(dsn) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { b.Close() }) + _, err = b.DB().Exec(` + CREATE TABLE directors (id INTEGER PRIMARY KEY, name TEXT NOT NULL); + CREATE TABLE films ( + id INTEGER PRIMARY KEY, + title TEXT NOT NULL, + director_id INTEGER REFERENCES directors(id) + ); + INSERT INTO directors (id, name) VALUES (1, 'Lang'), (2, 'Scott'), (3, 'Welles'); + INSERT INTO films (id, title, director_id) VALUES + (1, 'Metropolis', 1), (2, 'Blade Runner', 2); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + return b +} + +// names pulls the name column out of a result set in row order. +func names(rows []map[string]any) []string { + out := make([]string, len(rows)) + for i, r := range rows { + out[i], _ = asString(r["name"]) + } + return out +} + +// directors?select=name,films(title)&films=not.is.null keeps only directors with +// at least one film: a semi-join over the relationship (item 01.12). +func TestExecuteEmbedNotIsNull(t *testing.T) { + b := openEmbedNull(t) + q := planEmbed(t, b, "directors", "select=name,films(title)&films=not.is.null&order=id") + got := names(execReadResolved(t, b, q)) + if len(got) != 2 || got[0] != "Lang" || got[1] != "Scott" { + t.Errorf("not.is.null directors = %v, want [Lang Scott]", got) + } +} + +// directors?...&films=is.null is the anti-join: only the director with no films. +func TestExecuteEmbedIsNull(t *testing.T) { + b := openEmbedNull(t) + q := planEmbed(t, b, "directors", "select=name,films(title)&films=is.null&order=id") + got := names(execReadResolved(t, b, q)) + if len(got) != 1 || got[0] != "Welles" { + t.Errorf("is.null directors = %v, want [Welles]", got) + } +} + +// The predicate composes under or=(...): directors with a film OR named Welles is +// everyone here, exercising the EXISTS as one disjunct alongside a column compare. +func TestExecuteEmbedNullInsideOr(t *testing.T) { + b := openEmbedNull(t) + q := planEmbed(t, b, "directors", "select=name,films(title)&or=(films.not.is.null,name.eq.Welles)&order=id") + got := names(execReadResolved(t, b, q)) + if len(got) != 3 { + t.Errorf("or= directors = %v, want all three", got) + } +} + +// A count alongside the windowed read must apply the same semi-join, so the +// total reflects only the directors that have films. +func TestExecuteEmbedNullCount(t *testing.T) { + b := openEmbedNull(t) + q := planEmbed(t, b, "directors", "select=name,films(title)&films=not.is.null") + st, perr := sqlgen.CompileCount(dialect{}, q) + if perr != nil { + t.Fatalf("CompileCount: %v", perr) + } + var n int + if err := b.DB().QueryRow(st.SQL, st.Args...).Scan(&n); err != nil { + t.Fatalf("count query: %v", err) + } + if n != 2 { + t.Errorf("count = %d, want 2", n) + } +} + // execReadResolved executes an already-planned read and returns the rows. The // query's relation reference is bound by the planner, so it runs as-is. func execReadResolved(t *testing.T, b *Backend, q *ir.Query) []map[string]any { @@ -164,3 +295,65 @@ func execReadResolved(t *testing.T, b *Backend, q *ir.Query) []map[string]any { } return readAll(t, res) } + +// TestExecuteDeclaredRecursiveEmbed covers 01.10 end-to-end: a declared computed +// relationship names one direction of a self-referential foreign key, so the +// recursive embed compiles and executes, returning each comment's children. +func TestExecuteDeclaredRecursiveEmbed(t *testing.T) { + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + b, err := Open(dsn) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { b.Close() }) + + _, err = b.DB().Exec(` + CREATE TABLE comments ( + id INTEGER PRIMARY KEY, + parent_id INTEGER REFERENCES comments(id), + body TEXT NOT NULL + ); + INSERT INTO comments (id, parent_id, body) VALUES + (1, NULL, 'root'), (2, 1, 'first reply'), (3, 1, 'second reply'); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + + model, err := b.Introspect(context.Background()) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + model.AddDeclaredRelationship(schema.DeclaredRel{ + Name: "children", + ParentSchema: "", ParentName: "comments", + TargetSchema: "", TargetName: "comments", + Card: schema.CardToMany, + Local: []string{"id"}, + Foreign: []string{"parent_id"}, + }) + + q, perr := ir.ParseRead("comments", "select=body,children:comments!children(body)&id=eq.1", nil) + if perr != nil { + t.Fatalf("ParseRead: %v", perr) + } + pl, perr := plan.Read(model, q, nil, plan.Options{}) + if perr != nil { + t.Fatalf("plan.Read: %v", perr) + } + rows := execReadResolved(t, b, pl.Query) + if len(rows) != 1 { + t.Fatalf("got %d rows, want 1", len(rows)) + } + raw, ok := asString(rows[0]["children"]) + if !ok { + t.Fatalf("children = %T, want JSON array text", rows[0]["children"]) + } + var kids []map[string]any + if err := json.Unmarshal([]byte(raw), &kids); err != nil { + t.Fatalf("children is not a JSON array: %v (%q)", err, raw) + } + if len(kids) != 2 { + t.Errorf("got %d children, want 2", len(kids)) + } +} diff --git a/backend/sqlite/embedempty_test.go b/backend/sqlite/embedempty_test.go new file mode 100644 index 0000000..3e8acfd --- /dev/null +++ b/backend/sqlite/embedempty_test.go @@ -0,0 +1,66 @@ +package sqlite + +import ( + "strings" + "testing" + + "github.com/tamnd/dbrest/backend/sqlgen" +) + +// 07.8: an empty-parenthesis embed, directors!inner(), joins the relation to +// filter the parent while projecting no embed key. With the !inner modifier and +// an embed-scoped filter, only films whose director matches survive, and the +// row carries title alone. +func TestExecuteEmptyEmbedInnerFilters(t *testing.T) { + b := openEmbed(t) + q := planEmbed(t, b, "films", "select=title,directors!inner()&directors.name=eq.Scott&order=id") + rows := execReadResolved(t, b, q) + if len(rows) != 1 { + t.Fatalf("got %d rows, want 1 (only Scott's film)", len(rows)) + } + if got, _ := asString(rows[0]["title"]); got != "Blade Runner" { + t.Errorf("title = %q, want Blade Runner", got) + } + if _, present := rows[0]["directors"]; present { + t.Errorf("row carries a directors key %#v, want it hidden", rows[0]["directors"]) + } +} + +// A left empty embed hides the key without filtering: every film comes back, and +// none of them carries the embed key. +func TestExecuteEmptyEmbedLeftHidesKey(t *testing.T) { + b := openEmbed(t) + q := planEmbed(t, b, "films", "select=title,directors()&order=id") + rows := execReadResolved(t, b, q) + if len(rows) != 3 { + t.Fatalf("got %d rows, want all 3 films", len(rows)) + } + for i, r := range rows { + if _, present := r["directors"]; present { + t.Errorf("row %d carries a directors key, want it hidden", i) + } + if _, present := r["title"]; !present { + t.Errorf("row %d missing title", i) + } + } +} + +// The compiled SQL for the !inner empty embed restricts the parent through an +// EXISTS but never projects the embed: no json_object and no AS for its key. +func TestCompileEmptyEmbedNoProjection(t *testing.T) { + b := openEmbed(t) + q := planEmbed(t, b, "films", "select=title,directors!inner()&directors.name=eq.Scott") + st, perr := sqlgen.CompileRead(dialect{}, q) + if perr != nil { + t.Fatalf("CompileRead: %v", perr) + } + if !strings.Contains(st.SQL, "EXISTS") { + t.Errorf("SQL missing the !inner EXISTS\n got: %s", st.SQL) + } + if strings.Contains(st.SQL, `AS "directors"`) { + t.Errorf("SQL projects the hidden embed key\n got: %s", st.SQL) + } + if strings.Contains(st.SQL, "json_object") { + t.Errorf("SQL assembles the hidden embed object\n got: %s", st.SQL) + } +} diff --git a/backend/sqlite/fulltext.go b/backend/sqlite/fulltext.go index dec78d5..fb20efe 100644 --- a/backend/sqlite/fulltext.go +++ b/backend/sqlite/fulltext.go @@ -19,7 +19,7 @@ import ( // (spec 21); the divergence is documented, not an error. The bound value is the // query text translated to FTS5 query syntax for the variant. The col argument is // unused: the join goes through the index's rowid, not the base column directly. -func (dialect) FullText(_ string, idx *sqlgen.FullTextRef, variant ir.FTSVariant, _, value string) (string, string, bool) { +func (dialect) FullText(_, _ string, idx *sqlgen.FullTextRef, variant ir.FTSVariant, _, value string) (string, string, bool) { if idx == nil { return "", "", false } diff --git a/backend/sqlite/fulltext_test.go b/backend/sqlite/fulltext_test.go index a982ed5..011e354 100644 --- a/backend/sqlite/fulltext_test.go +++ b/backend/sqlite/fulltext_test.go @@ -44,7 +44,7 @@ func TestFTS5QuoteEscapes(t *testing.T) { func TestFullTextLowering(t *testing.T) { ref := &sqlgen.FullTextRef{Table: `"films_fts"`, RowidRef: `"films"."id"`} - frag, bind, ok := dialect{}.FullText("", ref, ir.FTSPlain, "english", "cat") + frag, bind, ok := dialect{}.FullText("", "", ref, ir.FTSPlain, "english", "cat") if !ok { t.Fatal("FullText ok = false, want true with an index") } @@ -60,7 +60,7 @@ func TestFullTextLowering(t *testing.T) { // TestFullTextNoIndex is the missing-structure case: with no covering FTS5 table // the dialect reports ok=false so the compiler raises PGRST127 instead of scanning. func TestFullTextNoIndex(t *testing.T) { - if _, _, ok := (dialect{}).FullText("col", nil, ir.FTSPlain, "", "cat"); ok { + if _, _, ok := (dialect{}).FullText("col", "", nil, ir.FTSPlain, "", "cat"); ok { t.Error("FullText with a nil index ok = true, want false") } } @@ -177,7 +177,7 @@ func planRead(t *testing.T, b *Backend, query string) *ir.Plan { if perr != nil { t.Fatalf("ParseRead: %v", perr) } - pl, perr := plan.Read(model, q, nil) + pl, perr := plan.Read(model, q, nil, plan.Options{}) if perr != nil { t.Fatalf("plan.Read: %v", perr) } @@ -215,7 +215,7 @@ func TestFTSMissingIndexErrors(t *testing.T) { if perr != nil { t.Fatalf("ParseRead: %v", perr) } - pl, perr := plan.Read(model, q, nil) + pl, perr := plan.Read(model, q, nil, plan.Options{}) if perr != nil { t.Fatalf("plan.Read: %v", perr) } @@ -238,7 +238,7 @@ func TestRegexBackreferenceErrors(t *testing.T) { if perr != nil { t.Fatalf("ParseRead: %v", perr) } - pl, perr := plan.Read(model, q, nil) + pl, perr := plan.Read(model, q, nil, plan.Options{}) if perr != nil { t.Fatalf("plan.Read: %v", perr) } diff --git a/backend/sqlite/introspect.go b/backend/sqlite/introspect.go index 1b6ec6e..76e25c8 100644 --- a/backend/sqlite/introspect.go +++ b/backend/sqlite/introspect.go @@ -25,6 +25,8 @@ func (b *Backend) Introspect(ctx context.Context) (*schema.Model, error) { ftsByContent, excluded := classifyFTS(rels) out := make([]*schema.Relation, 0, len(rels)) + colsByName := make(map[string][]*schema.Column, len(rels)) + ddlByName := make(map[string]string, len(rels)) for _, r := range rels { if excluded[r.name] { continue @@ -37,15 +39,44 @@ func (b *Backend) Introspect(ctx context.Context) (*schema.Model, error) { if err != nil { return nil, err } + uniq, err := b.uniques(ctx, r.name) + if err != nil { + return nil, err + } + colsByName[r.name] = cols + ddlByName[r.name] = r.sql out = append(out, &schema.Relation{ Name: r.name, Kind: r.kind, Columns: cols, PrimaryKey: pk, + Unique: uniq, ForeignKeys: fks, FullText: ftsByContent[r.name], }) } + // Second pass: parse each view's definition into the base-column mapping the + // model projects foreign keys through (spec 09). It runs after the first pass + // so a view referencing a base table defined later in the catalog still + // resolves. The parser is conservative: it maps only the views it can trace to + // plain base columns and leaves the rest empty, so the model inherits nothing + // where provenance is uncertain, the same as PostgREST skips a UNION. + baseCols := func(name string) ([]string, bool) { + cols, ok := colsByName[name] + if !ok { + return nil, false + } + names := make([]string, len(cols)) + for i, c := range cols { + names[i] = c.Name + } + return names, true + } + for _, r := range out { + if r.Kind == schema.KindView { + r.ViewColumns = parseViewColumns(ddlByName[r.Name], baseCols) + } + } return schema.NewModel(out), nil } @@ -377,6 +408,98 @@ func (b *Backend) foreignKeys(ctx context.Context, table string) ([]*schema.Fore return out, nil } +// uniques reads the relation's unique constraints from PRAGMA index_list and +// index_info, returning each as a set of column names. Only constraint-backed +// indexes are returned: origin "u" (a UNIQUE table constraint) and origin "c" +// (a CREATE UNIQUE INDEX) when the index is unique. The primary key (origin +// "pk") is omitted because table_info already reports it, and the planner tests +// the primary key separately when deciding one-to-one cardinality (spec 09). +func (b *Backend) uniques(ctx context.Context, table string) ([][]string, error) { + // The table name comes from sqlite_master, not user input; quote and inline it. + rows, err := b.db.QueryContext(ctx, `PRAGMA index_list(`+dialect{}.QuoteIdent(table)+`)`) + if err != nil { + return nil, err + } + type idxInfo struct { + name string + origin string + } + var indexes []idxInfo + for rows.Next() { + var ( + seq, unique, partial int + name, origin string + ) + if err := rows.Scan(&seq, &name, &unique, &origin, &partial); err != nil { + rows.Close() + return nil, err + } + // A partial unique index does not constrain the whole column, so it cannot + // make a foreign key one-to-one; skip it, as PostgREST does. + if unique == 1 && origin != "pk" && partial == 0 { + indexes = append(indexes, idxInfo{name: name, origin: origin}) + } + } + if err := rows.Err(); err != nil { + rows.Close() + return nil, err + } + rows.Close() + + var out [][]string + for _, idx := range indexes { + cols, err := b.indexColumns(ctx, idx.name) + if err != nil { + return nil, err + } + if len(cols) > 0 { + out = append(out, cols) + } + } + return out, nil +} + +// indexColumns reads the column names of one index from PRAGMA index_info, in +// key order. A NULL column name marks an expression index column, which cannot +// participate in a foreign-key match, so such an index is dropped by returning +// no columns for it. +func (b *Backend) indexColumns(ctx context.Context, index string) ([]string, error) { + rows, err := b.db.QueryContext(ctx, `PRAGMA index_info(`+dialect{}.QuoteIdent(index)+`)`) + if err != nil { + return nil, err + } + defer rows.Close() + + type entry struct { + seq int + name string + } + var entries []entry + for rows.Next() { + var ( + seqno, cid int + name any + ) + if err := rows.Scan(&seqno, &cid, &name); err != nil { + return nil, err + } + s, ok := toString(name) + if !ok { + return nil, nil // expression column: not usable for a key match + } + entries = append(entries, entry{seq: seqno, name: s}) + } + if err := rows.Err(); err != nil { + return nil, err + } + sort.Slice(entries, func(i, j int) bool { return entries[i].seq < entries[j].seq }) + cols := make([]string, len(entries)) + for i, e := range entries { + cols[i] = e.name + } + return cols, nil +} + // toString coerces a scalar from PRAGMA into a string, reporting false for NULL. func toString(v any) (string, bool) { switch s := v.(type) { diff --git a/backend/sqlite/isunknown_test.go b/backend/sqlite/isunknown_test.go new file mode 100644 index 0000000..f108528 --- /dev/null +++ b/backend/sqlite/isunknown_test.go @@ -0,0 +1,87 @@ +package sqlite + +import ( + "context" + "testing" + + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/reqctx" +) + +// openFlags seeds a table with a boolean-style column (SQLite stores 0/1/NULL) +// and a text column that literally holds the word "true", so the type-driven +// coercion of item 07.4 can be exercised against a real engine. +func openFlags(t *testing.T) *Backend { + t.Helper() + b := openSeeded(t) + _, err := b.DB().Exec(` + CREATE TABLE flags ( + id INTEGER PRIMARY KEY, + done INTEGER, + label TEXT + ); + INSERT INTO flags (id, done, label) VALUES + (1, 1, 'true'), + (2, 0, 'false'), + (3, NULL, 'unset'); + `) + if err != nil { + t.Fatalf("seed flags: %v", err) + } + return b +} + +func flagIDs(t *testing.T, b *Backend, where ir.Cond) []any { + t.Helper() + q := &ir.Query{ + Relation: ir.Ref{Name: "flags"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"id"}}}, + Order: []ir.OrderTerm{{Path: []string{"id"}}}, + Where: &where, + } + plan := &ir.Plan{Query: q, ReadOnly: true} + rc := &reqctx.Context{Role: "anon", Method: "GET"} + res, err := b.Execute(context.Background(), plan, rc) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rows := readAll(t, res) + ids := make([]any, 0, len(rows)) + for _, r := range rows { + ids = append(ids, r["id"]) + } + return ids +} + +// 07.4 task 1: is.unknown matches the NULL row through the "col IS NULL" +// fallback SQLite uses. +func TestSQLiteIsUnknownMatchesNull(t *testing.T) { + b := openFlags(t) + ids := flagIDs(t, b, ir.Compare{Path: []string{"done"}, Op: ir.OpIs, Value: ir.Value{Text: "unknown"}}) + if len(ids) != 1 || ids[0].(int64) != 3 { + t.Errorf("is.unknown matched %v, want [3]", ids) + } +} + +// 07.4 task 2: eq.true against the boolean column matches the 1 row. +func TestSQLiteEqTrueBooleanColumn(t *testing.T) { + b := openFlags(t) + ids := flagIDs(t, b, ir.Compare{ + Path: []string{"done"}, Op: ir.OpEq, ColumnType: "bool", Value: ir.Value{Text: "true"}, + }) + if len(ids) != 1 || ids[0].(int64) != 1 { + t.Errorf("done=eq.true matched %v, want [1]", ids) + } +} + +// 07.4 task 2: eq.true against the text column matches the row literally holding +// the word "true", not a coerced boolean. +func TestSQLiteEqTrueTextColumn(t *testing.T) { + b := openFlags(t) + ids := flagIDs(t, b, ir.Compare{ + Path: []string{"label"}, Op: ir.OpEq, ColumnType: "text", Value: ir.Value{Text: "true"}, + }) + if len(ids) != 1 || ids[0].(int64) != 1 { + t.Errorf("label=eq.true matched %v, want [1] (the row holding the word)", ids) + } +} diff --git a/backend/sqlite/jsonpath_test.go b/backend/sqlite/jsonpath_test.go new file mode 100644 index 0000000..645608d --- /dev/null +++ b/backend/sqlite/jsonpath_test.go @@ -0,0 +1,124 @@ +package sqlite + +import ( + "context" + "encoding/json" + "testing" + + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/plan" + "github.com/tamnd/dbrest/reqctx" +) + +// openDocs seeds a table with a JSON column holding a nested document, so the +// 07.1 JSON-path lowering can be exercised against the real SQLite engine. +func openDocs(t *testing.T) *Backend { + t.Helper() + b := openSeeded(t) + _, err := b.DB().Exec(` + CREATE TABLE docs (id INTEGER PRIMARY KEY, data JSON); + INSERT INTO docs (id, data) VALUES + (1, '{"blood_type":"A-","phones":[{"number":"555"}],"meta":{"k":1}}'), + (2, '{"blood_type":"O+","phones":[{"number":"999"}],"meta":{"k":2}}'), + (3, '{"blood_type":"A-","flag":"true","phones":[],"meta":{"k":3}}'); + `) + if err != nil { + t.Fatalf("seed docs: %v", err) + } + return b +} + +func planDocs(t *testing.T, b *Backend, query string) *ir.Query { + t.Helper() + q, perr := ir.ParseRead("docs", query, nil) + if perr != nil { + t.Fatalf("ParseRead: %v", perr) + } + model, err := b.Introspect(context.Background()) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + pl, perr := plan.Read(model, q, nil, plan.Options{}) + if perr != nil { + t.Fatalf("plan.Read: %v", perr) + } + return pl.Query +} + +func runDocs(t *testing.T, b *Backend, q *ir.Query) []map[string]any { + t.Helper() + pl := &ir.Plan{Query: q, ReadOnly: true} + res, err := b.Execute(context.Background(), pl, &reqctx.Context{Role: "anon"}) + if err != nil { + t.Fatalf("Execute: %v", err) + } + return readAll(t, res) +} + +// A ->> filter selects rows by the text scalar at the path. +func TestJSONPathFilterText(t *testing.T) { + b := openDocs(t) + rows := runDocs(t, b, planDocs(t, b, "select=id&data->>blood_type=eq.A-&order=id")) + if len(rows) != 2 || rows[0]["id"].(int64) != 1 || rows[1]["id"].(int64) != 3 { + t.Errorf("rows = %v, want ids [1 3]", rows) + } +} + +// A ->> filter reaches through an array index and a nested key. +func TestJSONPathFilterArrayIndex(t *testing.T) { + b := openDocs(t) + rows := runDocs(t, b, planDocs(t, b, "select=id&data->phones->0->>number=eq.999")) + if len(rows) != 1 || rows[0]["id"].(int64) != 2 { + t.Errorf("rows = %v, want id [2]", rows) + } +} + +// A ->> projection returns the text scalar under the last hop's name. +func TestJSONPathProjectionText(t *testing.T) { + b := openDocs(t) + rows := runDocs(t, b, planDocs(t, b, "select=id,data->>blood_type&order=id")) + if len(rows) != 3 { + t.Fatalf("rows = %d, want 3", len(rows)) + } + if got, _ := asString(rows[0]["blood_type"]); got != "A-" { + t.Errorf("row0 blood_type = %q, want A-", got) + } +} + +// A -> projection returns the engine's JSON text for the path. The backend +// surfaces it as the JSON string {"k":1}; the renderer splices it verbatim +// (proved at the httpapi layer, where embedKeys flags the -> column). +func TestJSONPathProjectionJSON(t *testing.T) { + b := openDocs(t) + rows := runDocs(t, b, planDocs(t, b, "select=id,data->meta&order=id&id=eq.1")) + if len(rows) != 1 { + t.Fatalf("rows = %d, want 1", len(rows)) + } + got, _ := asString(rows[0]["meta"]) + var m map[string]any + if err := json.Unmarshal([]byte(got), &m); err != nil { + t.Fatalf("meta not JSON: %v (%q)", err, got) + } + if m["k"] != float64(1) { + t.Errorf("meta.k = %v, want 1", m["k"]) + } +} + +// eq.true on a ->> extract compares the literal word: the row whose JSON flag +// holds the string "true" matches, proving the access is text, not boolean. +func TestJSONPathEqTrueIsText(t *testing.T) { + b := openDocs(t) + rows := runDocs(t, b, planDocs(t, b, "select=id&data->>flag=eq.true")) + if len(rows) != 1 || rows[0]["id"].(int64) != 3 { + t.Errorf("rows = %v, want id [3]", rows) + } +} + +// Ordering by a ->> extract sorts by the path's text value. +func TestJSONPathOrder(t *testing.T) { + b := openDocs(t) + rows := runDocs(t, b, planDocs(t, b, "select=id&order=data->meta->>k.desc")) + if len(rows) != 3 || rows[0]["id"].(int64) != 3 || rows[2]["id"].(int64) != 1 { + t.Errorf("rows = %v, want ids [3 2 1]", rows) + } +} diff --git a/backend/sqlite/like_test.go b/backend/sqlite/like_test.go new file mode 100644 index 0000000..1f2e9de --- /dev/null +++ b/backend/sqlite/like_test.go @@ -0,0 +1,40 @@ +package sqlite + +import "testing" + +// SQLite's LIKE folds ASCII case by default, which would make the like operator +// silently case-insensitive and return different rows than PostgreSQL. The pool +// sets PRAGMA case_sensitive_like = ON to fix that. A lowercase pattern must not +// match the title-cased row through like. Finding 01-M08. +func TestLikeIsCaseSensitive(t *testing.T) { + b := openSeeded(t) + pl := planRead(t, b, "title=like.blade*") + rows := execRead(t, b, pl.Query) + if len(rows) != 0 { + t.Fatalf("like.blade* matched %d rows, want 0 (case-sensitive); got %v", len(rows), rows) + } + + pl = planRead(t, b, "title=like.Blade*") + rows = execRead(t, b, pl.Query) + if len(rows) != 1 { + t.Fatalf("like.Blade* matched %d rows, want 1", len(rows)) + } + if got := rows[0]["title"]; got != "Blade Runner" { + t.Fatalf("like.Blade* matched %v, want Blade Runner", got) + } +} + +// ilike stays case-insensitive even though case_sensitive_like is ON, because +// the dialect folds both sides with lower(). A lowercase pattern matches the +// title-cased row. Finding 01-M08. +func TestILikeIsCaseInsensitive(t *testing.T) { + b := openSeeded(t) + pl := planRead(t, b, "title=ilike.blade*") + rows := execRead(t, b, pl.Query) + if len(rows) != 1 { + t.Fatalf("ilike.blade* matched %d rows, want 1", len(rows)) + } + if got := rows[0]["title"]; got != "Blade Runner" { + t.Fatalf("ilike.blade* matched %v, want Blade Runner", got) + } +} diff --git a/backend/sqlite/result.go b/backend/sqlite/result.go index 622e744..f1c09df 100644 --- a/backend/sqlite/result.go +++ b/backend/sqlite/result.go @@ -2,7 +2,9 @@ package sqlite import ( "database/sql" + "encoding/json" "io" + "strings" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/reqctx" @@ -36,6 +38,8 @@ type writeResult struct { rows [][]any affected int64 hasAff bool + count int64 + hasCount bool controls *reqctx.ResponseControls } @@ -43,7 +47,7 @@ func (r *writeResult) Body() io.Reader { return nil } func (r *writeResult) Rows() backend.RowStream { return &bufStream{cols: r.cols, rows: r.rows, i: -1} } -func (r *writeResult) Count() (int64, bool) { return 0, false } +func (r *writeResult) Count() (int64, bool) { return r.count, r.hasCount } func (r *writeResult) Affected() (int64, bool) { return r.affected, r.hasAff } func (r *writeResult) ResponseControls() *reqctx.ResponseControls { return r.controls } @@ -64,8 +68,9 @@ func (s *bufStream) Close() error { return nil } // rowStream is a forward-only cursor over the result rows. Values decode each // row into a []any the renderer maps to JSON by column name. type rowStream struct { - rows *sql.Rows - cols []string + rows *sql.Rows + cols []string + colTypes []*sql.ColumnType // lazily populated on first call to Values } func (s *rowStream) Columns() []string { return s.cols } @@ -74,9 +79,19 @@ func (s *rowStream) Err() error { return s.rows.Err() } func (s *rowStream) Close() error { return s.rows.Close() } // Values scans the current row into Go values. SQLite returns int64, float64, -// string, []byte, or nil; []byte is normalized to string so text columns render -// as JSON strings rather than base64. +// string, []byte, or nil. Post-scan coercions: +// - []byte → string so text columns render as JSON strings rather than base64. +// - BOOLEAN/BOOL declared columns: int64 0/1 → false/true so JSON marshals +// correctly as false/true rather than 0/1. +// - JSON declared columns: string → json.RawMessage so the JSON encoder embeds +// the value verbatim rather than quoting it as a string. func (s *rowStream) Values() ([]any, error) { + if s.colTypes == nil { + ct, err := s.rows.ColumnTypes() + if err == nil { + s.colTypes = ct + } + } holders := make([]any, len(s.cols)) ptrs := make([]any, len(s.cols)) for i := range holders { @@ -87,7 +102,20 @@ func (s *rowStream) Values() ([]any, error) { } for i, v := range holders { if b, ok := v.([]byte); ok { - holders[i] = string(b) + v = string(b) + holders[i] = v + } + if s.colTypes != nil && i < len(s.colTypes) { + switch strings.ToUpper(s.colTypes[i].DatabaseTypeName()) { + case "BOOLEAN", "BOOL": + if n, ok := v.(int64); ok { + holders[i] = n != 0 + } + case "JSON": + if str, ok := v.(string); ok && json.Valid([]byte(str)) { + holders[i] = json.RawMessage(str) + } + } } } return holders, nil diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index e817051..6b55bd2 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -4,9 +4,11 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "errors" "fmt" "regexp" + "strings" sqlitedrv "modernc.org/sqlite" sqlite3 "modernc.org/sqlite/lib" @@ -68,6 +70,21 @@ func Open(dsn string) (*Backend, error) { if err != nil { return nil, err } + // SQLite does not enforce FK constraints by default, and its LIKE folds ASCII + // case by default, which makes the like operator silently case-insensitive + // unlike PostgreSQL. Pin to one connection so both PRAGMAs stay in effect for + // the lifetime of the pool. + db.SetMaxOpenConns(1) + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + db.Close() + return nil, err + } + // case_sensitive_like = ON makes the like operator case-sensitive to match + // PostgreSQL; ilike folds case explicitly in the dialect (lower() LIKE lower()). + if _, err := db.Exec("PRAGMA case_sensitive_like = ON"); err != nil { + db.Close() + return nil, err + } if err := db.Ping(); err != nil { db.Close() return nil, err @@ -122,8 +139,13 @@ func (b *Backend) Close() error { return b.db.Close() } // MapError turns a driver error into the unified envelope. A SQLite constraint // violation maps to the PostgreSQL SQLSTATE PostgREST would report (so clients -// see the same code on every backend) with the matching HTTP status; anything -// else is surfaced as internal. +// see the same code on every backend) with the matching HTTP status, and to a +// PG-shaped message synthesized from what SQLite reports. SQLite names the +// relation and column in its NOT NULL and UNIQUE text, so those reconstruct +// PostgreSQL's own wording; it gives no constraint name for a unique key and no +// offending value, so neither is invented (an emulation limitation, not a +// fabricated wire contract). The native text is never leaked into details. +// Anything else is surfaced as internal. func (b *Backend) MapError(err error) *pgerr.APIError { if err == nil { return nil @@ -132,21 +154,57 @@ func (b *Backend) MapError(err error) *pgerr.APIError { // The primary result code is the low byte; the rest is the extended code. switch se.Code() { case sqlite3.SQLITE_CONSTRAINT_UNIQUE, sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY: - return pgerr.ErrUniqueViolation(se.Error()) + return pgerr.ErrConstraintViolation(pgerr.CodeUniqueViolation, + "duplicate key value violates unique constraint", "", "") case sqlite3.SQLITE_CONSTRAINT_NOTNULL: - return pgerr.ErrNotNullViolation(se.Error()) + return pgerr.ErrConstraintViolation(pgerr.CodeNotNullViolation, + notNullMessage(se.Error()), "", "") case sqlite3.SQLITE_CONSTRAINT_FOREIGNKEY: - return pgerr.ErrForeignKeyViolation(se.Error()) + return pgerr.ErrConstraintViolation(pgerr.CodeForeignKeyViolation, + "insert or update on table violates foreign key constraint", "", "") case sqlite3.SQLITE_CONSTRAINT_CHECK: - return pgerr.ErrCheckViolation(se.Error()) + return pgerr.ErrConstraintViolation(pgerr.CodeCheckViolation, + checkMessage(se.Error()), "", "") } if se.Code()&0xff == sqlite3.SQLITE_CONSTRAINT { - return pgerr.ErrCheckViolation(se.Error()) + return pgerr.ErrConstraintViolation(pgerr.CodeCheckViolation, + "new row violates check constraint", "", "") } } return pgerr.ErrInternal(err.Error()) } +// constraintTarget matches the "table.column" SQLite names after the colon in a +// constraint failure ("NOT NULL constraint failed: films.title"). +var constraintTarget = regexp.MustCompile(`([^\s,.]+)\.([^\s,.]+)`) + +// notNullMessage reconstructs PostgreSQL's not-null wording from SQLite's "NOT +// NULL constraint failed: relation.column" text. PostgreSQL reports `null value +// in column "c" of relation "t" violates not-null constraint`; SQLite supplies +// both names, so the message matches verbatim. When the text does not parse the +// generic message stands. +func notNullMessage(text string) string { + if m := constraintTarget.FindStringSubmatch(text); m != nil { + return fmt.Sprintf( + "null value in column %q of relation %q violates not-null constraint", + m[2], m[1]) + } + return "null value violates not-null constraint" +} + +// checkMessage reconstructs PostgreSQL's check wording from SQLite's "CHECK +// constraint failed: name" text. PostgreSQL names the constraint (`new row for +// relation "t" violates check constraint "c"`); SQLite gives only the +// constraint name (or the expression for an anonymous check), so the name rides +// through when present and the generic message stands otherwise. +func checkMessage(text string) string { + const prefix = "CHECK constraint failed: " + if name := strings.TrimPrefix(text, prefix); name != text && name != "" { + return fmt.Sprintf("new row violates check constraint %q", name) + } + return "new row violates check constraint" +} + // Execute lowers a resolved plan to SQLite operations and returns a streamable // result. Reads stream from an open cursor; writes run in a short transaction // and buffer their returned rows. RPC arrives with its subsystem. @@ -173,7 +231,7 @@ func (b *Backend) Execute(ctx context.Context, plan *ir.Plan, rc *reqctx.Context // (or roll back under Prefer: tx=rollback). The returned rows carry the // function's output for the renderer to shape by return kind. func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - st, apiErr := sqlgen.CompileCall(dialect{}, plan.Call, plan.Func) + st, apiErr := sqlgen.CompileCall(dialect{}, plan.Call, plan.Func, sqlgen.ContextArgs(rc)) if apiErr != nil { return nil, apiErr } @@ -182,7 +240,7 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con res := &result{controls: rc.Controls()} // A count over a read-only function runs as its own statement, like a read. if plan.Call.Count != ir.CountNone { - cst, apiErr := sqlgen.CompileCallCount(dialect{}, plan.Call, plan.Func) + cst, apiErr := sqlgen.CompileCallCount(dialect{}, plan.Call, plan.Func, sqlgen.ContextArgs(rc)) if apiErr != nil { return nil, apiErr } @@ -200,6 +258,21 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con rows.Close() return nil, b.MapError(err) } + // A portable function that steers the response projects reserved + // response-control columns. Buffer only then, so the common streaming + // path is untouched, lift the controls out, and strip them from the body. + if backend.HasResponseControlCols(cols) { + buf, err := drain(rows, len(cols)) + rows.Close() + if err != nil { + return nil, b.MapError(err) + } + cols, buf, apiErr := backend.LiftResponseControls(cols, buf, res.controls) + if apiErr != nil { + return nil, apiErr + } + return &writeResult{cols: cols, rows: buf, count: res.count, hasCount: res.hasCount, controls: res.controls}, nil + } res.rows, res.cols = rows, cols return res, nil } @@ -225,7 +298,17 @@ func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Con return nil, b.MapError(err) } - res := &writeResult{cols: cols, rows: buf, controls: rc.Controls()} + // A volatile function steers the response the same way a read-only one does: + // reserved columns lift into the controls and drop out of the body. An invalid + // status or header set fails the call before commit, so the deferred rollback + // discards the mutation. + controls := rc.Controls() + cols, buf, apiErr = backend.LiftResponseControls(cols, buf, controls) + if apiErr != nil { + return nil, apiErr + } + + res := &writeResult{cols: cols, rows: buf, controls: controls} if plan.Call.Prefer.Tx != nil && *plan.Call.Prefer.Tx == ir.TxRollback { return res, nil } @@ -278,6 +361,19 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co q := plan.Query returning := returningCols(q, plan.Rel) + // An empty column set (POST with an empty array, PATCH with an empty object) + // is a no-op: nothing is compiled or run, the affected count is zero, and the + // representation is the empty array. The HTTP layer turns that into 201/[] for + // an insert and 204 or 200/[] for an update. + if backend.IsNoOpMutation(q) { + return &writeResult{ + controls: rc.Controls(), + cols: returning, + affected: 0, + hasAff: true, + }, nil + } + st, apiErr := compileWrite(q, returning) if apiErr != nil { return nil, apiErr @@ -292,6 +388,19 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co defer func() { _ = tx.Rollback() }() res := &writeResult{controls: rc.Controls()} + + // An upsert's 200-vs-201 status turns on whether any row updated an existing + // one. SQLite has no xmax to read back (the PostgreSQL signal), so before the + // write we check, in the same transaction, whether any payload row's + // conflict-target key already exists; if none does the upsert is all-insert. + if q.Kind == ir.Upsert { + if inserted, ok, derr := detectUpsertInsert(ctx, tx, q, plan.Rel); derr != nil { + return nil, b.MapError(derr) + } else if ok { + res.controls.UpsertStatusKnown = true + res.controls.InsertedRows = inserted + } + } if len(returning) > 0 { rows, err := tx.QueryContext(ctx, st.SQL, st.Args...) if err != nil { @@ -307,8 +416,11 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co if err != nil { return nil, b.MapError(err) } - res.cols, res.rows = cols, buf + // The affected count is the full mutated set, taken before the + // representation is shaped: order/limit/offset bound only the returned + // body, not the mutation (v13 dropped limited update/delete). res.affected, res.hasAff = int64(len(buf)), true + res.cols, res.rows = cols, backend.ShapeWriteRepresentation(cols, buf, q) } else { out, err := tx.ExecContext(ctx, st.SQL, st.Args...) if err != nil { @@ -318,6 +430,18 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co res.affected, res.hasAff = n, true } + // Prefer: max-affected rolls an over-broad write back instead of committing. + if apiErr := backend.EnforceMaxAffected(q.Write, res.affected, res.hasAff); apiErr != nil { + return nil, apiErr + } + + // A singular write (vnd.pgrst.object+json) that touched zero or many rows + // fails closed before commit, so the deferred rollback discards it rather + // than the renderer rejecting an already-durable mutation. + if apiErr := backend.EnforceSingularWrite(q.Singular, res.affected, res.hasAff); apiErr != nil { + return nil, apiErr + } + // Prefer: tx=rollback returns the computed representation but discards the // work; leaving the transaction for the deferred rollback does exactly that. if q.Write != nil && q.Write.Tx == ir.TxRollback { @@ -329,6 +453,74 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co return res, nil } +// detectUpsertInsert counts how many of the payload rows the upsert will insert +// as new (those whose conflict-target key does not already exist) so the HTTP +// layer can choose 200 vs 201. It runs inside the write transaction, before the +// upsert statement, and returns ok=false when the target columns are unknown (no +// explicit on_conflict and no primary key), leaving the status to the default. +// The conflict target defaults to the relation's primary key, matching the +// upsert's own ON CONFLICT. +func detectUpsertInsert(ctx context.Context, tx *sql.Tx, q *ir.Query, rel *schema.Relation) (inserted int, ok bool, err error) { + if q.Write == nil || len(q.Write.Rows) == 0 { + return 0, false, nil + } + // Only merge-duplicates can turn into an update; an ignore-duplicates upsert + // (ON CONFLICT DO NOTHING) is a no-op insert on a conflict, which PostgreSQL + // reports through RETURNING as all-insert and PostgREST renders as 201. So a + // PUT (no Conflict spec) and a merge upsert run detection; an ignore upsert + // keeps the 201 default. + if q.Write.Conflict != nil && q.Write.Conflict.Resolution == ir.ConflictIgnore { + return 0, false, nil + } + target := rel.PrimaryKey + if q.Write.Conflict != nil && len(q.Write.Conflict.Target) > 0 { + target = q.Write.Conflict.Target + } + if len(target) == 0 { + return 0, false, nil + } + + d := dialect{} + var where strings.Builder + for i, c := range target { + if i > 0 { + where.WriteString(" AND ") + } + where.WriteString(d.QuoteIdent(c)) + where.WriteString(" = ?") + } + query := "SELECT 1 FROM " + d.QuoteIdent(rel.Name) + " WHERE " + where.String() + " LIMIT 1" + + for _, row := range q.Write.Rows { + args := make([]any, len(target)) + for i, c := range target { + // A payload missing a key column cannot match an existing row by it; + // treat that row as an insert and move on. + v, present := row[c] + if !present { + args = nil + break + } + args[i] = sqlgen.WriteArg(d, v, q.Write.ColumnTypes[c]) + } + if args == nil { + inserted++ + continue + } + var dummy int + switch scanErr := tx.QueryRowContext(ctx, query, args...).Scan(&dummy); scanErr { + case nil: + // This row matches an existing key: an ON CONFLICT update, not an insert. + case sql.ErrNoRows: + // No existing row: this one is a new insert. + inserted++ + default: + return 0, false, scanErr + } + } + return inserted, true, nil +} + // compileWrite dispatches to the right compiler for the mutation kind. func compileWrite(q *ir.Query, returning []string) (*sqlgen.Statement, *pgerr.APIError) { switch q.Kind { @@ -349,6 +541,9 @@ func compileWrite(q *ir.Query, returning []string) (*sqlgen.Statement, *pgerr.AP // nothing and runs as a plain affected-rows statement. func returningCols(q *ir.Query, rel *schema.Relation) []string { if q.Write != nil && q.Write.Return == ir.ReturnRepresentation { + if cols := q.ProjectedColumns(); cols != nil { + return cols + } return rel.ColumnNames() } if q.Kind == ir.Insert || q.Kind == ir.Upsert { @@ -357,9 +552,11 @@ func returningCols(q *ir.Query, rel *schema.Relation) []string { return nil } -// drain reads every row of a returning cursor into memory, normalizing []byte to -// string so text columns render as JSON strings. +// drain reads every row of a returning cursor into memory, applying the same +// type coercions as rowStream.Values: []byte→string, BOOLEAN int64→bool, +// JSON string→json.RawMessage. func drain(rows *sql.Rows, ncols int) ([][]any, error) { + colTypes, _ := rows.ColumnTypes() var out [][]any for rows.Next() { holders := make([]any, ncols) @@ -372,7 +569,20 @@ func drain(rows *sql.Rows, ncols int) ([][]any, error) { } for i, v := range holders { if bs, ok := v.([]byte); ok { - holders[i] = string(bs) + v = string(bs) + holders[i] = v + } + if colTypes != nil && i < len(colTypes) { + switch strings.ToUpper(colTypes[i].DatabaseTypeName()) { + case "BOOLEAN", "BOOL": + if n, ok := v.(int64); ok { + holders[i] = n != 0 + } + case "JSON": + if str, ok := v.(string); ok && json.Valid([]byte(str)) { + holders[i] = json.RawMessage(str) + } + } } } out = append(out, holders) diff --git a/backend/sqlite/sqlite_test.go b/backend/sqlite/sqlite_test.go index 1077367..5b51f0b 100644 --- a/backend/sqlite/sqlite_test.go +++ b/backend/sqlite/sqlite_test.go @@ -415,6 +415,73 @@ func TestMapErrorUniqueViolation(t *testing.T) { if api == nil || api.Code != pgerr.CodeUniqueViolation || api.HTTPStatus != 409 { t.Fatalf("err = %#v, want 23505/409", api) } + // The message is PostgreSQL's wording, not SQLite's native text, and the + // native "UNIQUE constraint failed" string never leaks into details. + if api.Message != "duplicate key value violates unique constraint" { + t.Errorf("message = %q, want PG unique wording", api.Message) + } + if api.Details != nil { + t.Errorf("details = %q, want no leaked native text", *api.Details) + } +} + +// A NOT NULL violation reconstructs PostgreSQL's exact wording from the +// relation and column SQLite names in its error text. +func TestMapErrorNotNullViolation(t *testing.T) { + b := openSeeded(t) + pl := &ir.Plan{Query: &ir.Query{ + Kind: ir.Insert, + Relation: ir.Ref{Name: "films"}, + Write: &ir.WriteSpec{ + Return: ir.ReturnMinimal, + Columns: []string{"id", "title"}, + Rows: []map[string]ir.Value{{"id": ir.Value{JSON: json.Number("9")}, "title": ir.Value{JSON: nil}}}, + }, + }} + rel, _ := mustModel(t, b).Lookup("films", nil) + pl.Rel = rel + _, err := b.Execute(context.Background(), pl, &reqctx.Context{Role: "anon"}) + if err == nil { + t.Fatal("want a constraint error") + } + api := pgerr.As(err) + if api == nil || api.Code != pgerr.CodeNotNullViolation || api.HTTPStatus != 400 { + t.Fatalf("err = %#v, want 23502/400", api) + } + want := `null value in column "title" of relation "films" violates not-null constraint` + if api.Message != want { + t.Errorf("message = %q, want %q", api.Message, want) + } + if api.Details != nil { + t.Errorf("details = %q, want no leaked native text", *api.Details) + } +} + +// The synthesis helpers reconstruct PG wording from SQLite's text directly, +// including the constraint name for a CHECK and a graceful fallback when the +// text does not parse. +func TestConstraintMessageSynthesis(t *testing.T) { + cases := []struct { + name string + got string + want string + }{ + {"notnull", notNullMessage("NOT NULL constraint failed: films.title"), + `null value in column "title" of relation "films" violates not-null constraint`}, + {"notnull-unparsed", notNullMessage("garbage"), + "null value violates not-null constraint"}, + {"check-named", checkMessage("CHECK constraint failed: rating_valid"), + `new row violates check constraint "rating_valid"`}, + {"check-bare", checkMessage("CHECK constraint failed: "), + "new row violates check constraint"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if c.got != c.want { + t.Errorf("got %q, want %q", c.got, c.want) + } + }) + } } func TestExecuteNullsOrdering(t *testing.T) { diff --git a/backend/sqlite/viewparse.go b/backend/sqlite/viewparse.go new file mode 100644 index 0000000..8630f10 --- /dev/null +++ b/backend/sqlite/viewparse.go @@ -0,0 +1,255 @@ +package sqlite + +import ( + "strings" + + "github.com/tamnd/dbrest/schema" +) + +// parseViewColumns traces a SQLite view's output columns back to the base-table +// columns they project, so the schema model can carry the base table's foreign +// keys onto the view (spec 09). SQLite keeps no column-provenance catalog, so the +// CREATE VIEW text from sqlite_master is parsed here. +// +// The parser is deliberately conservative and recognizes only the shape whose +// provenance is unambiguous: a single base relation in the FROM clause, no join, +// no set operation, and a select list of bare or aliased column references. Any +// expression column, function call, join, or UNION yields no mapping, so the +// model inherits nothing rather than guessing, the same way PostgREST skips a +// view it cannot resolve. baseCols returns a base relation's column names, used +// to expand `SELECT *`. +func parseViewColumns(ddl string, baseCols func(name string) ([]string, bool)) []schema.ViewColumn { + sel, ok := viewSelectBody(ddl) + if !ok { + return nil + } + // A set operation (UNION, INTERSECT, EXCEPT) makes provenance ambiguous. + if hasTopLevelKeyword(sel, "union", "intersect", "except") { + return nil + } + listText, fromText, ok := splitSelectFrom(sel) + if !ok { + return nil + } + base, ok := singleBaseTable(fromText) + if !ok { + return nil + } + + items := splitArgs(listText) + var out []schema.ViewColumn + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + if item == "*" || strings.HasSuffix(item, ".*") { + cols, ok := baseCols(base) + if !ok { + return nil + } + for _, c := range cols { + out = append(out, schema.ViewColumn{Name: c, BaseRelation: base, BaseColumn: c}) + } + continue + } + vc, ok := parseSelectColumn(item, base) + if !ok { + return nil // an expression column: do not project this view + } + out = append(out, vc) + } + return out +} + +// viewSelectBody extracts the SELECT body of a CREATE VIEW statement, dropping +// the "CREATE VIEW name AS" prefix. It reports ok=false for any other DDL. +func viewSelectBody(ddl string) (string, bool) { + low := strings.ToLower(ddl) + if !strings.Contains(low, "create") || !strings.Contains(low, "view") { + return "", false + } + as := indexWord(low, "as") + if as < 0 { + return "", false + } + return strings.TrimSpace(ddl[as+2:]), true +} + +// splitSelectFrom splits a SELECT body into its select list and FROM clause. It +// drops a leading SELECT (and DISTINCT) and cuts at the top-level FROM keyword, +// reporting ok=false when there is no FROM. +func splitSelectFrom(sel string) (list, from string, ok bool) { + low := strings.ToLower(sel) + if !strings.HasPrefix(low, "select") { + return "", "", false + } + sel = strings.TrimSpace(sel[len("select"):]) + if low := strings.ToLower(sel); strings.HasPrefix(low, "distinct") { + sel = strings.TrimSpace(sel[len("distinct"):]) + } + at := topLevelKeyword(sel, "from") + if at < 0 { + return "", "", false + } + list = strings.TrimSpace(sel[:at]) + from = strings.TrimSpace(sel[at+len("from"):]) + return list, from, true +} + +// singleBaseTable returns the lone base relation named in a FROM clause, or +// ok=false when the clause has a join, a comma, a subquery, or trailing clauses +// the parser will not reason about (WHERE, GROUP BY, and the rest). +func singleBaseTable(from string) (string, bool) { + // Cut anything after the table reference: a WHERE/GROUP/ORDER/LIMIT tail. + for _, kw := range []string{"where", "group", "order", "limit", "having", "window"} { + if at := topLevelKeyword(from, kw); at >= 0 { + from = strings.TrimSpace(from[:at]) + } + } + if from == "" { + return "", false + } + // A join or a comma-separated list is more than one base relation. + if strings.Contains(strings.ToLower(from), " join ") || strings.ContainsAny(from, ",(") { + return "", false + } + fields := strings.Fields(from) + // Accept "base" or "base alias"; reject "base AS alias" forms beyond two words + // only when they introduce something other than a plain alias. + if len(fields) == 0 { + return "", false + } + return unquoteIdent(fields[0]), true +} + +// parseSelectColumn parses one select-list item that is a bare or aliased column +// reference, returning the view column to base column mapping. It reports +// ok=false for an expression (a function call, an operator, a literal), which the +// caller treats as a reason not to project the view. +func parseSelectColumn(item, base string) (schema.ViewColumn, bool) { + // Split off an alias: "expr AS name" or "expr name". + expr, alias := splitColumnAlias(item) + // The expression must be a plain column reference: an identifier, optionally + // qualified by a table. Anything with an operator, call, or literal is out. + if !isColumnRef(expr) { + return schema.ViewColumn{}, false + } + baseCol := expr + if dot := strings.LastIndexByte(expr, '.'); dot >= 0 { + baseCol = expr[dot+1:] + } + baseCol = unquoteIdent(strings.TrimSpace(baseCol)) + name := baseCol + if alias != "" { + name = unquoteIdent(alias) + } + return schema.ViewColumn{Name: name, BaseRelation: base, BaseColumn: baseCol}, true +} + +// splitColumnAlias separates a select item into its expression and column alias. +// It handles "expr AS alias" and the bare "expr alias" form, and returns an empty +// alias when the item is a single token. +func splitColumnAlias(item string) (expr, alias string) { + if at := indexWord(strings.ToLower(item), "as"); at >= 0 { + return strings.TrimSpace(item[:at]), strings.TrimSpace(item[at+2:]) + } + fields := strings.Fields(item) + if len(fields) == 2 { + return fields[0], fields[1] + } + return strings.TrimSpace(item), "" +} + +// isColumnRef reports whether s is a plain (optionally table-qualified) column +// reference: identifier characters, quotes, and a single dot, with no operator, +// parenthesis, or whitespace that would mark an expression. +func isColumnRef(s string) bool { + s = strings.TrimSpace(s) + if s == "" { + return false + } + for i := 0; i < len(s); i++ { + c := s[i] + switch { + case c >= 'a' && c <= 'z', c >= 'A' && c <= 'Z', c >= '0' && c <= '9': + case c == '_' || c == '.' || c == '"' || c == '`' || c == '[' || c == ']' || c == '$': + default: + return false + } + } + return true +} + +// indexWord finds the byte offset of a standalone lowercase word in s (which the +// caller has already lowercased where needed), requiring word boundaries so that +// "as" does not match inside "class". It returns -1 when absent. +func indexWord(s, word string) int { + from := 0 + for { + at := strings.Index(s[from:], word) + if at < 0 { + return -1 + } + at += from + if wordBoundary(s, at, len(word)) { + return at + } + from = at + len(word) + } +} + +// topLevelKeyword finds a standalone keyword in s outside any parentheses or +// quotes, the boundary a clause splitter needs so a keyword inside a subquery or +// string does not match. It matches case-insensitively and returns -1 when absent. +func topLevelKeyword(s, keyword string) int { + low := strings.ToLower(s) + depth := 0 + var quote byte + for i := 0; i < len(low); i++ { + c := low[i] + switch { + case quote != 0: + if c == quote { + quote = 0 + } + case c == '\'' || c == '"' || c == '`': + quote = c + case c == '(' || c == '[': + depth++ + case c == ')' || c == ']': + depth-- + case depth == 0 && c == keyword[0] && strings.HasPrefix(low[i:], keyword) && wordBoundary(low, i, len(keyword)): + return i + } + } + return -1 +} + +// hasTopLevelKeyword reports whether any of the keywords appears at the top level. +func hasTopLevelKeyword(s string, keywords ...string) bool { + for _, kw := range keywords { + if topLevelKeyword(s, kw) >= 0 { + return true + } + } + return false +} + +// wordBoundary reports whether the substring at [at, at+n) in s is bounded by +// non-identifier characters on both sides. +func wordBoundary(s string, at, n int) bool { + if at > 0 && isIdentByte(s[at-1]) { + return false + } + end := at + n + if end < len(s) && isIdentByte(s[end]) { + return false + } + return true +} + +// isIdentByte reports whether c can appear inside an unquoted SQL identifier. +func isIdentByte(c byte) bool { + return c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' +} diff --git a/backend/sqlite/viewparse_test.go b/backend/sqlite/viewparse_test.go new file mode 100644 index 0000000..f034d0e --- /dev/null +++ b/backend/sqlite/viewparse_test.go @@ -0,0 +1,120 @@ +package sqlite + +import ( + "context" + "testing" +) + +// baseColsStub answers the base-column lookup the view parser uses to expand a +// star projection, for the unit tests that exercise the parser directly. +func baseColsStub(byTable map[string][]string) func(string) ([]string, bool) { + return func(name string) ([]string, bool) { + cols, ok := byTable[name] + return cols, ok + } +} + +func TestParseViewColumnsStarProjection(t *testing.T) { + ddl := `CREATE VIEW film_view AS SELECT * FROM films` + got := parseViewColumns(ddl, baseColsStub(map[string][]string{ + "films": {"id", "title", "director_id"}, + })) + if len(got) != 3 { + t.Fatalf("got %d view columns, want 3", len(got)) + } + if got[2].Name != "director_id" || got[2].BaseColumn != "director_id" || got[2].BaseRelation != "films" { + t.Errorf("third column = %+v, want director_id<-films.director_id", got[2]) + } +} + +func TestParseViewColumnsExplicitListWithAlias(t *testing.T) { + ddl := `CREATE VIEW v AS SELECT id, director_id AS dir FROM films` + got := parseViewColumns(ddl, baseColsStub(nil)) + if len(got) != 2 { + t.Fatalf("got %d view columns, want 2", len(got)) + } + if got[1].Name != "dir" || got[1].BaseColumn != "director_id" { + t.Errorf("aliased column = %+v, want dir<-director_id", got[1]) + } +} + +func TestParseViewColumnsQualifiedReference(t *testing.T) { + ddl := `CREATE VIEW v AS SELECT f.id, f.director_id FROM films f` + got := parseViewColumns(ddl, baseColsStub(nil)) + if len(got) != 2 { + t.Fatalf("got %d view columns, want 2", len(got)) + } + if got[1].BaseColumn != "director_id" || got[1].Name != "director_id" { + t.Errorf("qualified column = %+v, want director_id", got[1]) + } +} + +func TestParseViewColumnsRejectsJoin(t *testing.T) { + ddl := `CREATE VIEW v AS SELECT f.id FROM films f JOIN directors d ON d.id = f.director_id` + if got := parseViewColumns(ddl, baseColsStub(nil)); got != nil { + t.Errorf("a joined view should not project, got %v", got) + } +} + +func TestParseViewColumnsRejectsUnion(t *testing.T) { + ddl := `CREATE VIEW v AS SELECT id FROM films UNION SELECT id FROM directors` + if got := parseViewColumns(ddl, baseColsStub(nil)); got != nil { + t.Errorf("a union view should not project, got %v", got) + } +} + +func TestParseViewColumnsRejectsExpression(t *testing.T) { + ddl := `CREATE VIEW v AS SELECT id, upper(title) AS t FROM films` + if got := parseViewColumns(ddl, baseColsStub(nil)); got != nil { + t.Errorf("an expression column should stop projection, got %v", got) + } +} + +// TestExecuteEmbedThroughView covers 01.11 end-to-end on SQLite: a view defined +// as SELECT over a base table inherits the base foreign key, so the view embeds +// the referenced table as a to-one and returns the nested object. +func TestExecuteEmbedThroughView(t *testing.T) { + dsn := "file:" + t.Name() + "?mode=memory&cache=shared" + b, err := Open(dsn) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { b.Close() }) + + _, err = b.DB().Exec(` + CREATE TABLE directors (id INTEGER PRIMARY KEY, name TEXT NOT NULL); + CREATE TABLE films ( + id INTEGER PRIMARY KEY, + title TEXT NOT NULL, + director_id INTEGER REFERENCES directors(id) + ); + CREATE VIEW film_view AS SELECT id, title, director_id FROM films; + INSERT INTO directors (id, name) VALUES (1, 'Lang'); + INSERT INTO films (id, title, director_id) VALUES (1, 'Metropolis', 1); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + + model, err := b.Introspect(context.Background()) + if err != nil { + t.Fatalf("Introspect: %v", err) + } + view, _ := model.Lookup("film_view", nil) + if len(view.ViewColumns) != 3 { + t.Fatalf("film_view has %d view columns, want 3", len(view.ViewColumns)) + } + cands, _ := model.Relationships(view, "directors", nil) + if len(cands) != 1 { + t.Fatalf("got %d relationships film_view->directors, want 1", len(cands)) + } + q := planEmbed(t, b, "film_view", "select=title,director:directors(name)") + rows := execReadResolved(t, b, q) + if len(rows) != 1 { + t.Fatalf("got %d rows, want 1", len(rows)) + } + dir, ok := asString(rows[0]["director"]) + if !ok || dir == "" { + t.Fatalf("director = %v, want a nested object", rows[0]["director"]) + } +} diff --git a/backend/sqlite/writerep_test.go b/backend/sqlite/writerep_test.go new file mode 100644 index 0000000..7bc22c4 --- /dev/null +++ b/backend/sqlite/writerep_test.go @@ -0,0 +1,119 @@ +package sqlite + +import ( + "testing" + + "github.com/tamnd/dbrest/ir" +) + +// A bulk update that asks for a representation with order+limit writes every +// matching row (v13 dropped limited update/delete) but returns only the ordered, +// limited slice. The affected count stays the full set, and a re-read confirms +// no row escaped the mutation. +func TestExecuteUpdateRepresentationOrderedLimited(t *testing.T) { + b := openSeeded(t) + limit := 2 + res := execWrite(t, b, &ir.Query{ + Kind: ir.Update, + Relation: ir.Ref{Name: "films"}, + Order: []ir.OrderTerm{{Path: []string{"id"}, Desc: true}}, + Limit: &limit, + Write: &ir.WriteSpec{ + Return: ir.ReturnRepresentation, + Set: map[string]ir.Value{"rating": {JSON: "X"}}, + }, + }) + + // The affected count is the full mutated set, not the limited body. + if n, ok := res.Affected(); !ok || n != 4 { + t.Errorf("Affected = %d,%v want 4,true", n, ok) + } + + // The body is the top two rows by id descending. + rows := readAll(t, res) + if len(rows) != 2 { + t.Fatalf("body rows = %d, want 2", len(rows)) + } + if rows[0]["id"].(int64) != 4 || rows[1]["id"].(int64) != 3 { + t.Errorf("body ids = [%v %v], want [4 3]", rows[0]["id"], rows[1]["id"]) + } + + // Every row was updated, including the two the representation omitted. + all := execRead(t, b, &ir.Query{ + Relation: ir.Ref{Name: "films"}, + Select: []ir.SelectItem{ir.Column{Path: []string{"rating"}}}, + }) + if len(all) != 4 { + t.Fatalf("after update, count = %d, want 4", len(all)) + } + for _, r := range all { + if r["rating"] != "X" { + t.Errorf("a row escaped the update: rating = %v, want X", r["rating"]) + } + } +} + +// 07.12: a no-op mutation (an insert with no rows, an update with no column +// assignments) short-circuits before any SQL runs. The affected count is zero, +// the representation is empty, and the table is untouched. +func TestExecuteNoOpInsertRunsNoSQL(t *testing.T) { + b := openSeeded(t) + res := execWrite(t, b, &ir.Query{ + Kind: ir.Insert, + Relation: ir.Ref{Name: "films"}, + Write: &ir.WriteSpec{Return: ir.ReturnRepresentation, Rows: nil}, + }) + if n, ok := res.Affected(); !ok || n != 0 { + t.Errorf("Affected = %d,%v want 0,true", n, ok) + } + if rows := readAll(t, res); len(rows) != 0 { + t.Errorf("body rows = %d, want 0", len(rows)) + } + if all := execRead(t, b, &ir.Query{Relation: ir.Ref{Name: "films"}}); len(all) != 4 { + t.Errorf("row count = %d, want 4 (nothing inserted)", len(all)) + } +} + +func TestExecuteNoOpUpdateRunsNoSQL(t *testing.T) { + b := openSeeded(t) + res := execWrite(t, b, &ir.Query{ + Kind: ir.Update, + Relation: ir.Ref{Name: "films"}, + Write: &ir.WriteSpec{Return: ir.ReturnRepresentation, Set: map[string]ir.Value{}}, + }) + if n, ok := res.Affected(); !ok || n != 0 { + t.Errorf("Affected = %d,%v want 0,true", n, ok) + } + if rows := readAll(t, res); len(rows) != 0 { + t.Errorf("body rows = %d, want 0", len(rows)) + } +} + +// offset on a delete representation skips rows in the returned body while still +// deleting every matching row. +func TestExecuteDeleteRepresentationOffset(t *testing.T) { + b := openSeeded(t) + offset := 1 + res := execWrite(t, b, &ir.Query{ + Kind: ir.Delete, + Relation: ir.Ref{Name: "films"}, + Order: []ir.OrderTerm{{Path: []string{"id"}}}, + Offset: &offset, + Write: &ir.WriteSpec{Return: ir.ReturnRepresentation}, + }) + + if n, ok := res.Affected(); !ok || n != 4 { + t.Errorf("Affected = %d,%v want 4,true", n, ok) + } + rows := readAll(t, res) + if len(rows) != 3 { + t.Fatalf("body rows = %d, want 3 (offset 1 of 4)", len(rows)) + } + if rows[0]["id"].(int64) != 2 { + t.Errorf("first body id = %v, want 2", rows[0]["id"]) + } + // The whole table is gone regardless of the body window. + if all := execRead(t, b, &ir.Query{Relation: ir.Ref{Name: "films"}}); len(all) != 0 { + t.Errorf("after delete, count = %d, want 0", len(all)) + } +} diff --git a/backend/sqlserver/dialect.go b/backend/sqlserver/dialect.go index 8fac96c..c860ade 100644 --- a/backend/sqlserver/dialect.go +++ b/backend/sqlserver/dialect.go @@ -50,6 +50,11 @@ func (Dialect) Placeholder(n int) string { return "@p" + strconv.Itoa(n) } // used uniformly. func (Dialect) LimitOffset(limit, offset *int, hasOrder bool) string { if limit == nil && offset == nil { + if hasOrder { + // ORDER BY in a derived table requires OFFSET even when no paging is + // requested; OFFSET 0 ROWS keeps all rows while making the ORDER BY valid. + return "OFFSET 0 ROWS" + } return "" } off := 0 @@ -135,12 +140,13 @@ func (Dialect) JSONObject(pairs []sqlgen.Pair) string { return "JSON_OBJECT(" + strings.Join(parts, ", ") + ")" } -// JSONAgg aggregates rows with the SQL Server 2022 JSON_ARRAYAGG. The aggregate -// takes no ORDER BY argument, so a requested embed order is applied on the -// derived table feeding the aggregate, not here; orderBy is therefore unused and -// the row order within the array is best-effort (spec 06). +// JSONAgg aggregates rows into a JSON array using STRING_AGG. JSON_ARRAYAGG was +// only added in SQL Server 2025 (version 17); for 2022 compatibility the dialect +// constructs the array manually: '[' + STRING_AGG(elem,',') + ']'. The elements +// are cast to NVARCHAR(MAX) so STRING_AGG accepts them. orderBy is unused; a +// requested embed order is applied on the derived table feeding the aggregate. func (Dialect) JSONAgg(elem, _ string) string { - return "JSON_ARRAYAGG(" + elem + ")" + return "'['+STRING_AGG(CAST((" + elem + ") AS NVARCHAR(MAX)),',')+']'" } // Cast translates a canonical type to a T-SQL CAST target. SQL Server has no @@ -218,8 +224,38 @@ func (Dialect) SessionWrite(key string) (string, bool) { return "EXEC sp_set_session_context N'" + strings.ReplaceAll(key, "'", "''") + "', " + sqlgen.PatternMark, true } -// ArrayOp returns false; SQL Server has no array types or containment operators. -func (Dialect) ArrayOp(_, _, _ string) (string, bool) { return "", false } +// ArrayOp implements array containment/overlap operators using OPENJSON, which +// parses the JSON array argument and the JSON array column for element-level +// comparisons. val is a bound placeholder (@pN) whose value is a JSON array +// string (converted from PostgreSQL {a,b} syntax by ArrayLiteral). +func (Dialect) ArrayOp(col, op, val, _ string) (string, bool) { + switch op { + case "@>": + // col contains every element of val + return "NOT EXISTS(SELECT [value] FROM OPENJSON(" + val + ") WHERE [value] NOT IN (SELECT [value] FROM OPENJSON(" + col + ")))", true + case "<@": + // every element of col exists in val + return "NOT EXISTS(SELECT [value] FROM OPENJSON(" + col + ") WHERE [value] NOT IN (SELECT [value] FROM OPENJSON(" + val + ")))", true + case "&&": + // at least one element in common + return "EXISTS(SELECT 1 FROM OPENJSON(" + col + ") a WHERE a.[value] IN (SELECT [value] FROM OPENJSON(" + val + ")))", true + } + return "", false +} + +// RangeOp declines: SQL Server has no range types, so sl/sr/nxr/nxl/adj are +// PGRST127. +func (Dialect) RangeOp(_, _, _ string) (string, bool) { return "", false } + +// IsBool renders "col = 1" or "col = 0" for SQL Server BIT columns. SQL +// Server's IS operator only accepts NULL/UNKNOWN, not integer literals. +func (Dialect) IsBool(col string, v bool) (string, bool) { + return col + " = " + Dialect{}.BoolValue(v), true +} + +// IsUnknown falls back to "col IS NULL"; a BIT boolean column's UNKNOWN state is +// its NULL, so the row set matches. +func (Dialect) IsUnknown(string) (string, bool) { return "", false } // ILike uses plain LIKE; SQL Server's default collation is case-insensitive. func (Dialect) ILike(col, val string) (string, bool) { return col + " LIKE " + val, true } @@ -232,3 +268,46 @@ func (Dialect) BoolValue(v bool) string { } return "0" } + +// InList reports ok=false: SQL Server has no array-bound ANY, so the compiler +// emits the expanded col IN ($1, $2, ...) form. +func (Dialect) InList(_ string) (string, bool) { return "", false } + +// ArrayLiteral converts a PostgreSQL {a,b} array literal to a JSON array +// ["a","b"] so OPENJSON in ArrayOp can iterate over it. +func (Dialect) ArrayLiteral(pgText string) string { + s := strings.TrimSpace(pgText) + if len(s) < 2 || s[0] != '{' || s[len(s)-1] != '}' { + return pgText + } + inner := s[1 : len(s)-1] + if inner == "" { + return "[]" + } + parts := strings.Split(inner, ",") + quoted := make([]string, len(parts)) + for i, p := range parts { + p = strings.TrimSpace(p) + switch { + case p == "NULL": + quoted[i] = "null" + case len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"': + // PostgreSQL double-quote escaping: "foo" is already valid JSON; pass through. + quoted[i] = p + default: + quoted[i] = `"` + strings.ReplaceAll(p, `"`, `\"`) + `"` + } + } + return "[" + strings.Join(quoted, ",") + "]" +} + +// ArrayArg stores a payload array as its JSON text: SQL Server has no array +// columns, so an nvarchar column holds the array and reads it back as JSON. +// A PostgreSQL {a,b} literal here would corrupt the column. +func (Dialect) ArrayArg(elems []any, _ string) any { return sqlgen.JSONArrayArg(elems) } + +// JSONPath reports ok=false so the compiler raises PGRST127. SQL Server expresses +// JSON access through JSON_VALUE/JSON_QUERY rather than ->/->>, and lowering them +// to match PostgREST's typing needs a live server to verify; until then JSON +// paths are an honest capability gap, the per-driver remainder. +func (Dialect) JSONPath(string, []string, bool) (string, bool) { return "", false } diff --git a/backend/sqlserver/dialect_test.go b/backend/sqlserver/dialect_test.go index ea41b2f..c4f29f5 100644 --- a/backend/sqlserver/dialect_test.go +++ b/backend/sqlserver/dialect_test.go @@ -98,7 +98,7 @@ func TestJSON(t *testing.T) { if obj != "JSON_OBJECT('name': d.[name], 'year': d.[year])" { t.Errorf("JSONObject = %q", obj) } - if got := d.JSONAgg("t", "t.[id] DESC"); got != "JSON_ARRAYAGG(t)" { + if got := d.JSONAgg("t", "t.[id] DESC"); got != "'['+STRING_AGG(CAST((t) AS NVARCHAR(MAX)),',')+']'" { t.Errorf("JSONAgg = %q", got) } } diff --git a/backend/sqlserver/execute.go b/backend/sqlserver/execute.go index b5ed5d1..0b4a4ac 100644 --- a/backend/sqlserver/execute.go +++ b/backend/sqlserver/execute.go @@ -3,6 +3,8 @@ package sqlserver import ( "context" "database/sql" + "encoding/json" + "strconv" "strings" "github.com/tamnd/dbrest/backend" @@ -110,6 +112,19 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co q := plan.Query returning := returningCols(q, plan.Rel) + // An empty column set (POST with an empty array, PATCH with an empty object) + // is a no-op: nothing is compiled or run, the affected count is zero, and the + // representation is the empty array. The HTTP layer turns that into 201/[] for + // an insert and 204 or 200/[] for an update. + if backend.IsNoOpMutation(q) { + return &writeResult{ + controls: rc.Controls(), + cols: returning, + affected: 0, + hasAff: true, + }, nil + } + tx, err := b.db.BeginTx(ctx, nil) if err != nil { return nil, b.MapError(err) @@ -132,6 +147,11 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co return nil, b.MapError(err) } + // Prefer: max-affected rolls an over-broad write back instead of committing. + if apiErr := backend.EnforceMaxAffected(q.Write, res.affected, res.hasAff); apiErr != nil { + return nil, apiErr + } + if q.Write != nil && q.Write.Tx == ir.TxRollback { return res, nil } @@ -145,11 +165,17 @@ func (b *Backend) executeWrite(ctx context.Context, plan *ir.Plan, rc *reqctx.Co // The compiler emits: INSERT INTO [t] ([c1],[c2]) VALUES (@p1,@p2) // The data plane rewrites to: INSERT INTO [t] ([c1],[c2]) OUTPUT INSERTED.[c1],... VALUES (@p1,@p2) // by injecting the OUTPUT fragment before the " VALUES " marker. +// Upsert (on_conflict) is routed to executeUpsert instead of the single-statement +// compiler, which returns errUpsertMultiStatement. func (b *Backend) executeInsert( ctx context.Context, tx *sql.Tx, q *ir.Query, returning []string, rel *schema.Relation, res *writeResult, ) error { + if q.Kind == ir.Upsert { + return b.executeUpsert(ctx, tx, q, returning, rel, res) + } + st, apiErr := sqlgen.CompileInsert(Dialect{}, q, nil) if apiErr != nil { return apiErr @@ -187,6 +213,209 @@ func (b *Backend) executeInsert( return nil } +// executeUpsert implements the SQL Server upsert as a single MERGE statement per +// batch. MERGE avoids the semicolon-separated multi-statement pattern that +// go-mssqldb rejects when sent via sp_executesql. +// +// All rows are merged in one statement: the source is a VALUES(...) table with +// one row-tuple per input row; the ON clause matches the conflict (primary-key) +// columns; WHEN MATCHED updates non-key columns; WHEN NOT MATCHED inserts. +// The OUTPUT clause captures written rows when returning is requested. +func (b *Backend) executeUpsert( + ctx context.Context, tx *sql.Tx, + q *ir.Query, returning []string, rel *schema.Relation, + res *writeResult, +) error { + w := q.Write + if len(w.Rows) == 0 { + res.affected, res.hasAff = 0, true + return nil + } + + d := Dialect{} + sch := q.Relation.Schema + if sch == "" { + sch = b.schema + if sch == "" { + sch = "dbo" + } + } + tableName := d.QuoteIdent(sch) + "." + d.QuoteIdent(q.Relation.Name) + + conflictCols := w.Conflict.Target + conflictSet := make(map[string]bool, len(conflictCols)) + for _, c := range conflictCols { + conflictSet[c] = true + } + nonConflictCols := make([]string, 0, len(w.Columns)) + for _, c := range w.Columns { + if !conflictSet[c] { + nonConflictCols = append(nonConflictCols, c) + } + } + + // Collect args; @pN bind positions match the order we append. + raw := []any{} + argN := 0 + bind := func(v any) string { + argN++ + raw = append(raw, v) + return "@p" + strconv.Itoa(argN) + } + + // Build the source alias column names: s0, s1, ... + srcCols := make([]string, len(w.Columns)) + for i := range w.Columns { + srcCols[i] = "s" + strconv.Itoa(i) + } + + var sb strings.Builder + + // MERGE INTO target USING (VALUES (...),(...)) AS src(s0,s1,...) + sb.WriteString("MERGE INTO ") + sb.WriteString(tableName) + sb.WriteString(" WITH (HOLDLOCK) AS [_target] USING (VALUES ") + for ri, row := range w.Rows { + if ri > 0 { + sb.WriteString(",") + } + sb.WriteString("(") + for ci, c := range w.Columns { + if ci > 0 { + sb.WriteString(",") + } + sb.WriteString(bind(sqlgen.WriteArg(d, row[c], w.ColumnTypes[c]))) + } + sb.WriteString(")") + } + sb.WriteString(") AS [_src](") + for i, sc := range srcCols { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString(d.QuoteIdent(sc)) + } + sb.WriteString(") ON (") + // ON conflict columns match + for i, c := range conflictCols { + if i > 0 { + sb.WriteString(" AND ") + } + ci := colIndex(w.Columns, c) + sb.WriteString("[_target]." + d.QuoteIdent(c) + "=[_src]." + d.QuoteIdent(srcCols[ci])) + } + sb.WriteString(")") + + // WHEN MATCHED THEN UPDATE (skip if ignore or no non-conflict cols) + if w.Conflict.Resolution != ir.ConflictIgnore && len(nonConflictCols) > 0 { + sb.WriteString(" WHEN MATCHED THEN UPDATE SET ") + for i, c := range nonConflictCols { + if i > 0 { + sb.WriteString(",") + } + ci := colIndex(w.Columns, c) + sb.WriteString("[_target]." + d.QuoteIdent(c) + "=[_src]." + d.QuoteIdent(srcCols[ci])) + } + } + + // WHEN NOT MATCHED THEN INSERT + sb.WriteString(" WHEN NOT MATCHED THEN INSERT (") + for i, c := range w.Columns { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString(d.QuoteIdent(c)) + } + sb.WriteString(") VALUES (") + for i, sc := range srcCols { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString("[_src]." + d.QuoteIdent(sc)) + } + sb.WriteString(")") + + // OUTPUT clause when returning is requested + if len(returning) > 0 { + sb.WriteString(" OUTPUT ") + for i, c := range returning { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString("INSERTED." + d.QuoteIdent(c)) + } + } + + // MERGE requires a terminating semicolon. + sb.WriteString(";") + + // When any conflict column is an IDENTITY column and the user provided + // an explicit value, SQL Server requires IDENTITY_INSERT to be ON. + needIdentityInsert := rel != nil && hasIdentityConflictCol(rel, conflictCols, w.Columns) + if needIdentityInsert { + if _, err := tx.ExecContext(ctx, "SET IDENTITY_INSERT "+tableName+" ON"); err != nil { + return err + } + defer func() { _, _ = tx.ExecContext(ctx, "SET IDENTITY_INSERT "+tableName+" OFF") }() + } + + if len(returning) > 0 { + rows, err := tx.QueryContext(ctx, sb.String(), namedArgs(raw)...) + if err != nil { + return err + } + cols, err := rows.Columns() + if err != nil { + rows.Close() + return err + } + jsonIdx, timeIdx := buildColMaps(rows, nil) + buf, err := drain(rows, cols, jsonIdx, timeIdx) + rows.Close() + if err != nil { + return err + } + res.cols, res.rows = cols, buf + res.affected, res.hasAff = int64(len(buf)), true + return nil + } + + out, err := tx.ExecContext(ctx, sb.String(), namedArgs(raw)...) + if err != nil { + return err + } + n, _ := out.RowsAffected() + res.affected, res.hasAff = n, true + return nil +} + +// hasIdentityConflictCol reports whether any conflict column is an identity +// column AND is present in payloadCols (the user provided an explicit value). +// When IDENTITY_INSERT is ON, SQL Server requires an explicit value, so we only +// enable it when the identity column is actually in the payload. +func hasIdentityConflictCol(rel *schema.Relation, conflictCols, payloadCols []string) bool { + payload := make(map[string]bool, len(payloadCols)) + for _, c := range payloadCols { + payload[c] = true + } + for _, c := range conflictCols { + if col, ok := rel.Column(c); ok && col.Identity && payload[c] { + return true + } + } + return false +} + +// colIndex returns the position of name in cols, or 0 as a safe fallback. +func colIndex(cols []string, name string) int { + for i, c := range cols { + if c == name { + return i + } + } + return 0 +} + // executeUpdate runs UPDATE [t] SET ... OUTPUT INSERTED.* WHERE ... // Compiler emits: UPDATE [t] SET [c]=@p1 WHERE [id]=@p2 // Rewritten to: UPDATE [t] SET [c]=@p1 OUTPUT INSERTED.[c],... WHERE [id]=@p2 @@ -279,15 +508,27 @@ func (b *Backend) executeDelete( // executeCall runs a stored procedure or portable RPC function. func (b *Backend) executeCall(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - st, apiErr := sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func) + var st *sqlgen.Statement + var apiErr *pgerr.APIError + if plan.Func != nil { + // Portable registry function: the function body is a parameterised SQL + // statement whose :name placeholders are bound by CompileCall. + st, apiErr = sqlgen.CompileCall(Dialect{}, plan.Call, plan.Func, sqlgen.ContextArgs(rc)) + } else { + // Native RPC (NativeRPC=true): no registry function — generate EXEC + // [schema].[name] @param = @pN from the call's argument map. + st, apiErr = b.compileNativeCall(plan.Call) + } if apiErr != nil { return nil, apiErr } if plan.ReadOnly { res := &result{controls: rc.Controls()} - if plan.Call.Count != ir.CountNone { - cst, apiErr := sqlgen.CompileCallCount(Dialect{}, plan.Call, plan.Func) + // count=exact is only supported for portable registry functions; native + // stored procedures cannot be wrapped in SELECT count(*) in T-SQL. + if plan.Call.Count != ir.CountNone && plan.Func != nil { + cst, apiErr := sqlgen.CompileCallCount(Dialect{}, plan.Call, plan.Func, sqlgen.ContextArgs(rc)) if apiErr != nil { return nil, apiErr } @@ -379,6 +620,9 @@ func injectBeforeWhere(sqlStr, fragment string) string { // returningCols decides which columns to read back after a write. func returningCols(q *ir.Query, rel *schema.Relation) []string { if q.Write != nil && q.Write.Return == ir.ReturnRepresentation { + if cols := q.ProjectedColumns(); cols != nil { + return cols + } return rel.ColumnNames() } if q.Kind == ir.Insert || q.Kind == ir.Upsert { @@ -387,6 +631,57 @@ func returningCols(q *ir.Query, rel *schema.Relation) []string { return nil } +// compileNativeCall generates EXEC [schema].[name] @arg1 = @p1, @arg2 = @p2 for +// the NativeRPC path (plan.Func == nil). SQL Server stored procedures accept +// named parameters in any order, so the argument map can be emitted as-is. +// Scalar stored procedures should SELECT the result in a column named after the +// function (e.g. SELECT @a + @b AS [add]) so renderCall can detect scalar return +// by seeing a single column whose name matches the function name. +func (b *Backend) compileNativeCall(c *ir.Call) (*sqlgen.Statement, *pgerr.APIError) { + sch := b.schema + if sch == "" { + sch = "dbo" + } + d := Dialect{} + var sb strings.Builder + sb.WriteString("EXEC ") + sb.WriteString(d.QuoteIdent(sch)) + sb.WriteString(".") + sb.WriteString(d.QuoteIdent(c.Function.Name)) + + args := make([]any, 0, len(c.Args)) + i := 1 + for name, val := range c.Args { + if i == 1 { + sb.WriteString(" ") + } else { + sb.WriteString(", ") + } + sb.WriteString("@" + name + " = @p" + strconv.Itoa(i)) + // A POST arg has a decoded JSON value; a GET arg is raw text. + if val.JSON != nil { + args = append(args, nativeArgValue(val.JSON)) + } else { + args = append(args, val.Text) + } + i++ + } + return &sqlgen.Statement{SQL: sb.String(), Args: args}, nil +} + +// nativeArgValue converts a decoded JSON argument value to a driver-ready type. +// Scalars (string, float64, bool, nil) pass through; composite values are +// re-encoded as JSON text so the stored procedure can receive them as NVARCHAR. +func nativeArgValue(v any) any { + switch v.(type) { + case string, float64, bool, nil: + return v + default: + b, _ := json.Marshal(v) + return string(b) + } +} + // _ is a compile-time check that Backend implements backend.DB. var _ interface { Execute(context.Context, *ir.Plan, *reqctx.Context) (backend.Result, error) diff --git a/backend/sqlserver/fulltext.go b/backend/sqlserver/fulltext.go index 84956c8..cc68ea1 100644 --- a/backend/sqlserver/fulltext.go +++ b/backend/sqlserver/fulltext.go @@ -22,7 +22,7 @@ import ( // inflection the way plainto_tsquery's dictionary normalization does. // - the other variants map to CONTAINS, whose AND / OR / AND NOT / NEAR // operators give the explicit set semantics to_tsquery has. -func (Dialect) FullText(col string, _ *sqlgen.FullTextRef, variant ir.FTSVariant, _, value string) (string, string, bool) { +func (Dialect) FullText(col, _ string, _ *sqlgen.FullTextRef, variant ir.FTSVariant, _, value string) (string, string, bool) { if variant == ir.FTSPlainText { // FREETEXT takes a natural-language string, no operators; collapse runs of // whitespace so the bound value is clean. diff --git a/backend/sqlserver/fulltext_test.go b/backend/sqlserver/fulltext_test.go index a42118e..b860883 100644 --- a/backend/sqlserver/fulltext_test.go +++ b/backend/sqlserver/fulltext_test.go @@ -9,12 +9,12 @@ import ( // ftsFrag and fts return the predicate wrapper and the translated query value a // variant lowers to. func ftsFrag(v ir.FTSVariant) string { - frag, _, _ := Dialect{}.FullText("[c]", nil, v, "", "x") + frag, _, _ := Dialect{}.FullText("[c]", "", nil, v, "", "x") return frag } func fts(v ir.FTSVariant, value string) string { - _, q, _ := Dialect{}.FullText("[c]", nil, v, "", value) + _, q, _ := Dialect{}.FullText("[c]", "", nil, v, "", value) return q } diff --git a/backend/sqlserver/introspect.go b/backend/sqlserver/introspect.go index e9f961b..1cd8138 100644 --- a/backend/sqlserver/introspect.go +++ b/backend/sqlserver/introspect.go @@ -81,7 +81,8 @@ func (b *Backend) columns(ctx context.Context, table string) ([]*schema.Column, c.IS_NULLABLE, c.COLUMN_DEFAULT, CASE WHEN k.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS is_pk, - ISNULL(k.ORDINAL_POSITION, 0) AS pk_ord + ISNULL(k.ORDINAL_POSITION, 0) AS pk_ord, + COLUMNPROPERTY(OBJECT_ID(SCHEMA_NAME()+'.'+c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS is_identity FROM INFORMATION_SCHEMA.COLUMNS c LEFT JOIN ( SELECT kcu.COLUMN_NAME, kcu.ORDINAL_POSITION @@ -113,16 +114,17 @@ func (b *Backend) columns(ctx context.Context, table string) ([]*schema.Column, for rows.Next() { var name, dataType, isNullable string var colDefault sql.NullString - var isPK, pkOrd int - if err := rows.Scan(&name, &dataType, &isNullable, &colDefault, &isPK, &pkOrd); err != nil { + var isPK, pkOrd, isIdentity int + if err := rows.Scan(&name, &dataType, &isNullable, &colDefault, &isPK, &pkOrd, &isIdentity); err != nil { return nil, nil, err } - hasDefault := isPK == 1 || colDefault.Valid + hasDefault := isPK == 1 || colDefault.Valid || isIdentity == 1 col := &schema.Column{ Name: name, Type: sqlServerCanonicalType(dataType), Nullable: isNullable == "YES", HasDefault: hasDefault, + Identity: isIdentity == 1, } colRows = append(colRows, colRow{col: col, isPK: isPK == 1, pkOrd: pkOrd}) } diff --git a/backend/sqlserver/sqlserver.go b/backend/sqlserver/sqlserver.go index c31afb6..7192d64 100644 --- a/backend/sqlserver/sqlserver.go +++ b/backend/sqlserver/sqlserver.go @@ -120,14 +120,21 @@ func (b *Backend) MapError(err error) *pgerr.APIError { } // mapSQLServerError builds the unified API error from a SQL Server error. -func mapSQLServerError(me *mssql.Error) *pgerr.APIError { +func mapSQLServerError(me mssql.Error) *pgerr.APIError { + // Class-23 violations carry PostgreSQL's wording, not the native SQL Server + // text: the driver gives no constraint name or offending value in a form that + // reconstructs PG's message, so neither is invented and the native text is not + // leaked into details (an emulation limitation, documented in the spec). switch me.Number { case 2627, 2601: // unique constraint / unique index violation - return pgerr.ErrUniqueViolation(me.Message) + return pgerr.ErrConstraintViolation(pgerr.CodeUniqueViolation, + "duplicate key value violates unique constraint", "", "") case 515: // cannot insert NULL - return pgerr.ErrNotNullViolation(me.Message) + return pgerr.ErrConstraintViolation(pgerr.CodeNotNullViolation, + "null value violates not-null constraint", "", "") case 547: // FK constraint violation - return pgerr.ErrForeignKeyViolation(me.Message) + return pgerr.ErrConstraintViolation(pgerr.CodeForeignKeyViolation, + "insert or update on table violates foreign key constraint", "", "") case 207: // invalid column name return pgerr.New(400, "42703", me.Message) case 208: // invalid object name (table not found) @@ -206,9 +213,10 @@ func sqlServerCanonicalType(dataType string) string { } } -// asMSSQLError unwraps err as a *mssql.Error. -func asMSSQLError(err error) (*mssql.Error, bool) { - var me *mssql.Error +// asMSSQLError unwraps err as a mssql.Error. mssql.Error implements error via a +// value receiver so errors.As requires a value target, not a pointer. +func asMSSQLError(err error) (mssql.Error, bool) { + var me mssql.Error ok := errors.As(err, &me) return me, ok } diff --git a/backend/writerep.go b/backend/writerep.go new file mode 100644 index 0000000..92845a0 --- /dev/null +++ b/backend/writerep.go @@ -0,0 +1,215 @@ +package backend + +import ( + "encoding/json" + "sort" + + "github.com/tamnd/dbrest/ir" +) + +// IsNoOpMutation reports whether a write resolves to an empty column set that +// PostgREST treats as a zero-row no-op rather than an error. A POST with an empty +// array body (no rows) and a PATCH with an empty object body (no assignments) both +// touch nothing: the mutation runs against no data, the affected count is zero, and +// the representation is the empty array. Upsert/PUT and DELETE are excluded; an +// empty PUT body is a distinct shape, and DELETE carries no payload to be empty. +func IsNoOpMutation(q *ir.Query) bool { + if q == nil || q.Write == nil { + return false + } + switch q.Kind { + case ir.Insert: + return len(q.Write.Rows) == 0 + case ir.Update: + return len(q.Write.Set) == 0 + default: + return false + } +} + +// ShapeWriteRepresentation orders and paginates the rows a mutation returns for +// its representation. PostgREST v13 dropped limited update/delete (#3013), so +// order, limit and offset never bound the mutation itself: every matching row +// is written. The affected count the caller has already taken stays the full +// total, so Prefer: max-affected and the write Content-Range are unchanged. +// These query parameters only shape the returned body, matching v14 where the +// mutation's RETURNING is wrapped in an ordered, limited outer select (see +// PostgREST UpdateSpec "with ordering on top-level resource"). +// +// Ordering compares the buffered values directly, which matches an engine's +// binary/C collation; under a locale-aware text collation a column's order can +// differ, a representation-layer divergence. A term whose column is not in the +// returned projection is skipped, since the buffered representation cannot carry +// a value it never selected. +func ShapeWriteRepresentation(cols []string, rows [][]any, q *ir.Query) [][]any { + if q == nil || len(rows) == 0 { + return rows + } + rows = orderWriteRows(cols, rows, q.Order) + rows = pageWriteRows(rows, q.Limit, q.Offset) + return rows +} + +// orderSortKey binds an order term to the column index it sorts on. +type orderSortKey struct { + idx int + desc bool + nullsFirst bool +} + +// orderWriteRows stably sorts the buffered rows by the plain-column order terms +// that name a returned column. JSON-path terms and terms whose column is absent +// from the projection are skipped (the representation does not carry them). +func orderWriteRows(cols []string, rows [][]any, terms []ir.OrderTerm) [][]any { + if len(terms) == 0 { + return rows + } + index := make(map[string]int, len(cols)) + for i, c := range cols { + index[c] = i + } + var keys []orderSortKey + for _, t := range terms { + if len(t.Path) != 1 || t.Last != ir.JSONNone { + continue // a JSON sub-path is not a plain returned column + } + i, ok := index[t.Path[0]] + if !ok { + continue + } + // PostgreSQL default: NULLS LAST for ascending, NULLS FIRST for + // descending; an explicit nullsfirst/nullslast modifier overrides it. + nullsFirst := t.Desc + if t.NullsFirst != nil { + nullsFirst = *t.NullsFirst + } + keys = append(keys, orderSortKey{idx: i, desc: t.Desc, nullsFirst: nullsFirst}) + } + if len(keys) == 0 { + return rows + } + sort.SliceStable(rows, func(a, b int) bool { + for _, k := range keys { + av, bv := rows[a][k.idx], rows[b][k.idx] + aNull, bNull := av == nil, bv == nil + if aNull || bNull { + if aNull && bNull { + continue + } + // NULL placement is absolute: descending reverses the value + // order but not which end the NULLs land on. + if aNull { + return k.nullsFirst + } + return !k.nullsFirst + } + cmp := compareCells(av, bv) + if cmp == 0 { + continue + } + if k.desc { + return cmp > 0 + } + return cmp < 0 + } + return false + }) + return rows +} + +// compareCells orders two non-NULL buffered cell values: numbers compare +// numerically, booleans false-before-true, and everything else by its text form +// (matching binary/C collation). +func compareCells(a, b any) int { + if af, aok := cellFloat(a); aok { + if bf, bok := cellFloat(b); bok { + switch { + case af < bf: + return -1 + case af > bf: + return 1 + default: + return 0 + } + } + } + if ab, aok := a.(bool); aok { + if bb, bok := b.(bool); bok { + switch { + case ab == bb: + return 0 + case !ab: + return -1 + default: + return 1 + } + } + } + as, bs := cellString(a), cellString(b) + switch { + case as < bs: + return -1 + case as > bs: + return 1 + default: + return 0 + } +} + +// cellFloat reports a numeric value for the integer and float types the drivers +// decode into, so numeric columns sort by magnitude rather than text. +func cellFloat(v any) (float64, bool) { + switch n := v.(type) { + case int64: + return float64(n), true + case int: + return float64(n), true + case int32: + return float64(n), true + case float64: + return n, true + case float32: + return float64(n), true + case json.Number: + f, err := n.Float64() + return f, err == nil + } + return 0, false +} + +// cellString renders a cell to its text form for collation and as the fallback +// when two cells are of unlike kinds. +func cellString(v any) string { + switch s := v.(type) { + case string: + return s + case []byte: + return string(s) + case json.RawMessage: + return string(s) + } + b, err := json.Marshal(v) + if err != nil { + return "" + } + return string(b) +} + +// pageWriteRows applies offset then limit to the ordered representation. A nil +// bound leaves that end open; an offset past the end yields no rows. +func pageWriteRows(rows [][]any, limit, offset *int) [][]any { + if offset != nil { + o := max(*offset, 0) + if o >= len(rows) { + return rows[:0] + } + rows = rows[o:] + } + if limit != nil { + l := max(*limit, 0) + if l < len(rows) { + rows = rows[:l] + } + } + return rows +} diff --git a/backend/writerep_test.go b/backend/writerep_test.go new file mode 100644 index 0000000..b2328af --- /dev/null +++ b/backend/writerep_test.go @@ -0,0 +1,107 @@ +package backend + +import ( + "reflect" + "testing" + + "github.com/tamnd/dbrest/ir" +) + +func intp(n int) *int { return &n } + +func boolp(b bool) *bool { return &b } + +// Ordering a write representation sorts the buffered rows by a plain column, +// numerically for an integer column and honouring descending. +func TestShapeWriteRepresentationOrders(t *testing.T) { + cols := []string{"id", "name"} + rows := [][]any{ + {int64(2), "b"}, + {int64(10), "j"}, + {int64(1), "a"}, + } + q := &ir.Query{Order: []ir.OrderTerm{{Path: []string{"id"}, Desc: true}}} + got := ShapeWriteRepresentation(cols, rows, q) + want := [][]any{{int64(10), "j"}, {int64(2), "b"}, {int64(1), "a"}} + if !reflect.DeepEqual(got, want) { + t.Errorf("ordered rows = %v, want %v", got, want) + } +} + +// limit and offset bound the returned body after ordering. +func TestShapeWriteRepresentationPaginates(t *testing.T) { + cols := []string{"id"} + rows := [][]any{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} + q := &ir.Query{ + Order: []ir.OrderTerm{{Path: []string{"id"}}}, + Offset: intp(1), + Limit: intp(2), + } + got := ShapeWriteRepresentation(cols, rows, q) + want := [][]any{{int64(2)}, {int64(3)}} + if !reflect.DeepEqual(got, want) { + t.Errorf("paged rows = %v, want %v", got, want) + } +} + +// An offset past the end yields an empty body, not an error. +func TestShapeWriteRepresentationOffsetPastEnd(t *testing.T) { + cols := []string{"id"} + rows := [][]any{{int64(1)}, {int64(2)}} + q := &ir.Query{Offset: intp(5)} + if got := ShapeWriteRepresentation(cols, rows, q); len(got) != 0 { + t.Errorf("rows = %v, want empty", got) + } +} + +// NULLs sort last on ascending by default and first on descending, matching +// PostgreSQL; an explicit nullsfirst overrides the default. +func TestShapeWriteRepresentationNullsPlacement(t *testing.T) { + cols := []string{"v"} + mk := func() [][]any { return [][]any{{int64(2)}, {nil}, {int64(1)}} } + + asc := ShapeWriteRepresentation(cols, mk(), &ir.Query{Order: []ir.OrderTerm{{Path: []string{"v"}}}}) + if asc[2][0] != nil { + t.Errorf("asc default: null should sort last, got %v", asc) + } + + desc := ShapeWriteRepresentation(cols, mk(), &ir.Query{Order: []ir.OrderTerm{{Path: []string{"v"}, Desc: true}}}) + if desc[0][0] != nil { + t.Errorf("desc default: null should sort first, got %v", desc) + } + + nf := ShapeWriteRepresentation(cols, mk(), &ir.Query{Order: []ir.OrderTerm{{Path: []string{"v"}, NullsFirst: boolp(true)}}}) + if nf[0][0] != nil { + t.Errorf("asc nullsfirst: null should sort first, got %v", nf) + } +} + +// A term naming a column outside the projection is skipped: the representation +// cannot order by a value it never carried. The rows keep their order. +func TestShapeWriteRepresentationSkipsAbsentColumn(t *testing.T) { + cols := []string{"id"} + rows := [][]any{{int64(3)}, {int64(1)}, {int64(2)}} + q := &ir.Query{Order: []ir.OrderTerm{{Path: []string{"name"}}}} + got := ShapeWriteRepresentation(cols, rows, q) + want := [][]any{{int64(3)}, {int64(1)}, {int64(2)}} + if !reflect.DeepEqual(got, want) { + t.Errorf("rows = %v, want unchanged %v", got, want) + } +} + +// Shaping never alters the affected count: the caller takes it from the full +// buffer before shaping. This guards the contract that order/limit bound only +// the body, not the mutation. Here the full set is 4 rows; the shaped body is 1. +func TestShapeWriteRepresentationLeavesCallerCountAlone(t *testing.T) { + cols := []string{"id"} + full := [][]any{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} + affected := int64(len(full)) + q := &ir.Query{Order: []ir.OrderTerm{{Path: []string{"id"}}}, Limit: intp(1)} + body := ShapeWriteRepresentation(cols, full, q) + if affected != 4 { + t.Errorf("affected = %d, want 4 (the full mutated set)", affected) + } + if len(body) != 1 { + t.Errorf("body rows = %d, want 1", len(body)) + } +} diff --git a/cmd/dbrest-conformance/main.go b/cmd/dbrest-conformance/main.go index 38b42a3..142a3ff 100644 --- a/cmd/dbrest-conformance/main.go +++ b/cmd/dbrest-conformance/main.go @@ -4,9 +4,11 @@ // runs the capability self-consistency check. It is the local reproduction of // what the CI matrix does per backend (spec 22 section 10). // -// Only the SQLite backend is wired today, with the films fixture; another -// backend joins by adding its fixture and capabilities here once its driver -// lands. +// The SQLite and PostgreSQL backends are wired, each with a films fixture; +// another backend joins by adding its fixture and capabilities here once its +// driver lands. The postgres pass needs a live server, read from DBREST_PG_DSN +// or the -dsn flag, and it is the reference backend, so its corpus golden is the +// upstream PostgreSQL output and its allowlist documents no divergence. package main import ( @@ -14,9 +16,11 @@ import ( "flag" "fmt" "log" + "os" "time" "github.com/tamnd/dbrest/backend" + "github.com/tamnd/dbrest/backend/postgres" "github.com/tamnd/dbrest/backend/sqlite" "github.com/tamnd/dbrest/conformance" "github.com/tamnd/dbrest/httpapi" @@ -30,22 +34,57 @@ func main() { func run() error { var ( - backendName = flag.String("backend", "sqlite", "backend to run the conformance pass against") - corpusPath = flag.String("corpus", "conformance/testdata/sqlite/corpus.json", "request corpus file") - allowPath = flag.String("allowlist", "conformance/testdata/sqlite/allowlist.json", "allowlist file") + backendName = flag.String("backend", "sqlite", "backend to run the conformance pass against (sqlite or postgres)") + corpusPath = flag.String("corpus", "", "request corpus file (defaults to the backend's testdata corpus)") + allowPath = flag.String("allowlist", "", "allowlist file (defaults to the backend's testdata allowlist)") + dsn = flag.String("dsn", "", "postgres DSN; falls back to DBREST_PG_DSN") ) flag.Parse() - if *backendName != "sqlite" { - return fmt.Errorf("backend %q is not wired into the harness yet; only sqlite is available", *backendName) + if *corpusPath == "" { + *corpusPath = fmt.Sprintf("conformance/testdata/%s/corpus.json", *backendName) + } + if *allowPath == "" { + *allowPath = fmt.Sprintf("conformance/testdata/%s/allowlist.json", *backendName) } - srv, be, err := sqliteFixture() - if err != nil { - return err + var ( + srv *httpapi.Server + caps backend.Capabilities + closeBE func() + tiers map[string]backend.Tier + err error + ) + switch *backendName { + case "sqlite": + var be *sqlite.Backend + srv, be, err = sqliteFixture() + if err != nil { + return err + } + closeBE = func() { _ = be.Close() } + caps = be.Capabilities() + tiers = featureTiers(caps) + case "postgres": + conn := *dsn + if conn == "" { + conn = os.Getenv("DBREST_PG_DSN") + } + if conn == "" { + return fmt.Errorf("postgres backend needs a DSN: pass -dsn or set DBREST_PG_DSN") + } + var be *postgres.Backend + srv, be, err = postgresFixture(conn) + if err != nil { + return err + } + closeBE = func() { _ = be.Close() } + caps = be.Capabilities() + tiers = featureTiers(caps) + default: + return fmt.Errorf("backend %q is not wired into the harness; available: sqlite, postgres", *backendName) } - defer func() { _ = be.Close() }() - caps := be.Capabilities() + defer closeBE() cases, err := conformance.LoadCorpus(*corpusPath) if err != nil { @@ -55,7 +94,7 @@ func run() error { if err != nil { return err } - if err := allow.CheckMatrix(featureTiers(caps)); err != nil { + if err := allow.CheckMatrix(tiers); err != nil { return err } @@ -85,7 +124,9 @@ func sqliteFixture() (*httpapi.Server, *sqlite.Backend, error) { if err != nil { return nil, nil, fmt.Errorf("introspect: %w", err) } - return httpapi.NewServer(be, model, nil), be, nil + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv, be, nil } const fixtureDDL = ` @@ -103,6 +144,57 @@ CREATE VIRTUAL TABLE films_fts USING fts5(title, content='films', content_rowid= INSERT INTO films_fts (rowid, title) SELECT id, title FROM films; ` +// postgresFixture builds the films fixture on a live PostgreSQL server in a +// dedicated schema and returns a server over it. Postgres is the reference +// backend: its corpus golden is the upstream output, so the fixture mirrors the +// sqlite one plus a tags array column, since arrays are Native here where SQLite +// has none. The anon role is created so the server's SET LOCAL ROLE has an +// identity to assume, matching the role-emulation path a real deployment uses. +func postgresFixture(dsn string) (*httpapi.Server, *postgres.Backend, error) { + be, err := postgres.Open(dsn) + if err != nil { + return nil, nil, fmt.Errorf("open: %w", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + if _, err := be.Pool().Exec(ctx, pgFixtureDDL); err != nil { + _ = be.Close() + return nil, nil, fmt.Errorf("load fixture: %w", err) + } + be.SetSchemas([]string{"_dbrest_conf"}) + model, err := be.Introspect(ctx) + if err != nil { + _ = be.Close() + return nil, nil, fmt.Errorf("introspect: %w", err) + } + srv := httpapi.NewServer(be, model, []string{"_dbrest_conf"}) + srv.SetDefaultRole("anon") + return srv, be, nil +} + +const pgFixtureDDL = ` +DROP SCHEMA IF EXISTS _dbrest_conf CASCADE; +CREATE SCHEMA _dbrest_conf; +DO $$ BEGIN + IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = 'anon') THEN + CREATE ROLE anon NOLOGIN; + END IF; +END $$; +GRANT USAGE ON SCHEMA _dbrest_conf TO anon; +CREATE TABLE _dbrest_conf.films ( + id integer PRIMARY KEY, + title text NOT NULL, + year integer, + rating text, + tags text[] +); +INSERT INTO _dbrest_conf.films (id, title, year, rating, tags) VALUES + (1, 'Metropolis', 1927, 'NR', '{sci-fi,silent}'), + (2, 'Blade Runner', 1982, 'R', '{sci-fi,noir}'), + (3, 'Arrival', 2016, 'PG13', '{sci-fi,drama}'); +GRANT SELECT ON _dbrest_conf.films TO anon; +` + // featureTiers maps the allowlist's feature labels to the tier each resolves to // on this backend, so the allowlist can be reconciled with the live matrix. func featureTiers(caps backend.Capabilities) map[string]backend.Tier { diff --git a/cmd/dbrest/cli.go b/cmd/dbrest/cli.go new file mode 100644 index 0000000..e6b1d4f --- /dev/null +++ b/cmd/dbrest/cli.go @@ -0,0 +1,112 @@ +// The PostgREST-shaped command line: the config file is a positional argument +// (`dbrest /etc/dbrest.conf`), and the maintenance verbs --version, --example, +// --dump-config, --dump-schema, and --ready mirror upstream's. +package main + +import ( + "fmt" + "net/http" + "runtime/debug" + "time" + + "github.com/tamnd/dbrest/config" +) + +// resolveConfigPath reconciles the -config flag with the positional argument. +// Either spelling works; giving both with different paths is an error rather +// than a silent pick. +func resolveConfigPath(flagPath string, args []string) (string, error) { + if len(args) > 1 { + return "", fmt.Errorf("expected at most one config file argument, got %d", len(args)) + } + if len(args) == 0 { + return flagPath, nil + } + if flagPath != "" && flagPath != args[0] { + return "", fmt.Errorf("config file given twice: -config %s and argument %s", flagPath, args[0]) + } + return args[0], nil +} + +// versionString is the module version when built with one, "dev" otherwise. +func versionString() string { + if bi, ok := debug.ReadBuildInfo(); ok && bi.Main.Version != "" && bi.Main.Version != "(devel)" { + return bi.Main.Version + } + return "dev" +} + +// probeReady asks a running instance's admin server whether it is ready, the +// --ready verb orchestrators use as a health command. A non-200 answer or an +// unreachable admin server is an error, which main turns into exit status 1. +func probeReady(cfg *config.Config) error { + if !cfg.AdminEnabled() { + return fmt.Errorf("--ready needs admin-server-port to be configured") + } + url := "http://" + probeAddr(cfg.AdminServerHost, cfg.AdminServerPort) + "/ready" + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get(url) + if err != nil { + return fmt.Errorf("ready probe: %w", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("ready probe: %s answered %d", url, resp.StatusCode) + } + return nil +} + +// exampleConfig is the --example output: a commented config covering the +// options most deployments touch, in the file syntax Load reads. +const exampleConfig = `## dbrest example configuration +## Every option also reads from the environment as PGRST_ or +## DBREST_, with the DBREST_ spelling winning. + +## The engine behind the API: postgres, sqlite, mysql, sqlserver, or mongodb. +db-backend = "sqlite" + +## The connection string, in the engine's own syntax. +db-uri = "file:dbrest.db" + +## The database schemas to expose, comma-separated. The first is the default. +# db-schemas = "public" + +## The role used for requests that carry no JWT. +# db-anon-role = "web_anon" + +## Hard cap on the rows a read or RPC response may return. Unset means no cap. +# db-max-rows = 1000 + +## How the request transaction ends: commit (default), commit-allow-override, +## rollback, or rollback-allow-override. +# db-tx-end = "commit" + +## Secret for validating JWTs (HS256). Longer than 32 characters. +# jwt-secret = "reallyreallyreallyreallyverysafe" + +## Where the API listens. Besides a literal address, the host takes the +## PostgREST special values: "*" (any host, either stack), "*4"/"*6" +## (prefer one stack, fall back to the other), "!4"/"!6" (require it). +# server-host = "!4" +# server-port = 3000 + +## The admin server with /live, /ready, /schema_cache, and /metrics. +## Disabled until a port is set; it must differ from server-port. +# admin-server-port = 3001 + +## Connection pool sizing. +# db-pool = 10 + +## OpenAPI output: follow-privileges (default), ignore-privileges, disabled. +# openapi-mode = "follow-privileges" + +## Logging: crit, error (default), warn, info, or debug. +# log-level = "error" + +## CORS. Unset serves the permissive default; a comma-separated list +## restricts the allowed origins. +# server-cors-allowed-origins = "https://example.com" + +## Settings forwarded to the backend as transaction settings. +# app.settings.tenant = "acme" +` diff --git a/cmd/dbrest/cli_test.go b/cmd/dbrest/cli_test.go new file mode 100644 index 0000000..27ab37b --- /dev/null +++ b/cmd/dbrest/cli_test.go @@ -0,0 +1,125 @@ +package main + +import ( + "context" + "errors" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/tamnd/dbrest/adminapi" + "github.com/tamnd/dbrest/config" +) + +// TestResolveConfigPath covers the positional/flag reconciliation: either +// spelling alone, agreement, disagreement, and too many arguments. +func TestResolveConfigPath(t *testing.T) { + cases := []struct { + flag string + args []string + want string + wantErr bool + }{ + {"", nil, "", false}, + {"a.conf", nil, "a.conf", false}, + {"", []string{"b.conf"}, "b.conf", false}, + {"a.conf", []string{"a.conf"}, "a.conf", false}, + {"a.conf", []string{"b.conf"}, "", true}, + {"", []string{"a.conf", "b.conf"}, "", true}, + } + for _, tc := range cases { + got, err := resolveConfigPath(tc.flag, tc.args) + if (err != nil) != tc.wantErr { + t.Errorf("resolveConfigPath(%q, %v): err = %v, wantErr %v", tc.flag, tc.args, err, tc.wantErr) + continue + } + if got != tc.want { + t.Errorf("resolveConfigPath(%q, %v) = %q, want %q", tc.flag, tc.args, got, tc.want) + } + } +} + +// TestExampleConfigLoads pins the --example output to something Load accepts. +func TestExampleConfigLoads(t *testing.T) { + path := filepath.Join(t.TempDir(), "example.conf") + if err := os.WriteFile(path, []byte(exampleConfig), 0o644); err != nil { + t.Fatal(err) + } + cfg, err := config.Load(path, nil) + if err != nil { + t.Fatalf("the example config does not load: %v", err) + } + if cfg.Backend != "sqlite" || cfg.DBURI != "file:dbrest.db" { + t.Errorf("example values not applied: backend=%q uri=%q", cfg.Backend, cfg.DBURI) + } +} + +// TestProbeReady exercises the --ready verb against a real admin server in +// both the ready and the not-ready state, plus the unconfigured error. +func TestProbeReady(t *testing.T) { + cfgFor := func(t *testing.T, admin *adminapi.Server) *config.Config { + t.Helper() + ts := httptest.NewServer(admin) + t.Cleanup(ts.Close) + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + port, err := strconv.Atoi(u.Port()) + if err != nil { + t.Fatal(err) + } + cfg, err := config.FromMap(map[string]string{"db-uri": "x"}) + if err != nil { + t.Fatal(err) + } + cfg.AdminServerHost = u.Hostname() + cfg.AdminServerPort = port + return cfg + } + + ready := cfgFor(t, &adminapi.Server{}) + if err := probeReady(ready); err != nil { + t.Errorf("ready instance: %v", err) + } + + pending := cfgFor(t, &adminapi.Server{ + Ready: func(context.Context) error { return errors.New("pending") }, + }) + if err := probeReady(pending); err == nil { + t.Error("pending instance: expected an error") + } + + ready.AdminServerPort = 0 + if err := probeReady(ready); err == nil { + t.Error("no admin port: expected an error") + } +} + +// TestListenFirstBindsSpecialHosts checks each special host value yields a +// bindable listener on this machine (port 0 picks a free port), and that the +// fallback order engages when the first candidate cannot bind. +func TestListenFirstBindsSpecialHosts(t *testing.T) { + for _, host := range []string{"*", "*4", "!4", "*6", "!6"} { + specs := listenSpecsFor(t, host) + ln, err := listenFirst(specs) + if err != nil { + t.Errorf("host %q: %v", host, err) + continue + } + ln.Close() + } +} + +func listenSpecsFor(t *testing.T, host string) []config.ListenSpec { + t.Helper() + cfg, err := config.FromMap(map[string]string{"db-uri": "x", "server-port": "0"}) + if err != nil { + t.Fatal(err) + } + cfg.ServerHost = host + return cfg.Listeners() +} diff --git a/cmd/dbrest/jwtsecret_test.go b/cmd/dbrest/jwtsecret_test.go new file mode 100644 index 0000000..4d18b5c --- /dev/null +++ b/cmd/dbrest/jwtsecret_test.go @@ -0,0 +1,65 @@ +package main + +import ( + "bytes" + "testing" + + "github.com/tamnd/dbrest/config" +) + +// TestJWTSecretBytes pins the jwt-secret-is-base64 contract: off means the +// literal bytes, on means URL-safe base64 with optional padding, and an +// undecodable value is an error rather than a wrong key. +func TestJWTSecretBytes(t *testing.T) { + cases := []struct { + name string + secret string + isB64 bool + want []byte + bad bool + }{ + {name: "plain passthrough", secret: "reallysafe", want: []byte("reallysafe")}, + {name: "unpadded url-safe", secret: "c2VjcmV0LWJ5dGVz", isB64: true, want: []byte("secret-bytes")}, + {name: "padded url-safe", secret: "c2VjcmV0IQ==", isB64: true, want: []byte("secret!")}, + {name: "url alphabet", secret: "_-7-", isB64: true, want: []byte{0xff, 0xee, 0xfe}}, + {name: "surrounding space", secret: " c2VjcmV0IQ== ", isB64: true, want: []byte("secret!")}, + {name: "not base64", secret: "definitely not base64!!", isB64: true, bad: true}, + {name: "standard alphabet rejected", secret: "/+7+", isB64: true, bad: true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := &config.Config{JWTSecret: tc.secret, JWTSecretIsBase64: tc.isB64} + got, err := jwtSecretBytes(cfg) + if tc.bad { + if err == nil { + t.Fatalf("decoded %q without error to %q", tc.secret, got) + } + return + } + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, tc.want) { + t.Errorf("got %q, want %q", got, tc.want) + } + }) + } +} + +// TestBase64OptionNoLongerWarns checks the option left the unenforced list +// when its behavior landed. +func TestBase64OptionNoLongerWarns(t *testing.T) { + cfg, err := config.FromMap(map[string]string{ + "db-uri": "x", + "jwt-secret": "c2VjcmV0IQ==", + "jwt-secret-is-base64": "true", + }) + if err != nil { + t.Fatal(err) + } + for _, w := range cfg.Warnings { + if bytes.Contains([]byte(w), []byte("jwt-secret-is-base64")) { + t.Errorf("unexpected warning: %s", w) + } + } +} diff --git a/cmd/dbrest/logging.go b/cmd/dbrest/logging.go new file mode 100644 index 0000000..607e406 --- /dev/null +++ b/cmd/dbrest/logging.go @@ -0,0 +1,67 @@ +// Request logging, filtered by the log-level option the way PostgREST filters +// its own request log: crit logs no requests, error logs server failures +// (5xx), warn adds client failures (4xx), and info and debug log every +// request. The level is read per request so a SIGUSR2 config reload takes +// effect without a restart. +package main + +import ( + "log" + "net" + "net/http" + "time" +) + +// shouldLog decides whether a response status is logged at the given level. +func shouldLog(level string, status int) bool { + switch level { + case "crit": + return false + case "error": + return status >= 500 + case "warn": + return status >= 400 + default: // info, debug + return true + } +} + +// statusWriter records the response status for the log line. +type statusWriter struct { + http.ResponseWriter + status int +} + +func (w *statusWriter) WriteHeader(code int) { + w.status = code + w.ResponseWriter.WriteHeader(code) +} + +// Flush passes through so streaming responses keep working behind the logger. +func (w *statusWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// requestLogger wraps the API handler with the filtered request log. level is +// consulted on every request, so it follows config reloads. +type requestLogger struct { + next http.Handler + level func() string + out *log.Logger +} + +func (l *requestLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) { + sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} + started := time.Now() + l.next.ServeHTTP(sw, r) + if !shouldLog(l.level(), sw.status) { + return + } + remote := r.RemoteAddr + if host, _, err := net.SplitHostPort(remote); err == nil { + remote = host + } + l.out.Printf("%s - %q %d - %s", remote, r.Method+" "+r.URL.RequestURI()+" "+r.Proto, sw.status, time.Since(started).Round(time.Microsecond)) +} diff --git a/cmd/dbrest/logging_test.go b/cmd/dbrest/logging_test.go new file mode 100644 index 0000000..8c50e89 --- /dev/null +++ b/cmd/dbrest/logging_test.go @@ -0,0 +1,97 @@ +package main + +import ( + "bytes" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestShouldLog pins the per-level filter to PostgREST's: crit logs nothing, +// error logs 5xx, warn adds 4xx, info and debug log everything. +func TestShouldLog(t *testing.T) { + cases := []struct { + level string + status int + want bool + }{ + {"crit", 500, false}, + {"crit", 200, false}, + {"error", 500, true}, + {"error", 404, false}, + {"error", 200, false}, + {"warn", 500, true}, + {"warn", 404, true}, + {"warn", 200, false}, + {"info", 200, true}, + {"info", 404, true}, + {"debug", 200, true}, + } + for _, tc := range cases { + if got := shouldLog(tc.level, tc.status); got != tc.want { + t.Errorf("shouldLog(%q, %d) = %v, want %v", tc.level, tc.status, got, tc.want) + } + } +} + +// TestRequestLoggerFiltersAndFormats runs requests through the middleware at +// different levels and checks what reaches the log. +func TestRequestLoggerFiltersAndFormats(t *testing.T) { + level := "error" + var buf bytes.Buffer + rl := &requestLogger{ + next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/boom": + w.WriteHeader(http.StatusInternalServerError) + case "/missing": + w.WriteHeader(http.StatusNotFound) + default: + w.Write([]byte("ok")) // implicit 200 + } + }), + level: func() string { return level }, + out: log.New(&buf, "", 0), + } + + get := func(path string) { + rec := httptest.NewRecorder() + rl.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, path, nil)) + } + + get("/films") + get("/missing") + if buf.Len() != 0 { + t.Errorf("level error logged a non-5xx request: %q", buf.String()) + } + + get("/boom") + line := buf.String() + if !strings.Contains(line, `"GET /boom HTTP/1.1" 500`) { + t.Errorf("5xx line = %q, want method, path, and status", line) + } + + buf.Reset() + level = "warn" + get("/missing") + get("/films") + if !strings.Contains(buf.String(), "404") || strings.Contains(buf.String(), "200") { + t.Errorf("level warn: %q, want the 404 and not the 200", buf.String()) + } + + buf.Reset() + level = "info" + get("/films") + if !strings.Contains(buf.String(), `"GET /films HTTP/1.1" 200`) { + t.Errorf("level info missed a 200: %q", buf.String()) + } + + buf.Reset() + level = "crit" + get("/boom") + if buf.Len() != 0 { + t.Errorf("level crit logged: %q", buf.String()) + } +} diff --git a/cmd/dbrest/main.go b/cmd/dbrest/main.go index 8fa4fcb..716c0f4 100644 --- a/cmd/dbrest/main.go +++ b/cmd/dbrest/main.go @@ -5,14 +5,24 @@ package main import ( "context" + "database/sql" + "encoding/base64" + "encoding/json" + "errors" "flag" "fmt" + "io/fs" "log" + "net" "net/http" "os" + "strconv" + "strings" "time" + "github.com/tamnd/dbrest/adminapi" "github.com/tamnd/dbrest/auth" + "github.com/tamnd/dbrest/authz" "github.com/tamnd/dbrest/backend" _ "github.com/tamnd/dbrest/backend/mongo" _ "github.com/tamnd/dbrest/backend/mysql" @@ -21,6 +31,7 @@ import ( _ "github.com/tamnd/dbrest/backend/sqlserver" "github.com/tamnd/dbrest/config" "github.com/tamnd/dbrest/httpapi" + "github.com/tamnd/dbrest/rpc" ) func main() { @@ -32,14 +43,52 @@ func main() { // run holds the real entry point so deferred cleanup (closing the backend) runs // on every exit path; main only translates a returned error into a fatal log. func run() error { - var configPath string + var ( + configPath string + showVersion bool + example bool + dumpConfig bool + dumpSchema bool + ready bool + ) flag.StringVar(&configPath, "config", "", "path to the configuration file (env-only if omitted)") + flag.BoolVar(&showVersion, "version", false, "print the version and exit") + flag.BoolVar(&showVersion, "v", false, "print the version and exit (shorthand)") + flag.BoolVar(&example, "example", false, "print an example configuration file and exit") + flag.BoolVar(&example, "e", false, "print an example configuration file and exit (shorthand)") + flag.BoolVar(&dumpConfig, "dump-config", false, "print the resolved configuration and exit") + flag.BoolVar(&dumpSchema, "dump-schema", false, "print the schema cache as JSON and exit") + flag.BoolVar(&ready, "ready", false, "probe a running instance's admin /ready and exit 0 or 1") flag.Parse() + configPath, err := resolveConfigPath(configPath, flag.Args()) + if err != nil { + return err + } + if showVersion { + fmt.Println("dbrest " + versionString()) + return nil + } + if example { + fmt.Print(exampleConfig) + return nil + } + cfg, err := config.Load(configPath, os.Environ()) if err != nil { return err } + for _, w := range cfg.Warnings { + log.Printf("dbrest: warning: %s", w) + } + + if dumpConfig { + fmt.Print(cfg.Dump()) + return nil + } + if ready { + return probeReady(cfg) + } be, err := openBackend(cfg) if err != nil { @@ -47,50 +96,276 @@ func run() error { } defer func() { _ = be.Close() }() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - model, err := be.Introspect(ctx) - cancel() - if err != nil { + if dumpSchema { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + model, err := be.Introspect(ctx) + if err != nil { + return fmt.Errorf("introspect: %w", err) + } + out, err := json.MarshalIndent(map[string]any{"relations": model.Relations()}, "", " ") + if err != nil { + return err + } + fmt.Println(string(out)) + return nil + } + + metrics := adminapi.NewMetrics(cfg.DBPool) + a := &app{cfgPath: configPath, be: be, cfg: cfg, metrics: metrics} + if err := a.reloadSchema(); err != nil { return fmt.Errorf("introspect: %w", err) } + a.watchSignals() + a.watchDBChannel(context.Background()) - srv := httpapi.NewServer(be, model, cfg.Schemas) - srv.SetDefaultRole(cfg.AnonRole) - srv.SetOpenAPI(cfg.OpenAPIMode, cfg.OpenAPIServerProxyURI) - if err := attachAuth(srv, cfg); err != nil { - return err + if cfg.AdminEnabled() { + startAdmin(cfg, be, a, metrics) } - log.Printf("dbrest listening on %s (backend %s, %d relations)", cfg.ServerAddr(), cfg.Backend, model.Len()) - if err := http.ListenAndServe(cfg.ServerAddr(), srv); err != nil { + ln, err := listenAPI(cfg) + if err != nil { + where := cfg.ServerAddr() + if cfg.ServerUnixSocket != "" { + where = cfg.ServerUnixSocket + } + return fmt.Errorf("listen on %s: %w", where, err) + } + log.Printf("dbrest listening on %s (backend %s, %d relations)", ln.Addr(), cfg.Backend, a.Model().Len()) + logged := &requestLogger{next: a, level: a.logLevel, out: log.Default()} + if err := http.Serve(ln, logged); err != nil { return fmt.Errorf("serve: %w", err) } return nil } +// listenAPI binds the API listener. With server-unix-socket set the socket +// replaces TCP entirely, the upstream behavior: a stale socket file from a +// previous run is removed, the socket is bound, and server-unix-socket-mode +// (already validated at load) is applied to it. +func listenAPI(cfg *config.Config) (net.Listener, error) { + if cfg.ServerUnixSocket == "" { + return listenFirst(cfg.Listeners()) + } + if err := os.Remove(cfg.ServerUnixSocket); err != nil && !errors.Is(err, fs.ErrNotExist) { + return nil, err + } + ln, err := net.Listen("unix", cfg.ServerUnixSocket) + if err != nil { + return nil, err + } + mode, err := strconv.ParseUint(cfg.ServerUnixSocketMode, 8, 32) + if err != nil { + _ = ln.Close() + return nil, fmt.Errorf("server-unix-socket-mode: %w", err) + } + if err := os.Chmod(cfg.ServerUnixSocket, os.FileMode(mode)); err != nil { + _ = ln.Close() + return nil, err + } + return ln, nil +} + +// listenFirst binds the first candidate that works, in the preference order +// the host option encodes (the *4/*6 fallback story). +func listenFirst(specs []config.ListenSpec) (net.Listener, error) { + var firstErr error + for _, s := range specs { + ln, err := net.Listen(s.Network, s.Addr) + if err == nil { + return ln, nil + } + if firstErr == nil { + firstErr = err + } + } + return nil, firstErr +} + +// startAdmin runs the admin listener (admin-server-port) next to the API: the +// /live and /ready probes, the /schema_cache dump, and /metrics. The liveness +// check dials the API socket the way PostgREST's admin server does. +func startAdmin(cfg *config.Config, be backend.Backend, a *app, metrics *adminapi.Metrics) { + apiAddr := probeAddr(cfg.ServerHost, cfg.ServerPort) + admin := &adminapi.Server{ + Live: func(ctx context.Context) error { + d := net.Dialer{Timeout: time.Second} + conn, err := d.DialContext(ctx, "tcp", apiAddr) + if err != nil { + return err + } + return conn.Close() + }, + Ready: func(ctx context.Context) error { + // A backend that can check its connection exposes Ping; one that + // cannot (an embedded engine) is ready once the cache is loaded. + if p, ok := be.(interface{ Ping(context.Context) error }); ok { + return p.Ping(ctx) + } + return nil + }, + SchemaCache: func() ([]byte, error) { + return json.Marshal(map[string]any{"relations": a.Model().Relations()}) + }, + Metrics: metrics, + } + go func() { + ln, err := listenFirst(cfg.AdminListeners()) + if err != nil { + log.Printf("dbrest: admin server: listen on %s: %v", cfg.AdminAddr(), err) + return + } + log.Printf("dbrest admin listening on %s", ln.Addr()) + if err := http.Serve(ln, admin); err != nil { + log.Printf("dbrest: admin server: %v", err) + } + }() +} + +// probeAddr is the address the liveness probe dials. A wildcard listen host is +// not dialable as written, so the probe goes through loopback. +func probeAddr(host string, port int) string { + switch host { + case "", "0.0.0.0", "*", "*4", "!4": + host = "127.0.0.1" + case "::", "*6", "!6": + host = "::1" + } + return net.JoinHostPort(host, fmt.Sprint(port)) +} + // openBackend opens the engine the configuration selected. // Each backend driver self-registers via its package init function; this file // imports them as blank imports so their init functions run. func openBackend(cfg *config.Config) (backend.Backend, error) { - be, err := backend.Open(cfg.Backend, cfg.DBURI) + prepared := cfg.DBPreparedStatements + be, err := backend.OpenWith(cfg.Backend, cfg.DBURI, backend.OpenOptions{PreparedStatements: &prepared}) if err != nil { return nil, fmt.Errorf("open database: %w", err) } + applyPoolConfig(be, cfg) + applySchemaConfig(be, cfg) + // Wire declared function registry for backends that cannot discover + // functions from an engine catalog (NativeRPC=false: SQLite, MySQL, …). + if cfg.FunctionRegistry != "" { + reg, err := rpc.ParseRegistry(cfg.FunctionRegistry) + if err != nil { + return nil, fmt.Errorf("function-registry: %w", err) + } + if r, ok := be.(interface{ Register(rpc.Registry) }); ok { + r.Register(reg) + } + } + return be, nil +} + +// attachPreRequest wires the db-pre-request option. The function name rides the +// request context so the backend can invoke it after the session settings and +// before the main statement (spec 13). A backend that cannot honor it must not +// silently drop the option, since deployments use db-pre-request for blocking +// and custom auth; with no backend support declared, startup is refused. +func attachPreRequest(srv *httpapi.Server, be backend.Backend, cfg *config.Config) error { + if cfg.PreRequest == "" { + return nil + } + if pr, ok := be.(interface{ SupportsPreRequest() bool }); ok && pr.SupportsPreRequest() { + srv.SetPreRequest(cfg.PreRequest) + return nil + } + return fmt.Errorf("db-pre-request: the %s backend cannot run a pre-request function; unset the option", cfg.Backend) +} + +// attachAuthz wires the emulated authorization layer from the policy-registry +// option. With no registry configured the gate stays off, which mirrors a +// database where every role holds every privilege; declaring a registry flips +// the model, and from then on the absence of a grant is a denial. A registry +// the parser cannot fully understand is a startup error, never a silently +// thinner rule set. Postgres delegates privileges and RLS to the engine, so a +// registry configured there is a misconfiguration and is refused too. +func attachAuthz(srv *httpapi.Server, cfg *config.Config) error { + if cfg.PolicyRegistry == "" { + return nil + } + if cfg.Backend == "postgres" { + return fmt.Errorf("policy-registry: the postgres backend enforces grants and RLS natively; manage them in the database and unset the option") + } + reg, err := authz.ParseRegistry(cfg.PolicyRegistry) + if err != nil { + return err + } + srv.SetAuthz(reg) + return nil +} + +// applySchemaConfig pushes the schema-shaped options onto a backend that +// accepts them: the exposed schemas and db-extra-search-path, which extends +// type and function resolution without exposing the schemas. It runs at open +// and again on a config reload. Backends that have no schema notion ignore +// both by not implementing the setters. +func applySchemaConfig(be any, cfg *config.Config) { if sc, ok := be.(interface{ SetSchemas([]string) }); ok { sc.SetSchemas(cfg.Schemas) } - return be, nil + if sp, ok := be.(interface{ SetExtraSearchPath([]string) }); ok { + sp.SetExtraSearchPath(cfg.ExtraSearchPath) + } + if h, ok := be.(interface{ SetHoistedTxSettings([]string) }); ok { + h.SetHoistedTxSettings(cfg.HoistedTxSettings) + } +} + +// applyPoolConfig sizes the connection pool on the engines built over +// database/sql (mysql, sqlserver). SQLite is left alone: its backend pins the +// pool to one connection so the foreign-key PRAGMA stays in effect, and +// resizing or recycling that connection would silently drop FK enforcement. +// The pgx-based postgres backend builds its pool inside Open and the +// acquisition timeout has no database/sql knob; both stay with the backend +// drivers as the per-driver remainder of the pool item. +func applyPoolConfig(be backend.Backend, cfg *config.Config) { + if cfg.Backend == config.BackendSQLite { + return + } + d, ok := be.(interface{ DB() *sql.DB }) + if !ok { + return + } + db := d.DB() + db.SetMaxOpenConns(cfg.DBPool) + db.SetConnMaxIdleTime(cfg.DBPoolMaxIdleTime) + db.SetConnMaxLifetime(cfg.DBPoolMaxLifetime) } -// attachAuth wires a JWT verifier onto the server when a key is configured. -// With no key material the server runs every request as the static anon role, -// which is the PostgREST behavior for an unconfigured jwt-secret. +// jwtSecretBytes returns the key material configured in jwt-secret. With +// jwt-secret-is-base64 set, the value is URL-safe base64 (padding optional) +// and an undecodable value is a boot error; silently keying the verifier with +// the wrong bytes would lock every valid token out. +func jwtSecretBytes(cfg *config.Config) ([]byte, error) { + if !cfg.JWTSecretIsBase64 { + return []byte(cfg.JWTSecret), nil + } + trimmed := strings.TrimRight(strings.TrimSpace(cfg.JWTSecret), "=") + b, err := base64.RawURLEncoding.DecodeString(trimmed) + if err != nil { + return nil, fmt.Errorf("jwt-secret-is-base64 is set but jwt-secret is not valid URL-safe base64: %w", err) + } + return b, nil +} + +// attachAuth wires a JWT verifier onto the server. The verifier is always +// attached so the server fails closed the way PostgREST does: with no key +// material a presented token is a 500 PGRST300, and with no anon role a +// tokenless request is a 401 PGRST302. The jwt-secret value is read the +// PostgREST way (JWK Set, JWK, or text secret), base64-decoded first when +// jwt-secret-is-base64 is set, and an unusable key configuration is a startup +// error. func attachAuth(srv *httpapi.Server, cfg *config.Config) error { - if cfg.JWTSecret == "" && cfg.JWKSet == "" { - return nil + secret, err := jwtSecretBytes(cfg) + if err != nil { + return err } v, err := auth.NewVerifier(auth.Config{ - Secret: []byte(cfg.JWTSecret), + Secret: secret, + JWKSet: cfg.JWKSet, Audience: cfg.JWTAud, RoleClaimKey: cfg.JWTRoleClaimKey, AnonRole: cfg.AnonRole, diff --git a/cmd/dbrest/main_test.go b/cmd/dbrest/main_test.go new file mode 100644 index 0000000..c39cb3c --- /dev/null +++ b/cmd/dbrest/main_test.go @@ -0,0 +1,112 @@ +package main + +import ( + "strings" + "testing" + + "github.com/tamnd/dbrest/backend" + "github.com/tamnd/dbrest/backend/sqlite" + "github.com/tamnd/dbrest/config" + "github.com/tamnd/dbrest/httpapi" +) + +// openTestBackend opens an in-memory SQLite backend for the wiring tests. +func openTestBackend(t *testing.T) backend.Backend { + t.Helper() + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = be.Close() }) + return be +} + +// preRequestBackend declares pre-request support over a real backend, standing +// in for a driver that runs the function inside its transaction. +type preRequestBackend struct{ backend.Backend } + +func (preRequestBackend) SupportsPreRequest() bool { return true } + +func TestAttachPreRequestNoopWhenUnset(t *testing.T) { + be := openTestBackend(t) + srv := httpapi.NewServer(be, nil, nil) + if err := attachPreRequest(srv, be, &config.Config{Backend: "sqlite"}); err != nil { + t.Fatalf("attachPreRequest with no option = %v, want nil", err) + } +} + +func TestAttachPreRequestRefusesUnsupportedBackend(t *testing.T) { + // No backend declares pre-request support yet, so a configured + // db-pre-request must refuse startup rather than silently drop the + // function (deployments use it for blocking and custom auth). + be := openTestBackend(t) + srv := httpapi.NewServer(be, nil, nil) + cfg := &config.Config{Backend: "sqlite", PreRequest: "api.check_request"} + err := attachPreRequest(srv, be, cfg) + if err == nil { + t.Fatal("attachPreRequest = nil, want a startup refusal on a backend without pre-request support") + } + if !strings.Contains(err.Error(), "db-pre-request") { + t.Errorf("error %q does not name the db-pre-request option", err) + } +} + +func TestAttachPreRequestAcceptsSupportingBackend(t *testing.T) { + be := preRequestBackend{openTestBackend(t)} + srv := httpapi.NewServer(be, nil, nil) + cfg := &config.Config{Backend: "sqlite", PreRequest: "api.check_request"} + if err := attachPreRequest(srv, be, cfg); err != nil { + t.Fatalf("attachPreRequest on a supporting backend = %v, want nil", err) + } +} + +func TestAttachAuthzNoopWhenUnset(t *testing.T) { + be := openTestBackend(t) + srv := httpapi.NewServer(be, nil, nil) + if err := attachAuthz(srv, &config.Config{Backend: "sqlite"}); err != nil { + t.Fatalf("attachAuthz with no registry = %v, want nil", err) + } +} + +func TestAttachAuthzWiresParsedRegistry(t *testing.T) { + be := openTestBackend(t) + srv := httpapi.NewServer(be, nil, nil) + cfg := &config.Config{Backend: "sqlite", PolicyRegistry: `{ + "grants": [{"role": "web_user", "relation": "todos", "actions": ["select"]}] + }`} + if err := attachAuthz(srv, cfg); err != nil { + t.Fatalf("attachAuthz with a valid registry = %v, want nil", err) + } +} + +func TestAttachAuthzRefusesBadRegistry(t *testing.T) { + // The registry is the security boundary on the emulated backends, so a + // declaration the parser cannot fully understand must stop the boot. + be := openTestBackend(t) + srv := httpapi.NewServer(be, nil, nil) + cfg := &config.Config{Backend: "sqlite", PolicyRegistry: `{"grants": [{"role": "r"}]}`} + err := attachAuthz(srv, cfg) + if err == nil { + t.Fatal("attachAuthz = nil, want a startup refusal on an unparseable registry") + } + if !strings.Contains(err.Error(), "policy-registry") { + t.Errorf("error %q does not name the policy-registry option", err) + } +} + +func TestAttachAuthzRefusesPostgres(t *testing.T) { + // Postgres enforces grants and RLS in the engine; a registry there would + // suggest a second enforcement layer that does not exist, so it is a + // misconfiguration rather than a silent no-op. + be := openTestBackend(t) + srv := httpapi.NewServer(be, nil, nil) + cfg := &config.Config{Backend: "postgres", PolicyRegistry: `{}`} + err := attachAuthz(srv, cfg) + if err == nil { + t.Fatal("attachAuthz = nil, want a refusal on the postgres backend") + } + if !strings.Contains(err.Error(), "policy-registry") { + t.Errorf("error %q does not name the policy-registry option", err) + } +} diff --git a/cmd/dbrest/pool_test.go b/cmd/dbrest/pool_test.go new file mode 100644 index 0000000..c15ac52 --- /dev/null +++ b/cmd/dbrest/pool_test.go @@ -0,0 +1,66 @@ +package main + +import ( + "database/sql" + "testing" + + "github.com/tamnd/dbrest/backend" + "github.com/tamnd/dbrest/config" +) + +func sqlPool(t *testing.T, be backend.Backend) *sql.DB { + t.Helper() + d, ok := be.(interface{ DB() *sql.DB }) + if !ok { + t.Fatal("backend does not expose its sql.DB") + } + return d.DB() +} + +// TestApplyPoolConfigSizesTheSQLPool checks the database/sql settings reach +// the pool. The sqlite driver is just a convenient pool carrier here; the +// config names another engine so the resize branch runs. +func TestApplyPoolConfigSizesTheSQLPool(t *testing.T) { + be, err := backend.Open("sqlite", ":memory:") + if err != nil { + t.Fatal(err) + } + defer be.Close() + + cfg, err := config.FromMap(map[string]string{ + "db-uri": "x", "db-backend": "mysql", + "db-pool": "7", "db-pool-max-idletime": "60", "db-pool-max-lifetime": "120", + }) + if err != nil { + t.Fatal(err) + } + applyPoolConfig(be, cfg) + + if got := sqlPool(t, be).Stats().MaxOpenConnections; got != 7 { + t.Errorf("MaxOpenConnections = %d, want the configured db-pool 7", got) + } +} + +// TestApplyPoolConfigLeavesSQLiteAlone pins the exemption: the sqlite backend +// runs on one pinned connection so its foreign-key PRAGMA holds, and the pool +// options must not resize it. +func TestApplyPoolConfigLeavesSQLiteAlone(t *testing.T) { + be, err := backend.Open("sqlite", ":memory:") + if err != nil { + t.Fatal(err) + } + defer be.Close() + + cfg, err := config.FromMap(map[string]string{"db-uri": "x", "db-pool": "7"}) + if err != nil { + t.Fatal(err) + } + if cfg.Backend != config.BackendSQLite { + t.Fatalf("default backend = %q, want sqlite", cfg.Backend) + } + applyPoolConfig(be, cfg) + + if got := sqlPool(t, be).Stats().MaxOpenConnections; got != 1 { + t.Errorf("MaxOpenConnections = %d, the sqlite single-connection pin was lost", got) + } +} diff --git a/cmd/dbrest/reload.go b/cmd/dbrest/reload.go new file mode 100644 index 0000000..09bfee3 --- /dev/null +++ b/cmd/dbrest/reload.go @@ -0,0 +1,228 @@ +// Reload plumbing: PostgREST re-reads its schema cache on SIGUSR1 and its +// configuration on SIGUSR2, without dropping the listener. dbrest does the +// same by keeping the HTTP frontend behind an atomic handler and rebuilding +// it from the new inputs; an in-flight request keeps the snapshot it started +// with. A failed reload logs and keeps serving with the previous state, the +// upstream behavior. The per-driver paths (LISTEN on db-channel, db-config, +// db-pre-config) live with each backend and are not wired here yet. +package main + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/tamnd/dbrest/adminapi" + "github.com/tamnd/dbrest/backend" + "github.com/tamnd/dbrest/config" + "github.com/tamnd/dbrest/httpapi" + "github.com/tamnd/dbrest/schema" +) + +// app owns the pieces a reload swaps: the configuration, the schema cache, +// and the frontend built from them. +type app struct { + cfgPath string + be backend.Backend + metrics *adminapi.Metrics + + mu sync.Mutex // serializes reloads; guards cfg and model + cfg *config.Config + model *schema.Model + + handler atomic.Value // always a *httpapi.Server +} + +func (a *app) ServeHTTP(w http.ResponseWriter, r *http.Request) { + a.handler.Load().(http.Handler).ServeHTTP(w, r) +} + +// Model is the schema cache currently being served. +func (a *app) Model() *schema.Model { + a.mu.Lock() + defer a.mu.Unlock() + return a.model +} + +// logLevel is the log-level currently in force; the request logger reads it +// per request so a config reload changes it live. +func (a *app) logLevel() string { + a.mu.Lock() + defer a.mu.Unlock() + return a.cfg.LogLevel +} + +// reloadSchema re-introspects the database and swaps in a frontend built on +// the fresh cache. It is both the boot-time load and the SIGUSR1 handler; on +// failure the old cache stays in service. +func (a *app) reloadSchema() error { + a.mu.Lock() + defer a.mu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + started := time.Now() + model, err := a.be.Introspect(ctx) + cancel() + a.metrics.ObserveSchemaCacheLoad(time.Since(started), err) + if err != nil { + return err + } + a.model = model + return a.rebuildLocked() +} + +// reloadConfig re-reads the configuration sources and applies the reloadable +// subset, logging every boot-time option whose change had to be ignored. It +// is the SIGUSR2 handler; a config that does not load leaves the old one in +// service. +func (a *app) reloadConfig(environ []string) error { + next, err := config.Load(a.cfgPath, environ) + if err != nil { + return err + } + a.mu.Lock() + defer a.mu.Unlock() + merged, kept := a.cfg.MergeReloadable(next) + for _, msg := range kept { + log.Printf("dbrest: reload: %s", msg) + } + for _, w := range merged.Warnings { + log.Printf("dbrest: warning: %s", w) + } + a.cfg = merged + applySchemaConfig(a.be, merged) + return a.rebuildLocked() +} + +// rebuildLocked builds the frontend from the current cfg and model and swaps +// it in. The caller holds a.mu. +func (a *app) rebuildLocked() error { + srv := httpapi.NewServer(a.be, a.model, a.cfg.Schemas) + srv.SetDefaultRole(a.cfg.AnonRole) + srv.SetOpenAPI(a.cfg.OpenAPIMode, a.cfg.OpenAPIServerProxyURI, a.cfg.OpenAPISecurityActive) + srv.SetRootSpec(a.cfg.RootSpec) + srv.SetCORSAllowedOrigins(a.cfg.CORSAllowedOrigins) + srv.SetMaxRows(a.cfg.MaxRows) + srv.SetMaxRequestBody(a.cfg.MaxRequestBody) + srv.SetServerTimingEnabled(a.cfg.ServerTimingEnabled) + srv.SetTxEnd(a.cfg.TxEnd) + srv.SetPlanEnabled(a.cfg.PlanEnabled) + srv.SetAggregatesEnabled(a.cfg.AggregatesEnabled) + srv.SetPreRequest(a.cfg.PreRequest) + srv.SetAppSettings(a.cfg.AppSettings) + srv.SetLogQuery(a.cfg.LogQuery) + if err := attachAuth(srv, a.cfg); err != nil { + return err + } + if err := attachPreRequest(srv, a.be, a.cfg); err != nil { + return err + } + if err := attachAuthz(srv, a.cfg); err != nil { + return err + } + a.handler.Store(srv) + return nil +} + +// reloadAction is what a db-channel notification asks the server to do. +type reloadAction int + +const ( + reloadNone reloadAction = iota + reloadActionSchema + reloadActionConfig +) + +// dbNotifyAction decodes a db-channel NOTIFY payload into a reload action, +// implementing PostgREST's contract: an empty payload or "reload schema" reloads +// the schema cache, "reload config" reloads the configuration, and any other +// payload is ignored. +func dbNotifyAction(payload string) reloadAction { + switch payload { + case "", "reload schema": + return reloadActionSchema + case "reload config": + return reloadActionConfig + default: + return reloadNone + } +} + +// handleDBNotify applies the reload a db-channel payload asks for. Like the +// signal handlers, a failed reload logs and keeps the previous state. +func (a *app) handleDBNotify(payload string) { + switch dbNotifyAction(payload) { + case reloadActionSchema: + log.Printf("dbrest: db-channel notification, reloading the schema cache") + if err := a.reloadSchema(); err != nil { + log.Printf("dbrest: schema cache reload failed, keeping the old cache: %v", err) + } + case reloadActionConfig: + log.Printf("dbrest: db-channel notification, reloading the configuration") + if err := a.reloadConfig(os.Environ()); err != nil { + log.Printf("dbrest: config reload failed, keeping the old config: %v", err) + } + } +} + +// watchDBChannel starts PostgREST's db-channel listener when db-channel-enabled +// is set and the backend can listen. A NOTIFY drives reloads through +// handleDBNotify; a reconnect refreshes the schema cache because notifications +// sent while the listener was down are lost. A backend with no LISTEN support is +// silently skipped, leaving signal-driven reloads in place. +func (a *app) watchDBChannel(ctx context.Context) { + a.mu.Lock() + enabled := a.cfg.DBChannelEnabled + channel := a.cfg.DBChannel + a.mu.Unlock() + if !enabled { + return + } + l, ok := a.be.(backend.Listener) + if !ok { + log.Printf("dbrest: backend does not support db-channel; reloads are signal-driven only") + return + } + h := backend.ListenHandler{ + OnNotify: a.handleDBNotify, + OnReconnect: func() { + log.Printf("dbrest: db-channel %q reconnected, reloading the schema cache", channel) + if err := a.reloadSchema(); err != nil { + log.Printf("dbrest: schema cache reload failed, keeping the old cache: %v", err) + } + }, + } + go func() { + if err := l.Listen(ctx, channel, h); err != nil && ctx.Err() == nil { + log.Printf("dbrest: db-channel listener stopped: %v", err) + } + }() +} + +// watchSignals installs the two reload signals. Reload failures log and keep +// the previous state; they never terminate the process. +func (a *app) watchSignals() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2) + go func() { + for s := range ch { + switch s { + case syscall.SIGUSR1: + log.Printf("dbrest: received SIGUSR1, reloading the schema cache") + if err := a.reloadSchema(); err != nil { + log.Printf("dbrest: schema cache reload failed, keeping the old cache: %v", err) + } + case syscall.SIGUSR2: + log.Printf("dbrest: received SIGUSR2, reloading the configuration") + if err := a.reloadConfig(os.Environ()); err != nil { + log.Printf("dbrest: config reload failed, keeping the old config: %v", err) + } + } + } + }() +} diff --git a/cmd/dbrest/reload_test.go b/cmd/dbrest/reload_test.go new file mode 100644 index 0000000..3703006 --- /dev/null +++ b/cmd/dbrest/reload_test.go @@ -0,0 +1,224 @@ +package main + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/tamnd/dbrest/adminapi" + "github.com/tamnd/dbrest/backend/sqlite" + "github.com/tamnd/dbrest/config" +) + +// newApp boots an app over an in-memory sqlite with one table, the way run() +// does, minus the listeners. +func newApp(t *testing.T, cfg *config.Config) *app { + t.Helper() + dsn := "file:reload_" + t.Name() + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { be.Close() }) + if _, err := be.DB().Exec(`CREATE TABLE films (id INTEGER PRIMARY KEY, title TEXT); + INSERT INTO films (title) VALUES ('Metropolis'), ('Sunrise');`); err != nil { + t.Fatalf("seed: %v", err) + } + a := &app{be: be, cfg: cfg, metrics: adminapi.NewMetrics(cfg.DBPool)} + if err := a.reloadSchema(); err != nil { + t.Fatalf("initial load: %v", err) + } + return a +} + +func mustConfig(t *testing.T, raw map[string]string) *config.Config { + t.Helper() + if raw["db-uri"] == "" { + raw["db-uri"] = "x" + } + cfg, err := config.FromMap(raw) + if err != nil { + t.Fatal(err) + } + return cfg +} + +func status(t *testing.T, a *app, target string) int { + t.Helper() + rec := httptest.NewRecorder() + a.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, target, nil)) + return rec.Code +} + +// TestReloadSchemaPicksUpNewTable is the SIGUSR1 path: a table created after +// boot is 404 until the schema cache reload, then served. +func TestReloadSchemaPicksUpNewTable(t *testing.T) { + a := newApp(t, mustConfig(t, map[string]string{"db-anon-role": "web_anon"})) + + if got := status(t, a, "/directors"); got != http.StatusNotFound { + t.Fatalf("before reload: status = %d, want 404", got) + } + if _, err := a.be.(*sqlite.Backend).DB().Exec(`CREATE TABLE directors (id INTEGER PRIMARY KEY, name TEXT)`); err != nil { + t.Fatal(err) + } + if err := a.reloadSchema(); err != nil { + t.Fatalf("reload: %v", err) + } + if got := status(t, a, "/directors"); got != http.StatusOK { + t.Errorf("after reload: status = %d, want 200", got) + } +} + +// TestReloadConfigAppliesReloadableSubset is the SIGUSR2 path: a new max-rows +// takes effect on the next request, while the request keeps flowing. +func TestReloadConfigAppliesReloadableSubset(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "dbrest.conf") + write := func(body string) { + if err := os.WriteFile(path, []byte(body), 0o644); err != nil { + t.Fatal(err) + } + } + write(`db-uri = "x"` + "\n" + `db-anon-role = "web_anon"` + "\n") + cfg, err := config.Load(path, nil) + if err != nil { + t.Fatal(err) + } + a := newApp(t, cfg) + a.cfgPath = path + + rec := httptest.NewRecorder() + a.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/films", nil)) + if got := rec.Header().Get("Content-Range"); got != "0-1/*" { + t.Fatalf("before reload: Content-Range = %q, want 0-1/*", got) + } + + write(`db-uri = "x"` + "\n" + `db-anon-role = "web_anon"` + "\n" + `db-max-rows = 1` + "\n") + if err := a.reloadConfig(nil); err != nil { + t.Fatalf("reload: %v", err) + } + rec = httptest.NewRecorder() + a.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/films", nil)) + if got := rec.Header().Get("Content-Range"); got != "0-0/*" { + t.Errorf("after reload: Content-Range = %q, want 0-0/* (max-rows applied)", got) + } +} + +// TestReloadConfigKeepsServingOnBadFile checks the failure mode: a config that +// no longer loads is rejected and the old one stays in service. +func TestReloadConfigKeepsServingOnBadFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "dbrest.conf") + if err := os.WriteFile(path, []byte(`db-uri = "x"`+"\n"+`db-anon-role = "web_anon"`+"\n"), 0o644); err != nil { + t.Fatal(err) + } + cfg, err := config.Load(path, nil) + if err != nil { + t.Fatal(err) + } + a := newApp(t, cfg) + a.cfgPath = path + + if err := os.WriteFile(path, []byte(`db-tx-end = "explode"`+"\n"), 0o644); err != nil { + t.Fatal(err) + } + if err := a.reloadConfig(nil); err == nil { + t.Fatal("expected the bad config to be rejected") + } + if got := status(t, a, "/films"); got != http.StatusOK { + t.Errorf("after failed reload: status = %d, want 200", got) + } +} + +// TestDBNotifyAction covers PostgREST's db-channel payload contract: an empty +// payload and "reload schema" both reload the schema cache, "reload config" +// reloads the configuration, and anything else is ignored. +func TestDBNotifyAction(t *testing.T) { + cases := []struct { + payload string + want reloadAction + }{ + {"", reloadActionSchema}, + {"reload schema", reloadActionSchema}, + {"reload config", reloadActionConfig}, + {"reload everything", reloadNone}, + {"RELOAD SCHEMA", reloadNone}, // case-sensitive, matching upstream + {" reload schema ", reloadNone}, + } + for _, c := range cases { + if got := dbNotifyAction(c.payload); got != c.want { + t.Errorf("dbNotifyAction(%q) = %d, want %d", c.payload, got, c.want) + } + } +} + +// TestHandleDBNotifyReloadsSchema is the db-channel schema path: a table created +// after boot is 404 until a "reload schema" notification (and equally an empty +// payload), then served, the same effect as SIGUSR1. +func TestHandleDBNotifyReloadsSchema(t *testing.T) { + a := newApp(t, mustConfig(t, map[string]string{"db-anon-role": "web_anon"})) + + if got := status(t, a, "/directors"); got != http.StatusNotFound { + t.Fatalf("before notify: status = %d, want 404", got) + } + if _, err := a.be.(*sqlite.Backend).DB().Exec(`CREATE TABLE directors (id INTEGER PRIMARY KEY, name TEXT)`); err != nil { + t.Fatal(err) + } + a.handleDBNotify("reload schema") + if got := status(t, a, "/directors"); got != http.StatusOK { + t.Errorf("after notify: status = %d, want 200", got) + } +} + +// TestHandleDBNotifyReloadsConfig is the db-channel config path: a "reload +// config" notification applies the reloadable subset, the same effect as +// SIGUSR2. +func TestHandleDBNotifyReloadsConfig(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "dbrest.conf") + write := func(body string) { + if err := os.WriteFile(path, []byte(body), 0o644); err != nil { + t.Fatal(err) + } + } + write(`db-uri = "x"` + "\n" + `db-anon-role = "web_anon"` + "\n") + cfg, err := config.Load(path, nil) + if err != nil { + t.Fatal(err) + } + a := newApp(t, cfg) + a.cfgPath = path + + write(`db-uri = "x"` + "\n" + `db-anon-role = "web_anon"` + "\n" + `db-max-rows = 1` + "\n") + a.handleDBNotify("reload config") + + rec := httptest.NewRecorder() + a.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/films", nil)) + if got := rec.Header().Get("Content-Range"); got != "0-0/*" { + t.Errorf("after notify: Content-Range = %q, want 0-0/* (max-rows applied)", got) + } +} + +// TestWatchDBChannelSkipsBackendWithoutListener confirms a backend that cannot +// LISTEN (sqlite) is skipped without blocking or panicking, leaving signal-only +// reloads. The call returns immediately because no goroutine is spawned. +func TestWatchDBChannelSkipsBackendWithoutListener(t *testing.T) { + a := newApp(t, mustConfig(t, map[string]string{"db-anon-role": "web_anon"})) + a.watchDBChannel(context.Background()) // sqlite has no Listener; must not block +} + +// TestSchemaReloadFailureKeepsOldCache mirrors upstream: when re-introspection +// fails the old cache keeps serving. +func TestSchemaReloadFailureKeepsOldCache(t *testing.T) { + a := newApp(t, mustConfig(t, map[string]string{})) + a.be.(*sqlite.Backend).Close() + if err := a.reloadSchema(); err == nil { + t.Skip("introspection on a closed handle did not fail; nothing to assert") + } + if a.Model() == nil || a.Model().Len() == 0 { + t.Error("old schema cache was dropped on a failed reload") + } +} diff --git a/cmd/dbrest/schemaconfig_test.go b/cmd/dbrest/schemaconfig_test.go new file mode 100644 index 0000000..80bfddc --- /dev/null +++ b/cmd/dbrest/schemaconfig_test.go @@ -0,0 +1,42 @@ +package main + +import ( + "slices" + "testing" + + "github.com/tamnd/dbrest/config" +) + +// schemaRecorder is a minimal stand-in for a backend that accepts both +// schema-shaped setters. +type schemaRecorder struct { + schemas []string + extra []string +} + +func (s *schemaRecorder) SetSchemas(v []string) { s.schemas = v } +func (s *schemaRecorder) SetExtraSearchPath(v []string) { s.extra = v } + +// TestApplySchemaConfig checks both options reach a backend that accepts +// them, and that a backend without the setters is simply left alone. +func TestApplySchemaConfig(t *testing.T) { + cfg, err := config.FromMap(map[string]string{ + "db-uri": "x", + "db-schemas": "api,private", + "db-extra-search-path": "extensions,util", + }) + if err != nil { + t.Fatal(err) + } + + rec := &schemaRecorder{} + applySchemaConfig(rec, cfg) + if !slices.Equal(rec.schemas, []string{"api", "private"}) { + t.Errorf("schemas = %v", rec.schemas) + } + if !slices.Equal(rec.extra, []string{"extensions", "util"}) { + t.Errorf("extra search path = %v", rec.extra) + } + + applySchemaConfig(struct{}{}, cfg) // must not panic on a bare backend +} diff --git a/cmd/dbrest/socket_test.go b/cmd/dbrest/socket_test.go new file mode 100644 index 0000000..47273a3 --- /dev/null +++ b/cmd/dbrest/socket_test.go @@ -0,0 +1,101 @@ +package main + +import ( + "context" + "io" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/tamnd/dbrest/config" +) + +// shortSocketPath returns a socket path short enough for sun_path. t.TempDir +// can exceed the limit on macOS, so this uses a fresh small temp dir. +func shortSocketPath(t *testing.T) string { + t.Helper() + dir, err := os.MkdirTemp("", "dbrest") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.RemoveAll(dir) }) + return filepath.Join(dir, "api.sock") +} + +// TestListenAPIUnixSocket covers the server-unix-socket listener: it replaces +// TCP, gets the configured mode, survives a stale socket file from a previous +// run, and serves HTTP. +func TestListenAPIUnixSocket(t *testing.T) { + path := shortSocketPath(t) + // A stale file at the path must not block the bind. + if err := os.WriteFile(path, nil, 0o600); err != nil { + t.Fatal(err) + } + + cfg, err := config.FromMap(map[string]string{ + "db-uri": "x", "server-unix-socket": path, "server-unix-socket-mode": "600", + }) + if err != nil { + t.Fatal(err) + } + ln, err := listenAPI(cfg) + if err != nil { + t.Fatalf("listenAPI: %v", err) + } + defer ln.Close() + if ln.Addr().Network() != "unix" { + t.Fatalf("network = %q, want unix", ln.Addr().Network()) + } + + fi, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if perm := fi.Mode().Perm(); perm != 0o600 { + t.Errorf("socket mode = %o, want 600", perm) + } + + srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "over the socket") + })} + go srv.Serve(ln) + defer srv.Close() + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", path) + }, + }, + Timeout: 2 * time.Second, + } + resp, err := client.Get("http://unix/anything") + if err != nil { + t.Fatalf("GET over the socket: %v", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if string(body) != "over the socket" { + t.Errorf("body = %q", body) + } +} + +// TestListenAPITCPDefault keeps the TCP path intact when no socket is set. +func TestListenAPITCPDefault(t *testing.T) { + cfg, err := config.FromMap(map[string]string{"db-uri": "x", "server-port": "0"}) + if err != nil { + t.Fatal(err) + } + ln, err := listenAPI(cfg) + if err != nil { + t.Fatalf("listenAPI: %v", err) + } + defer ln.Close() + if ln.Addr().Network() != "tcp" { + t.Errorf("network = %q, want tcp", ln.Addr().Network()) + } +} diff --git a/compat/auth_v14_test.go b/compat/auth_v14_test.go new file mode 100644 index 0000000..ddd9a17 --- /dev/null +++ b/compat/auth_v14_test.go @@ -0,0 +1,156 @@ +// Auth wire-compat cases against PostgREST v14: the PGRST301/302/303 code +// assignments, the WWW-Authenticate challenges, and the claim validation +// behavior. Each case is sent to both servers and the status, the JSON error +// envelope, and the WWW-Authenticate header must agree byte for byte. +// +// The servers come from the same compose stacks as compat_test.go and share +// the jwt-secret "reallyreallyreallyreallyverysafe"; tokens are minted here so +// the time claims are relative to the test run. +package compat + +import ( + "net/http" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// compatSecret is the jwt-secret both compose stacks are configured with. +var compatSecret = []byte("reallyreallyreallyreallyverysafe") + +// mintHS signs an HS256 token with the shared compat secret. +func mintHS(t *testing.T, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + s, err := tok.SignedString(compatSecret) + if err != nil { + t.Fatalf("sign: %v", err) + } + return s +} + +// authCase is one auth wire comparison: the request is sent to both servers +// and status, JSON body, and WWW-Authenticate must match across them. +type authCase struct { + name string + method string + path string + token string // Authorization: Bearer when non-empty + header map[string]string + + wantStatus int // when > 0 both servers must return exactly this +} + +// runAuthCases drives the cross-server comparison for a case list. +func runAuthCases(t *testing.T, cases []authCase) { + t.Helper() + pgrest, dbrest := urls(t) + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + headers := map[string]string{} + for k, v := range c.header { + headers[k] = v + } + if c.token != "" { + headers["Authorization"] = "Bearer " + c.token + } + cc := compatCase{method: c.method, path: c.path, headers: headers} + pgResp := doRequest(t, pgrest, cc) + dbResp := doRequest(t, dbrest, cc) + + if pgResp.status != dbResp.status { + t.Errorf("status: postgrest=%d dbrest=%d", pgResp.status, dbResp.status) + } + if c.wantStatus != 0 && dbResp.status != c.wantStatus { + t.Errorf("dbrest status = %d, want %d", dbResp.status, c.wantStatus) + } + pgWWW := pgResp.header.Get("WWW-Authenticate") + dbWWW := dbResp.header.Get("WWW-Authenticate") + if pgWWW != dbWWW { + t.Errorf("WWW-Authenticate: postgrest=%q dbrest=%q", pgWWW, dbWWW) + } + if dbResp.status >= 400 { + compareJSON(t, pgResp, dbResp) + } + }) + } +} + +// The group-3 code assignments and the WWW-Authenticate surface (item 03.1): +// PGRST301 for decode failures with per-cause messages, PGRST303 for claim +// validation failures, the invalid_token challenge on both. +func TestV14AuthErrorSurface(t *testing.T) { + expired := mintHS(t, jwt.MapClaims{ + "role": "web_user", + "exp": time.Now().Add(-time.Hour).Unix(), + }) + notYet := mintHS(t, jwt.MapClaims{ + "role": "web_user", + "nbf": time.Now().Add(time.Hour).Unix(), + }) + good := mintHS(t, jwt.MapClaims{ + "role": "web_user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + badSig := good[:len(good)-2] + "qq" + + runAuthCases(t, []authCase{ + {name: "expired token is 401 PGRST303", method: http.MethodGet, path: "/todos", + token: expired, wantStatus: 401}, + {name: "not-yet-valid token is 401 PGRST303", method: http.MethodGet, path: "/todos", + token: notYet, wantStatus: 401}, + {name: "one-part token reports the part count", method: http.MethodGet, path: "/todos", + token: "garbage", wantStatus: 401}, + {name: "two-part token reports the part count", method: http.MethodGet, path: "/todos", + token: "two.parts", wantStatus: 401}, + {name: "bad signature is 401 PGRST301", method: http.MethodGet, path: "/todos", + token: badSig, wantStatus: 401}, + {name: "valid token reads fine", method: http.MethodGet, path: "/todos", + token: good, wantStatus: 200}, + }) +} + +// The claim validation surface (item 03.5): iat is validated with skew, type +// errors carry their own PGRST303 messages, and a token without aud (or with a +// foreign aud, since neither stack configures jwt-aud) is accepted. +func TestV14ClaimValidation(t *testing.T) { + iatFuture := mintHS(t, jwt.MapClaims{ + "role": "web_user", + "iat": time.Now().Add(time.Hour).Unix(), + }) + expString := mintHS(t, jwt.MapClaims{"role": "web_user", "exp": "soon"}) + iatString := mintHS(t, jwt.MapClaims{"role": "web_user", "iat": "x"}) + foreignAud := mintHS(t, jwt.MapClaims{"role": "web_user", "aud": "other"}) + emptyAud := mintHS(t, jwt.MapClaims{"role": "web_user", "aud": []string{}}) + + runAuthCases(t, []authCase{ + {name: "future iat is 401 PGRST303", method: http.MethodGet, path: "/todos", + token: iatFuture, wantStatus: 401}, + {name: "non-number exp is a type error", method: http.MethodGet, path: "/todos", + token: expString, wantStatus: 401}, + {name: "non-number iat is a type error", method: http.MethodGet, path: "/todos", + token: iatString, wantStatus: 401}, + {name: "foreign aud passes with no jwt-aud configured", method: http.MethodGet, path: "/todos", + token: foreignAud, wantStatus: 200}, + {name: "empty aud array passes", method: http.MethodGet, path: "/todos", + token: emptyAud, wantStatus: 200}, + }) +} + +// The token edge cases (item 03.7): a bearer scheme with an empty token is a +// 401 PGRST301, never a silent anon downgrade, while credentials of another +// scheme are not a token at all and run anonymous. The non-string role claim +// half of the item is covered in auth's unit tests: its observable error +// ("role ... does not exist") comes from the engine's role catalog, which the +// emulated stack does not have. +func TestV14TokenEdgeCases(t *testing.T) { + runAuthCases(t, []authCase{ + {name: "empty bearer is 401 PGRST301", method: http.MethodGet, path: "/todos", + header: map[string]string{"Authorization": "Bearer"}, wantStatus: 401}, + {name: "blank bearer is 401 PGRST301", method: http.MethodGet, path: "/todos", + header: map[string]string{"Authorization": "Bearer "}, wantStatus: 401}, + {name: "basic credentials run anonymous", method: http.MethodGet, path: "/todos", + header: map[string]string{"Authorization": "Basic d2ViX3VzZXI6cHc="}, wantStatus: 200}, + }) +} diff --git a/compat/compat_test.go b/compat/compat_test.go index 2a64778..aff23fc 100644 --- a/compat/compat_test.go +++ b/compat/compat_test.go @@ -106,6 +106,13 @@ type compatCase struct { // for responses whose status code depends on non-deterministic state (e.g. // planner statistics that differ between two independent database instances). skipStatusMatch bool + + // cleanupPath: if non-empty, a DELETE is issued to this path on the target + // server immediately before the case runs. Both servers share one database, + // so an insert into a UNIQUE-constrained column would collide on the second + // server's turn; clearing the row first lets each server's write be tested + // against a clean slot. + cleanupPath string } // All test cases, grouped by the compat matrix sections. @@ -225,30 +232,38 @@ var cases = []compatCase{ bodyMode: "schema"}, // ── Group 10: Inserts ───────────────────────────────────────────────── + // todos.task is UNIQUE in the shared rig DB; cleanupPath clears each row so + // both servers insert into a clean slot (a no-op on the two-DB docker rig). {name: "10.1 insert minimal 201", method: "POST", path: "/todos", - headers: map[string]string{"Content-Type": "application/json"}, - body: `{"task":"compat insert minimal"}`, - wantStatus: 201, bodyMode: "empty"}, + headers: map[string]string{"Content-Type": "application/json"}, + body: `{"task":"compat insert minimal"}`, + cleanupPath: "/todos?task=eq.compat%20insert%20minimal", + wantStatus: 201, bodyMode: "empty"}, {name: "10.2 insert return=representation 201", method: "POST", path: "/todos", - headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=representation"}, - body: `{"task":"compat insert repr"}`, - wantStatus: 201, bodyMode: "schema"}, + headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=representation"}, + body: `{"task":"compat insert repr"}`, + cleanupPath: "/todos?task=eq.compat%20insert%20repr", + wantStatus: 201, bodyMode: "schema"}, {name: "10.3 insert return=headers-only 201", method: "POST", path: "/todos", - headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=headers-only"}, - body: `{"task":"compat insert headers-only"}`, - wantStatus: 201, bodyMode: "empty", + headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=headers-only"}, + body: `{"task":"compat insert headers-only"}`, + cleanupPath: "/todos?task=eq.compat%20insert%20headers-only", + wantStatus: 201, bodyMode: "empty", wantLocationPrefix: "/todos?id=eq."}, {name: "10.4 bulk insert 201", method: "POST", path: "/todos", - headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=representation"}, - body: `[{"task":"bulk a"},{"task":"bulk b"}]`, - wantStatus: 201, bodyMode: "schema"}, + headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=representation"}, + body: `[{"task":"bulk a"},{"task":"bulk b"}]`, + cleanupPath: "/todos?task=like.bulk%20*", + wantStatus: 201, bodyMode: "schema"}, {name: "10.5 insert missing=default 201", method: "POST", path: "/todos", - headers: map[string]string{"Content-Type": "application/json", "Prefer": "missing=default,return=representation"}, - body: `{"task":"compat missing default"}`, - wantStatus: 201, bodyMode: "schema"}, + headers: map[string]string{"Content-Type": "application/json", "Prefer": "missing=default,return=representation"}, + body: `{"task":"compat missing default"}`, + cleanupPath: "/todos?task=eq.compat%20missing%20default", + wantStatus: 201, bodyMode: "schema"}, {name: "10.6 insert Location header", method: "POST", path: "/todos", headers: map[string]string{"Content-Type": "application/json"}, body: `{"task":"compat location test"}`, + cleanupPath: "/todos?task=eq.compat%20location%20test", wantStatus: 201, bodyMode: "empty", wantLocationPrefix: "/todos?id=eq."}, @@ -306,9 +321,10 @@ var cases = []compatCase{ body: `{"id":1,"task":"put upsert","done":false}`, wantStatus: 200, bodyMode: "schema"}, {name: "13.5 PUT upsert new 201", method: "PUT", path: "/todos?id=eq.9999", - headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=representation"}, - body: `{"id":9999,"task":"new via put","done":false}`, - wantStatus: 201, bodyMode: "schema"}, + headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=representation"}, + body: `{"id":9999,"task":"new via put","done":false}`, + cleanupPath: "/todos?id=eq.9999", + wantStatus: 201, bodyMode: "schema"}, // ── Group 15: tx=rollback ───────────────────────────────────────────── {name: "15.1 tx=rollback insert", method: "POST", path: "/todos", @@ -327,9 +343,10 @@ var cases = []compatCase{ headers: map[string]string{"Content-Type": "application/json"}, body: `{}`, wantStatus: 200, bodyMode: "status"}, {name: "16.3 volatile fn POST", method: "POST", path: "/rpc/add_todo", - headers: map[string]string{"Content-Type": "application/json"}, - body: `{"task":"rpc insert"}`, - wantStatus: 200, bodyMode: "status"}, + headers: map[string]string{"Content-Type": "application/json"}, + body: `{"task":"rpc insert"}`, + cleanupPath: "/todos?task=eq.rpc%20insert", + wantStatus: 200, bodyMode: "status"}, {name: "16.4 volatile tx=rollback", method: "POST", path: "/rpc/add_todo", headers: map[string]string{"Content-Type": "application/json", "Prefer": "tx=rollback"}, body: `{"task":"rpc rollback"}`, @@ -365,7 +382,9 @@ var cases = []compatCase{ body: `{"task":"no content type"}`, // PostgREST v14.13 infers JSON when Content-Type is absent and body looks // like JSON; it no longer rejects with 415. Test that both servers agree. - bodyMode: "status"}, + // todos.task is UNIQUE in the shared rig DB, so clear the row first. + cleanupPath: "/todos?task=eq.no%20content%20type", + bodyMode: "status"}, {name: "20.3 bad value for int col 400", method: "GET", path: "/todos?id=eq.notanint", wantStatus: 400, bodyMode: "status"}, {name: "20.4 bad operator 400", method: "GET", path: "/todos?id=badop.1", @@ -395,6 +414,7 @@ var cases = []compatCase{ {name: "22.1 pref-applied return=representation", method: "POST", path: "/todos", headers: map[string]string{"Content-Type": "application/json", "Prefer": "return=representation"}, body: `{"task":"pref applied test"}`, + cleanupPath: "/todos?task=eq.pref%20applied%20test", wantStatus: 201, bodyMode: "schema", wantPrefApplied: "return=representation"}, @@ -766,6 +786,13 @@ type response struct { func doRequest(t *testing.T, base string, c compatCase) response { t.Helper() + if c.cleanupPath != "" { + req, _ := http.NewRequest(http.MethodDelete, base+c.cleanupPath, nil) + if resp, err := http.DefaultClient.Do(req); err == nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + } var bodyReader io.Reader if c.body != "" { bodyReader = strings.NewReader(c.body) diff --git a/compat/config_v14_test.go b/compat/config_v14_test.go new file mode 100644 index 0000000..7f5f94d --- /dev/null +++ b/compat/config_v14_test.go @@ -0,0 +1,117 @@ +// HTTP-level checks for the v14 configuration surface (review items 05.x). +// Unlike the main suite, these start an in-process dbrest built from the +// current tree, so the behavior under test is the working copy's, and compare +// it against a live PostgREST when one is reachable. +package compat + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/tamnd/dbrest/backend/sqlite" + "github.com/tamnd/dbrest/httpapi" +) + +// localDBREST starts an in-process dbrest over a seeded sqlite database and +// returns its base URL. The schema mirrors the todos table of the compat seed +// closely enough for header-level comparisons. +func localDBREST(t *testing.T) (*httptest.Server, *httpapi.Server) { + t.Helper() + dsn := "file:compat_" + t.Name() + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { be.Close() }) + if _, err := be.DB().Exec(` + CREATE TABLE todos (id INTEGER PRIMARY KEY, task TEXT, done BOOLEAN, due TIMESTAMP); + INSERT INTO todos (id, task, done) VALUES (1, 'do laundry', 0); + `); err != nil { + t.Fatalf("seed: %v", err) + } + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + api := httpapi.NewServer(be, model, nil) + // Mirror the live PostgREST rig's db-anon-role=web_anon so tokenless reads + // run as anon rather than failing closed with 401. + api.SetDefaultRole("web_anon") + ts := httptest.NewServer(api) + t.Cleanup(ts.Close) + return ts, api +} + +// livePostgREST returns the base URL of a reachable PostgREST, or skips. +func livePostgREST(t *testing.T) string { + t.Helper() + base := envOr("COMPAT_POSTGREST_URL", "http://localhost:3000") + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if !pingOK(ctx, base) { + t.Skipf("PostgREST not reachable at %s; set COMPAT_POSTGREST_URL or start docker/postgrest/compose.yaml", base) + } + return base +} + +// corsHeaders are the response headers compared between the two servers. +var corsHeaders = []string{ + "Access-Control-Allow-Origin", + "Access-Control-Allow-Credentials", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Headers", + "Access-Control-Max-Age", + "Access-Control-Expose-Headers", +} + +// TestV14CORSPreflight compares the default preflight answer (item 05.2) +// against a live PostgREST: wildcard origin, the full method list, the +// requested headers reflected, and the one-day max age. +func TestV14CORSPreflight(t *testing.T) { + pgrest := livePostgREST(t) + local, _ := localDBREST(t) + + c := compatCase{ + method: "OPTIONS", path: "/todos", + headers: map[string]string{ + "Origin": "http://example.com", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Foo,Bar", + }, + } + pg := doRequest(t, pgrest, c) + db := doRequest(t, local.URL, c) + if pg.status != db.status { + t.Errorf("preflight status: postgrest %d, dbrest %d", pg.status, db.status) + } + for _, h := range corsHeaders { + if pgv, dbv := pg.header.Get(h), db.header.Get(h); pgv != dbv { + t.Errorf("preflight %s: postgrest %q, dbrest %q", h, pgv, dbv) + } + } +} + +// TestV14CORSSimpleRequest compares the cross-origin headers on a plain read +// (item 05.2): wildcard origin plus the exposed-headers list. +func TestV14CORSSimpleRequest(t *testing.T) { + pgrest := livePostgREST(t) + local, _ := localDBREST(t) + + c := compatCase{ + method: "GET", path: "/todos", + headers: map[string]string{"Origin": "http://example.com"}, + } + pg := doRequest(t, pgrest, c) + db := doRequest(t, local.URL, c) + if pg.status != http.StatusOK || db.status != http.StatusOK { + t.Fatalf("status: postgrest %d, dbrest %d", pg.status, db.status) + } + for _, h := range []string{"Access-Control-Allow-Origin", "Access-Control-Expose-Headers"} { + if pgv, dbv := pg.header.Get(h), db.header.Get(h); pgv != dbv { + t.Errorf("%s: postgrest %q, dbrest %q", h, pgv, dbv) + } + } +} diff --git a/compat/errors_v14_test.go b/compat/errors_v14_test.go new file mode 100644 index 0000000..719b634 --- /dev/null +++ b/compat/errors_v14_test.go @@ -0,0 +1,111 @@ +// PostgREST v14 error-vocabulary conformance checks (review item series 04.x). +// These run only when both a live PostgREST and a live dbrest are reachable, +// using the same harness as compat_test.go. +package compat + +import ( + "encoding/json" + "net/http" + "testing" +) + +// errEnvelope is the four-key PostgREST error body. +type errEnvelope struct { + Code string `json:"code"` + Message string `json:"message"` + Details json.RawMessage `json:"details"` + Hint json.RawMessage `json:"hint"` +} + +func decodeEnvelope(t *testing.T, r response) errEnvelope { + t.Helper() + var e errEnvelope + if err := json.Unmarshal(r.body, &e); err != nil { + t.Fatalf("error body is not a JSON envelope: %v: %s", err, r.body) + } + return e +} + +// TestSingularEnvelope compares the PGRST116 envelope byte-for-byte between +// the servers: v14 says "Cannot coerce the result to a single JSON object" +// with the row count in details (review item 04.3). +func TestSingularEnvelope(t *testing.T) { + pgrest, dbrest := urls(t) + for _, c := range []compatCase{ + {name: "singular zero rows", method: "GET", path: "/todos?id=eq.999999", + headers: map[string]string{"Accept": "application/vnd.pgrst.object+json"}}, + {name: "singular many rows", method: "GET", path: "/todos?id=lte.2", + headers: map[string]string{"Accept": "application/vnd.pgrst.object+json"}}, + } { + t.Run(c.name, func(t *testing.T) { + pgResp := doRequest(t, pgrest, c) + dbResp := doRequest(t, dbrest, c) + if pgResp.status != http.StatusNotAcceptable || dbResp.status != http.StatusNotAcceptable { + t.Errorf("status: postgrest=%d dbrest=%d, want 406", pgResp.status, dbResp.status) + } + compareJSON(t, pgResp, dbResp) + }) + } +} + +// TestProxyStatusOnErrors checks that every error response names its code in +// the Proxy-Status header the way v14 does ("PostgREST; error=PGRST205"), +// which is how a HEAD request identifies the failure (review item 04.11). +func TestProxyStatusOnErrors(t *testing.T) { + pgrest, dbrest := urls(t) + for _, c := range []compatCase{ + {name: "unknown table", method: "GET", path: "/definitely_not_a_table"}, + {name: "head unknown table", method: "HEAD", path: "/definitely_not_a_table"}, + {name: "singular zero rows", method: "GET", path: "/todos?id=eq.999999", + headers: map[string]string{"Accept": "application/vnd.pgrst.object+json"}}, + } { + t.Run(c.name, func(t *testing.T) { + pgResp := doRequest(t, pgrest, c) + dbResp := doRequest(t, dbrest, c) + pgPS := pgResp.header.Get("Proxy-Status") + dbPS := dbResp.header.Get("Proxy-Status") + if pgPS == "" || pgPS != dbPS { + t.Errorf("Proxy-Status: postgrest=%q dbrest=%q", pgPS, dbPS) + } + }) + } + + // A successful response carries no Proxy-Status. + ok := doRequest(t, dbrest, compatCase{method: "GET", path: "/todos?id=eq.1"}) + if ps := ok.header.Get("Proxy-Status"); ps != "" { + t.Errorf("Proxy-Status on success = %q, want absent", ps) + } +} + +// TestContentTypeContract locks the request Content-Type error contract +// (review item 04.1 task 4). The published v14 error table still carries a +// stale PGRST107/415 row for an invalid request Content-Type; live v14 +// actually answers 400 PGRST102 "Content-Type not acceptable: ", which +// this probe verified against a running PostgREST. The probe pins the live +// behavior on both servers so a regression on either side is caught. +func TestContentTypeContract(t *testing.T) { + pgrest, dbrest := urls(t) + c := compatCase{ + name: "unsupported request content-type", + method: "POST", + path: "/todos", + headers: map[string]string{ + "Content-Type": "application/yaml", + }, + body: "task: write tests", + } + + for name, base := range map[string]string{"postgrest": pgrest, "dbrest": dbrest} { + resp := doRequest(t, base, c) + env := decodeEnvelope(t, resp) + if resp.status != http.StatusBadRequest { + t.Errorf("%s status = %d, want 400", name, resp.status) + } + if env.Code != "PGRST102" { + t.Errorf("%s code = %q, want PGRST102", name, env.Code) + } + if want := "Content-Type not acceptable: application/yaml"; env.Message != want { + t.Errorf("%s message = %q, want %q", name, env.Message, want) + } + } +} diff --git a/compat/openapi_v14_test.go b/compat/openapi_v14_test.go new file mode 100644 index 0000000..39d9830 --- /dev/null +++ b/compat/openapi_v14_test.go @@ -0,0 +1,469 @@ +// openapi_v14_test.go holds the v14 conformance tests for the OpenAPI root and +// the schema-profile machinery (audit topic 06): profile negotiation and +// PGRST106, root content negotiation, the document shape, and the schema +// cache. Each test runs against both live servers with the same harness as +// compat_test.go and asserts the exact v14 wire behavior, verified against +// PostgREST v14 directly. +package compat + +import ( + "encoding/json" + "net/http" + "strings" + "testing" +) + +// errBody is the PostgREST error envelope. +type errBody struct { + Code string `json:"code"` + Message string `json:"message"` + Hint string `json:"hint"` + Details any `json:"details"` +} + +func decodeErr(t *testing.T, body []byte) errBody { + t.Helper() + var e errBody + if err := json.Unmarshal(body, &e); err != nil { + t.Fatalf("error body is not JSON: %v\n%s", err, body) + } + return e +} + +// onBoth runs fn once per live server, as a subtest named for it. +func onBoth(t *testing.T, fn func(t *testing.T, base string)) { + pgrest, dbrest := urls(t) + for name, base := range map[string]string{"postgrest": pgrest, "dbrest": dbrest} { + t.Run(name, func(t *testing.T) { fn(t, base) }) + } +} + +// ── 06.1 profile headers and the active schema ───────────────────────────── + +func TestProfileUnknownSchemaGET(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/todos", + headers: map[string]string{"Accept-Profile": "nonexistent"}}) + if res.status != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406\n%s", res.status, res.body) + } + e := decodeErr(t, res.body) + if e.Code != "PGRST106" { + t.Errorf("code = %q, want PGRST106", e.Code) + } + if e.Message != "Invalid schema: nonexistent" { + t.Errorf("message = %q, want %q", e.Message, "Invalid schema: nonexistent") + } + if e.Hint != "Only the following schemas are exposed: api, private" { + t.Errorf("hint = %q, want the exposed-schema list", e.Hint) + } + if h := res.header.Get("Content-Profile"); h != "" { + t.Errorf("Content-Profile = %q on an error, want unset", h) + } + }) +} + +func TestProfileUnknownSchemaPOST(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "POST", path: "/todos", + headers: map[string]string{"Content-Profile": "nope", "Content-Type": "application/json"}, + body: "{}"}) + if res.status != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406\n%s", res.status, res.body) + } + e := decodeErr(t, res.body) + if e.Code != "PGRST106" || e.Message != "Invalid schema: nope" { + t.Errorf("got %q %q, want PGRST106 / Invalid schema: nope", e.Code, e.Message) + } + }) +} + +// A write reads Content-Profile, never Accept-Profile: a bogus Accept-Profile +// on a DELETE is ignored. +func TestProfileWriteIgnoresAcceptProfile(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "DELETE", path: "/todos?id=eq.999999", + headers: map[string]string{"Accept-Profile": "nonexistent"}}) + if res.status != http.StatusNoContent { + t.Fatalf("status = %d, want 204 (Accept-Profile ignored on DELETE)\n%s", res.status, res.body) + } + }) +} + +func TestProfileSelectsSchemaAndEchoes(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/items", + headers: map[string]string{"Accept-Profile": "private"}}) + if res.status != http.StatusOK { + t.Fatalf("status = %d, want 200\n%s", res.status, res.body) + } + if h := res.header.Get("Content-Profile"); h != "private" { + t.Errorf("Content-Profile = %q, want private", h) + } + }) +} + +// With no profile header on a multi-schema deployment the first exposed schema +// is active and is echoed in Content-Profile. +func TestProfileDefaultSchemaEchoed(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/todos"}) + if res.status != http.StatusOK { + t.Fatalf("status = %d, want 200\n%s", res.status, res.body) + } + if h := res.header.Get("Content-Profile"); h != "api" { + t.Errorf("Content-Profile = %q, want api (first exposed schema)", h) + } + }) +} + +// A failed request carries no Content-Profile even when the profile was valid. +func TestProfileNotEchoedOnError(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/no_such_table", + headers: map[string]string{"Accept-Profile": "api"}}) + if res.status != http.StatusNotFound { + t.Fatalf("status = %d, want 404\n%s", res.status, res.body) + } + if h := res.header.Get("Content-Profile"); h != "" { + t.Errorf("Content-Profile = %q on an error, want unset", h) + } + }) +} + +// The root document is scoped to the active schema: under Accept-Profile: +// private it describes private's relations, not api's. +func TestRootScopedToActiveSchema(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/", + headers: map[string]string{"Accept-Profile": "private"}}) + if res.status != http.StatusOK { + t.Fatalf("status = %d, want 200\n%s", res.status, res.body) + } + if h := res.header.Get("Content-Profile"); h != "private" { + t.Errorf("Content-Profile = %q, want private", h) + } + var doc struct { + Paths map[string]json.RawMessage `json:"paths"` + } + if err := json.Unmarshal(res.body, &doc); err != nil { + t.Fatalf("root is not JSON: %v", err) + } + if _, ok := doc.Paths["/items"]; !ok { + t.Errorf("paths lack /items; private schema not described: %v", pathKeys(doc.Paths)) + } + if _, ok := doc.Paths["/todos"]; ok { + t.Errorf("paths include /todos from the api schema; root not scoped: %v", pathKeys(doc.Paths)) + } + }) +} + +func pathKeys(m map[string]json.RawMessage) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} + +// ── 06.2 root content negotiation and charset ────────────────────────────── + +func TestRootContentTypeCharset(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + for _, accept := range []string{"", "application/json", "application/openapi+json", "*/*"} { + c := compatCase{method: "GET", path: "/"} + if accept != "" { + c.headers = map[string]string{"Accept": accept} + } + res := doRequest(t, base, c) + if res.status != http.StatusOK { + t.Fatalf("Accept %q: status = %d, want 200", accept, res.status) + } + if ct := res.header.Get("Content-Type"); ct != "application/openapi+json; charset=utf-8" { + t.Errorf("Accept %q: Content-Type = %q, want application/openapi+json; charset=utf-8", accept, ct) + } + } + }) +} + +func TestRootUnacceptableAcceptIs406(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/", + headers: map[string]string{"Accept": "text/csv"}}) + if res.status != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406\n%s", res.status, res.body) + } + e := decodeErr(t, res.body) + if e.Code != "PGRST107" { + t.Errorf("code = %q, want PGRST107", e.Code) + } + if e.Message != "None of these media types are available: text/csv" { + t.Errorf("message = %q, want the requested type echoed", e.Message) + } + }) +} + +// A path segment is a bare name inside the active schema, never a qualified +// reference into another one. +func TestDottedPathDoesNotEscapeSchema(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/private.items"}) + if res.status != http.StatusNotFound { + t.Fatalf("status = %d, want 404 (no cross-schema escape)\n%s", res.status, res.body) + } + if !strings.Contains(string(res.body), "PGRST205") { + t.Errorf("body = %s, want PGRST205", res.body) + } + }) +} + +// ── 06.5 root verb handling ──────────────────────────────────────────────── + +// A verb the root does not serve is 405 PGRST117 naming the method. +func TestRootUnsupportedVerb(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + for _, method := range []string{"DELETE", "PATCH", "PUT", "TRACE"} { + res := doRequest(t, base, compatCase{method: method, path: "/"}) + if res.status != http.StatusMethodNotAllowed { + t.Fatalf("%s /: status = %d, want 405\n%s", method, res.status, res.body) + } + e := decodeErr(t, res.body) + if e.Code != "PGRST117" { + t.Errorf("%s /: code = %q, want PGRST117", method, e.Code) + } + if e.Message != "Unsupported HTTP method: "+method { + t.Errorf("%s /: message = %q", method, e.Message) + } + } + }) +} + +// OPTIONS on the root answers 200 with the verb set in Allow and no body. +func TestRootOptionsAllow(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "OPTIONS", path: "/"}) + if res.status != http.StatusOK { + t.Fatalf("status = %d, want 200\n%s", res.status, res.body) + } + if allow := res.header.Get("Allow"); allow != "OPTIONS,GET,HEAD" { + t.Errorf("Allow = %q, want OPTIONS,GET,HEAD", allow) + } + if len(res.body) != 0 { + t.Errorf("body = %q, want empty", res.body) + } + }) +} + +// ── 06.6 document shape ──────────────────────────────────────────────────── + +// TestRootDocumentFraming pins the framing both servers must share: the v14 +// info defaults, the externalDocs pointer, the vendor media types, and the "/" +// entry describing the document itself. +func TestRootDocumentFraming(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/"}) + if res.status != http.StatusOK { + t.Fatalf("status = %d, want 200", res.status) + } + var doc struct { + Info struct { + Title string `json:"title"` + Description string `json:"description"` + } `json:"info"` + ExternalDocs struct { + Description string `json:"description"` + URL string `json:"url"` + } `json:"externalDocs"` + Consumes []string `json:"consumes"` + Produces []string `json:"produces"` + Paths map[string]map[string]struct { + Summary string `json:"summary"` + Tags []string `json:"tags"` + Produces []string `json:"produces"` + } `json:"paths"` + } + if err := json.Unmarshal(res.body, &doc); err != nil { + t.Fatalf("decode: %v", err) + } + if doc.Info.Title != "PostgREST API" { + t.Errorf("info title = %q", doc.Info.Title) + } + if doc.Info.Description != "This is a dynamic API generated by PostgREST" { + t.Errorf("info description = %q", doc.Info.Description) + } + if doc.ExternalDocs.URL != "https://postgrest.org/en/v14/references/api.html" { + t.Errorf("externalDocs url = %q", doc.ExternalDocs.URL) + } + if doc.ExternalDocs.Description != "PostgREST Documentation" { + t.Errorf("externalDocs description = %q", doc.ExternalDocs.Description) + } + want := []string{ + "application/json", + "application/vnd.pgrst.object+json;nulls=stripped", + "application/vnd.pgrst.object+json", + "text/csv", + } + if strings.Join(doc.Consumes, " ") != strings.Join(want, " ") { + t.Errorf("consumes = %v", doc.Consumes) + } + if strings.Join(doc.Produces, " ") != strings.Join(want, " ") { + t.Errorf("produces = %v", doc.Produces) + } + root, ok := doc.Paths["/"] + if !ok { + t.Fatal(`document lacks the "/" path entry`) + } + get := root["get"] + if get.Summary != "OpenAPI description (this document)" { + t.Errorf("root summary = %q", get.Summary) + } + if len(get.Tags) != 1 || get.Tags[0] != "Introspection" { + t.Errorf("root tags = %v", get.Tags) + } + if len(get.Produces) != 2 || get.Produces[0] != "application/openapi+json" || get.Produces[1] != "application/json" { + t.Errorf("root produces = %v", get.Produces) + } + }) +} + +// TestRootRelationLayout pins how a table is laid out: operations reference +// the shared rowFilter and reserved parameters, GET answers 200 with an +// array-of-definition schema plus a 206, writes carry the v14 single-status +// responses, and the primary key carries the pk marker in its definition. +func TestRootRelationLayout(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/"}) + if res.status != http.StatusOK { + t.Fatalf("status = %d, want 200", res.status) + } + var doc struct { + Paths map[string]map[string]struct { + Parameters []struct { + Ref string `json:"$ref"` + } `json:"parameters"` + Responses map[string]struct { + Description string `json:"description"` + Schema *struct { + Type string `json:"type"` + Items *struct { + Ref string `json:"$ref"` + } `json:"items"` + } `json:"schema"` + } `json:"responses"` + } `json:"paths"` + Parameters map[string]struct { + Name string `json:"name"` + In string `json:"in"` + Required *bool `json:"required"` + Default string `json:"default"` + } `json:"parameters"` + Definitions map[string]struct { + Properties map[string]struct { + Description string `json:"description"` + } `json:"properties"` + Required []string `json:"required"` + } `json:"definitions"` + } + if err := json.Unmarshal(res.body, &doc); err != nil { + t.Fatalf("decode: %v", err) + } + todos, ok := doc.Paths["/todos"] + if !ok { + t.Fatal("document lacks /todos") + } + + // GET: rowFilter refs for the columns, then the fixed read block. + get := todos["get"] + readBlock := []string{ + "#/parameters/select", "#/parameters/order", "#/parameters/range", + "#/parameters/rangeUnit", "#/parameters/offset", "#/parameters/limit", + "#/parameters/preferCount", + } + if len(get.Parameters) < len(readBlock)+1 { + t.Fatalf("get parameters = %v", get.Parameters) + } + head := get.Parameters[:len(get.Parameters)-len(readBlock)] + for i, p := range head { + if !strings.HasPrefix(p.Ref, "#/parameters/rowFilter.todos.") { + t.Errorf("get parameter %d = %q, want a rowFilter.todos ref", i, p.Ref) + } + } + tail := get.Parameters[len(get.Parameters)-len(readBlock):] + for i, want := range readBlock { + if tail[i].Ref != want { + t.Errorf("get read block[%d] = %q, want %q", i, tail[i].Ref, want) + } + } + + // GET responses: 200 carries the array-of-definition schema, plus a 206. + ok200, present := get.Responses["200"] + if !present || ok200.Schema == nil || ok200.Schema.Type != "array" || + ok200.Schema.Items == nil || ok200.Schema.Items.Ref != "#/definitions/todos" { + t.Errorf("get 200 = %+v, want an array of #/definitions/todos", ok200) + } + if p206, present := get.Responses["206"]; !present || p206.Description != "Partial Content" { + t.Errorf("get 206 = %+v, want Partial Content", get.Responses["206"]) + } + + // Writes: POST 201 only; PATCH and DELETE 204 only. + for op, want := range map[string]string{"post": "201", "patch": "204", "delete": "204"} { + r := todos[op].Responses + if len(r) != 1 { + t.Errorf("%s responses = %v, want only %s", op, r, want) + } + if _, present := r[want]; !present { + t.Errorf("%s responses lack %s", op, want) + } + } + + // POST opens with the shared body parameter. + if post := todos["post"]; len(post.Parameters) == 0 || post.Parameters[0].Ref != "#/parameters/body.todos" { + t.Errorf("post parameters = %v, want body.todos first", post.Parameters) + } + body, present := doc.Parameters["body.todos"] + if !present || body.Name != "todos" || body.In != "body" { + t.Errorf("body.todos = %+v", body) + } + if body.Required == nil || *body.Required { + t.Errorf("body.todos required = %v, want explicit false", body.Required) + } + if ru, present := doc.Parameters["rangeUnit"]; !present || ru.Default != "items" { + t.Errorf("rangeUnit = %+v, want default items", doc.Parameters["rangeUnit"]) + } + + // The primary key column carries the v14 pk marker. + def, present := doc.Definitions["todos"] + if !present { + t.Fatal("definitions lack todos") + } + id, present := def.Properties["id"] + if !present || !strings.Contains(id.Description, "Note:\nThis is a Primary Key.") { + t.Errorf("todos.id description = %q, want the pk marker", id.Description) + } + if len(def.Required) == 0 { + t.Error("todos definition lists no required columns") + } + }) +} + +// TestRootSecurityInactiveByDefault pins the default openapi-security-active +// shape: with it off the document carries neither securityDefinitions nor a +// security requirement, even though both servers authenticate JWTs. +func TestRootSecurityInactiveByDefault(t *testing.T) { + onBoth(t, func(t *testing.T, base string) { + res := doRequest(t, base, compatCase{method: "GET", path: "/"}) + if res.status != http.StatusOK { + t.Fatalf("status = %d, want 200", res.status) + } + var doc map[string]json.RawMessage + if err := json.Unmarshal(res.body, &doc); err != nil { + t.Fatalf("decode: %v", err) + } + if _, ok := doc["securityDefinitions"]; ok { + t.Error("securityDefinitions should be absent by default") + } + if _, ok := doc["security"]; ok { + t.Error("security should be absent by default") + } + }) +} diff --git a/config/config.go b/config/config.go index 97512cb..7c536d1 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,9 @@ package config import ( "fmt" "maps" + "net" + "os" + "strconv" "strings" "time" ) @@ -51,6 +54,12 @@ var knownLogLevels = map[string]bool{ "crit": true, "error": true, "warn": true, "info": true, "debug": true, } +// Transaction termination modes (db-tx-end). +var knownTxEnds = map[string]bool{ + "commit": true, "commit-allow-override": true, + "rollback": true, "rollback-allow-override": true, +} + // Config is the resolved option set. Fields are grouped by the spec's sections. // A zero value is not valid; build one through Load, which applies defaults and // validates. @@ -60,38 +69,70 @@ type Config struct { DBURI string // Exposed surface (section 3). - Schemas []string - AnonRole string - PreRequest string - ExtraSearchPath []string - MaxRows int // 0 means no cap + Schemas []string + AnonRole string + PreRequest string + ExtraSearchPath []string + MaxRows int // 0 means no cap + AggregatesEnabled bool + RootSpec string + MaxRequestBody int // bytes; 0 means unlimited, matching PostgREST + + // Transaction behavior. + TxEnd string // commit / commit-allow-override / rollback / rollback-allow-override + HoistedTxSettings []string + + // Application settings forwarded to the backend as transaction settings + // (the app.settings.* namespace). Keys are stored without the prefix. + AppSettings map[string]string // Auth, a frontend concern identical on every backend (spec 13). JWTSecret string + JWTSecretIsBase64 bool JWTAud string JWTRoleClaimKey string JWKSet string JWTCacheMaxEntries int // Servers (section 5). - ServerHost string - ServerPort int - ServerUnixSocket string - AdminServerHost string - AdminServerPort int // 0 disables the admin server + ServerHost string + ServerPort int + ServerUnixSocket string + ServerUnixSocketMode string + AdminServerHost string + AdminServerPort int // 0 disables the admin server // Pooling and limits (section 7). DBPool int DBPoolAcquisitionTimeout time.Duration + DBPoolMaxIdleTime time.Duration + DBPoolMaxLifetime time.Duration + DBPoolAutomaticRecovery bool + + // Reload and in-database configuration. + DBChannel string + DBChannelEnabled bool + DBConfig bool + DBPreConfig string + DBPreparedStatements bool // OpenAPI (spec 19). OpenAPIMode string OpenAPIServerProxyURI string + OpenAPISecurityActive bool // Observability and CORS (section 8). - LogLevel string - LogQuery bool - CORSAllowedOrigins []string + LogLevel string + LogQuery bool + CORSAllowedOrigins []string + PlanEnabled bool + ServerTraceHeader string + ServerTimingEnabled bool + + // Warnings collected while loading: accepted-but-unenforced options, + // unknown keys, and risky postures. The command logs them at startup; + // none of them is fatal. + Warnings []string // dbrest-specific declared registries (section 4). Carried as raw text here; // each is parsed by the subsystem that consumes it (introspection, RPC, @@ -104,19 +145,47 @@ type Config struct { CapabilityOverrides string } +// defaultSchemas is the exposed-schema default for an unset db-schemas: the +// engine's natural namespace rather than a hardcoded value. PostgreSQL gets +// upstream's "public"; the engines whose default namespace is the connected +// database itself get the empty marker, matching their introspection contract +// of unqualified relations. SQLite is one of those: its main database maps to +// the unqualified namespace, and attached databases become named schemas when +// that subsystem lands (spec 08). +func defaultSchemas(backendName string) []string { + if backendName == BackendPostgres { + return []string{"public"} + } + return []string{""} +} + // defaults returns a Config carrying the unambiguous PostgREST defaults, before // the file and environment are layered on. func defaults() *Config { return &Config{ Backend: BackendSQLite, - Schemas: []string{""}, JWTRoleClaimKey: ".role", JWTCacheMaxEntries: 1000, - ServerHost: "0.0.0.0", + ServerHost: "!4", ServerPort: 3000, DBPool: 10, OpenAPIMode: OpenAPIFollowPrivileges, + ExtraSearchPath: []string{"public"}, LogLevel: "error", + TxEnd: "commit", + HoistedTxSettings: []string{ + "statement_timeout", "plan_filter.statement_cost_limit", + "default_transaction_isolation", + }, + DBChannel: "pgrst", + DBChannelEnabled: true, + DBConfig: true, + DBPreparedStatements: true, + DBPoolAcquisitionTimeout: 10 * time.Second, + DBPoolMaxIdleTime: 30 * time.Second, + DBPoolMaxLifetime: 1800 * time.Second, + DBPoolAutomaticRecovery: true, + ServerUnixSocketMode: "660", } } @@ -126,15 +195,22 @@ func defaults() *Config { // the PGRST_* and DBREST_* spellings are read, with DBREST_* winning. func Load(path string, environ []string) (*Config, error) { raw := map[string]string{} + var warnings []string if path != "" { - fileRaw, err := parseFile(path) + fileRaw, fileWarnings, err := parseFile(path) if err != nil { return nil, err } + warnings = append(warnings, fileWarnings...) maps.Copy(raw, fileRaw) } - overlayEnv(raw, environ) - return fromRaw(raw) + warnings = append(warnings, overlayEnv(raw, environ)...) + c, err := fromRaw(raw) + if err != nil { + return nil, err + } + c.Warnings = append(warnings, c.Warnings...) + return c, nil } // FromMap builds a Config from an already-merged option map, applying defaults @@ -157,29 +233,78 @@ func fromRaw(raw map[string]string) (*Config, error) { if v, ok := get("db-uri"); ok { c.DBURI = v } - if v, ok := get("db-schemas"); ok { - c.Schemas = splitList(v) + // PostgREST defaults db-uri to "postgresql://", an empty URI libpq fills + // from PGHOST/PGUSER/PGDATABASE and friends, so a bare server with PG* + // environment variables just works. Only the postgres backend can + // self-configure that way; every other engine keeps db-uri mandatory. + if strings.TrimSpace(c.DBURI) == "" && c.Backend == BackendPostgres { + c.DBURI = "postgresql://" + } + // An @path value loads the option from a file, the documented way to keep + // secrets out of config files. Upstream supports it for exactly two + // options: db-uri (trimmed of surrounding whitespace) and jwt-secret + // (one trailing newline chomped), with the path read relative to the + // working directory. + if path, ok := strings.CutPrefix(c.DBURI, "@"); ok { + if data, err := os.ReadFile(path); err != nil { + errs = append(errs, fmt.Sprintf("db-uri: reading %s: %v", path, err)) + } else { + c.DBURI = strings.TrimSpace(string(data)) + } + } + schemasSet := false + for _, key := range []string{"db-schemas", "db-schema"} { + if v, ok := get(key); ok { + c.Schemas = splitList(v) + schemasSet = true + break + } + } + if !schemasSet { + c.Schemas = defaultSchemas(c.Backend) } if v, ok := get("db-anon-role"); ok { c.AnonRole = v } - if v, ok := get("db-pre-request"); ok { - c.PreRequest = v - } + c.PreRequest = pickString(raw, c.PreRequest, "db-pre-request", "pre-request") if v, ok := get("db-extra-search-path"); ok { c.ExtraSearchPath = splitList(v) } c.MaxRows = pickInt(raw, &errs, c.MaxRows, "db-max-rows", "max-rows") + c.MaxRequestBody = pickInt(raw, &errs, c.MaxRequestBody, "max-request-body") + c.AggregatesEnabled = pickBool(raw, &errs, c.AggregatesEnabled, "db-aggregates-enabled") + c.RootSpec = pickString(raw, c.RootSpec, "db-root-spec", "root-spec") + + if v, ok := get("db-tx-end"); ok { + c.TxEnd = strings.ToLower(strings.TrimSpace(v)) + } + if v, ok := get("db-hoisted-tx-settings"); ok { + c.HoistedTxSettings = splitList(v) + } + for key, v := range raw { + if name, ok := strings.CutPrefix(key, "app.settings."); ok && name != "" { + if c.AppSettings == nil { + c.AppSettings = map[string]string{} + } + c.AppSettings[name] = v + } + } if v, ok := get("jwt-secret"); ok { c.JWTSecret = v } + if path, ok := strings.CutPrefix(c.JWTSecret, "@"); ok { + if data, err := os.ReadFile(path); err != nil { + errs = append(errs, fmt.Sprintf("jwt-secret: reading %s: %v", path, err)) + } else { + c.JWTSecret = strings.TrimSuffix(string(data), "\n") + } + } + c.JWTSecretIsBase64 = pickBool(raw, &errs, c.JWTSecretIsBase64, "jwt-secret-is-base64", "secret-is-base64") if v, ok := get("jwt-aud"); ok { c.JWTAud = v } - if v, ok := get("jwt-role-claim-key"); ok { - c.JWTRoleClaimKey = v - } + c.JWTRoleClaimKey = pickString(raw, c.JWTRoleClaimKey, "jwt-role-claim-key", "role-claim-key") if v, ok := get("jwk-set"); ok { c.JWKSet = v } @@ -192,13 +317,33 @@ func fromRaw(raw map[string]string) (*Config, error) { if v, ok := get("server-unix-socket"); ok { c.ServerUnixSocket = v } + if v, ok := get("server-unix-socket-mode"); ok { + c.ServerUnixSocketMode = strings.TrimSpace(v) + } if v, ok := get("admin-server-host"); ok { c.AdminServerHost = v } c.AdminServerPort = pickInt(raw, &errs, c.AdminServerPort, "admin-server-port") + if c.AdminServerHost == "" { + // Upstream defaults the admin host to the API host. + c.AdminServerHost = c.ServerHost + } c.DBPool = pickInt(raw, &errs, c.DBPool, "db-pool") - c.DBPoolAcquisitionTimeout = pickDuration(raw, &errs, c.DBPoolAcquisitionTimeout, "db-pool-acquisition-timeout") + c.DBPoolAcquisitionTimeout = pickSeconds(raw, &errs, c.DBPoolAcquisitionTimeout, "db-pool-acquisition-timeout") + c.DBPoolMaxIdleTime = pickSeconds(raw, &errs, c.DBPoolMaxIdleTime, "db-pool-max-idletime", "db-pool-timeout") + c.DBPoolMaxLifetime = pickSeconds(raw, &errs, c.DBPoolMaxLifetime, "db-pool-max-lifetime") + c.DBPoolAutomaticRecovery = pickBool(raw, &errs, c.DBPoolAutomaticRecovery, "db-pool-automatic-recovery") + + if v, ok := get("db-channel"); ok { + c.DBChannel = v + } + c.DBChannelEnabled = pickBool(raw, &errs, c.DBChannelEnabled, "db-channel-enabled") + c.DBConfig = pickBool(raw, &errs, c.DBConfig, "db-config") + if v, ok := get("db-pre-config"); ok { + c.DBPreConfig = v + } + c.DBPreparedStatements = pickBool(raw, &errs, c.DBPreparedStatements, "db-prepared-statements") if v, ok := get("openapi-mode"); ok { c.OpenAPIMode = strings.ToLower(strings.TrimSpace(v)) @@ -206,6 +351,7 @@ func fromRaw(raw map[string]string) (*Config, error) { if v, ok := get("openapi-server-proxy-uri"); ok { c.OpenAPIServerProxyURI = strings.TrimSpace(v) } + c.OpenAPISecurityActive = pickBool(raw, &errs, c.OpenAPISecurityActive, "openapi-security-active") if v, ok := get("log-level"); ok { c.LogLevel = strings.ToLower(strings.TrimSpace(v)) @@ -214,6 +360,22 @@ func fromRaw(raw map[string]string) (*Config, error) { if v, ok := get("server-cors-allowed-origins"); ok { c.CORSAllowedOrigins = splitList(v) } + c.PlanEnabled = pickBool(raw, &errs, c.PlanEnabled, "db-plan-enabled") + if v, ok := get("server-trace-header"); ok { + c.ServerTraceHeader = v + } + c.ServerTimingEnabled = pickBool(raw, &errs, c.ServerTimingEnabled, "server-timing-enabled") + + c.Warnings = append(c.Warnings, unenforcedWarnings(raw)...) + + // Anonymous access should be a choice, not an accident. With neither an + // anon role nor JWT key material every request runs anonymously with no + // role at all, so say so loudly at startup; upstream's docs treat this + // posture as something the operator confirms explicitly. + if c.AnonRole == "" && c.JWTSecret == "" && c.JWKSet == "" { + c.Warnings = append(c.Warnings, + "neither db-anon-role nor a JWT key (jwt-secret, jwk-set) is configured; every request will run anonymously with no role") + } c.DeclaredSchema = raw["declared-schema"] c.DeclaredRelationships = raw["declared-relationships"] @@ -250,19 +412,155 @@ func (c *Config) validate(errs *[]string) { if c.AdminServerPort < 0 || c.AdminServerPort > 65535 { *errs = append(*errs, fmt.Sprintf("admin-server-port %d is out of range", c.AdminServerPort)) } + if c.AdminServerPort != 0 && c.AdminServerPort == c.ServerPort { + *errs = append(*errs, "admin-server-port cannot be the same as server-port") + } if c.MaxRows < 0 { *errs = append(*errs, "db-max-rows must not be negative") } + if c.MaxRequestBody < 0 { + *errs = append(*errs, "max-request-body must not be negative") + } if c.JWTCacheMaxEntries < 0 { *errs = append(*errs, "jwt-cache-max-entries must not be negative") } + if len(c.Schemas) == 0 { + *errs = append(*errs, "db-schemas must name at least one schema") + } + if !knownTxEnds[c.TxEnd] { + *errs = append(*errs, fmt.Sprintf("db-tx-end %q is not one of commit/commit-allow-override/rollback/rollback-allow-override", c.TxEnd)) + } + if mode, err := strconv.ParseUint(c.ServerUnixSocketMode, 8, 32); err != nil { + *errs = append(*errs, fmt.Sprintf("server-unix-socket-mode %q is not an octal", c.ServerUnixSocketMode)) + } else if mode < 0o600 || mode > 0o777 { + *errs = append(*errs, fmt.Sprintf("server-unix-socket-mode %q needs to be between 600 and 777", c.ServerUnixSocketMode)) + } +} + +// unenforcedOptions are options dbrest parses for PostgREST compatibility but +// whose behavior has not landed yet. Setting one is accepted with a warning so +// a working postgrest.conf boots, but the operator is told the knob does not +// turn anything yet. An entry leaves this list when its subsystem ships. +var unenforcedOptions = []string{ + "db-aggregates-enabled", "db-channel", "db-channel-enabled", "db-config", + "db-extra-search-path", "db-hoisted-tx-settings", + "db-pool-acquisition-timeout", "db-pool-automatic-recovery", + "db-pre-config", "db-pre-request", "pre-request", + "db-prepared-statements", "db-root-spec", "root-spec", "db-tx-end", + "log-query", + "openapi-security-active", "server-trace-header", "server-timing-enabled", +} + +// unenforcedWarnings returns one warning per explicitly set option that parses +// but is not yet enforced. +func unenforcedWarnings(raw map[string]string) []string { + var out []string + for _, key := range unenforcedOptions { + if _, ok := raw[key]; ok { + out = append(out, fmt.Sprintf("option %s is accepted but not enforced yet", key)) + } + } + return out } -// ServerAddr is the API listen address in host:port form. +// MergeReloadable layers a freshly loaded configuration over the running one, +// the way PostgREST applies a SIGUSR2 reload: every option takes its new value +// except the ones fixed at boot (the connection, the pool, the listeners, and +// the function registry wired at backend open). The returned messages name +// each boot-time option whose new value had to be ignored, one log line per +// option. +func (c *Config) MergeReloadable(next *Config) (*Config, []string) { + merged := *next + var kept []string + note := func(name string, changed bool) { + if changed { + kept = append(kept, fmt.Sprintf("%s changed but cannot be reloaded; keeping the boot value", name)) + } + } + + note("db-backend", merged.Backend != c.Backend) + merged.Backend = c.Backend + note("db-uri", merged.DBURI != c.DBURI) + merged.DBURI = c.DBURI + note("db-pool", merged.DBPool != c.DBPool) + merged.DBPool = c.DBPool + note("db-pool-acquisition-timeout", merged.DBPoolAcquisitionTimeout != c.DBPoolAcquisitionTimeout) + merged.DBPoolAcquisitionTimeout = c.DBPoolAcquisitionTimeout + note("db-pool-max-idletime", merged.DBPoolMaxIdleTime != c.DBPoolMaxIdleTime) + merged.DBPoolMaxIdleTime = c.DBPoolMaxIdleTime + note("db-pool-max-lifetime", merged.DBPoolMaxLifetime != c.DBPoolMaxLifetime) + merged.DBPoolMaxLifetime = c.DBPoolMaxLifetime + note("db-pool-automatic-recovery", merged.DBPoolAutomaticRecovery != c.DBPoolAutomaticRecovery) + merged.DBPoolAutomaticRecovery = c.DBPoolAutomaticRecovery + note("server-host", merged.ServerHost != c.ServerHost) + merged.ServerHost = c.ServerHost + note("server-port", merged.ServerPort != c.ServerPort) + merged.ServerPort = c.ServerPort + note("server-unix-socket", merged.ServerUnixSocket != c.ServerUnixSocket) + merged.ServerUnixSocket = c.ServerUnixSocket + note("server-unix-socket-mode", merged.ServerUnixSocketMode != c.ServerUnixSocketMode) + merged.ServerUnixSocketMode = c.ServerUnixSocketMode + note("admin-server-host", merged.AdminServerHost != c.AdminServerHost) + merged.AdminServerHost = c.AdminServerHost + note("admin-server-port", merged.AdminServerPort != c.AdminServerPort) + merged.AdminServerPort = c.AdminServerPort + note("function-registry", merged.FunctionRegistry != c.FunctionRegistry) + merged.FunctionRegistry = c.FunctionRegistry + + return &merged, kept +} + +// ServerAddr is the API listen address in host:port form. With one of the +// special hosts the result is for display only; the listener is built from +// Listeners. func (c *Config) ServerAddr() string { return fmt.Sprintf("%s:%d", c.ServerHost, c.ServerPort) } +// ListenSpec is one candidate listener: the net.Listen network and address. +type ListenSpec struct { + Network string + Addr string +} + +// listenSpecs maps a host option to ordered listener candidates, implementing +// PostgREST's special values: * is any host on either stack, *4 and *6 prefer +// one stack and fall back to the other, !4 and !6 require their stack. Any +// other value is a literal address. The caller takes the first candidate that +// binds. +func listenSpecs(host string, port int) []ListenSpec { + p := strconv.Itoa(port) + switch host { + case "*": + return []ListenSpec{{"tcp", ":" + p}} + case "*4": + return []ListenSpec{{"tcp4", "0.0.0.0:" + p}, {"tcp6", "[::]:" + p}} + case "!4": + return []ListenSpec{{"tcp4", "0.0.0.0:" + p}} + case "*6": + return []ListenSpec{{"tcp6", "[::]:" + p}, {"tcp4", "0.0.0.0:" + p}} + case "!6": + return []ListenSpec{{"tcp6", "[::]:" + p}} + default: + return []ListenSpec{{"tcp", net.JoinHostPort(host, p)}} + } +} + +// Listeners are the API listener candidates, in preference order. Setting +// server-unix-socket replaces the TCP listener entirely, as it does upstream; +// the admin server stays on TCP either way. +func (c *Config) Listeners() []ListenSpec { + if c.ServerUnixSocket != "" { + return []ListenSpec{{"unix", c.ServerUnixSocket}} + } + return listenSpecs(c.ServerHost, c.ServerPort) +} + +// AdminListeners are the admin listener candidates, in preference order. +func (c *Config) AdminListeners() []ListenSpec { + return listenSpecs(c.AdminServerHost, c.AdminServerPort) +} + // AdminEnabled reports whether the admin server should run. func (c *Config) AdminEnabled() bool { return c.AdminServerPort != 0 } diff --git a/config/config_test.go b/config/config_test.go index 748ceeb..2f6ebf6 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -4,6 +4,7 @@ import ( "os" "path/filepath" "slices" + "strings" "testing" "time" ) @@ -43,6 +44,26 @@ func TestDefaultsApplied(t *testing.T) { } } +func TestOpenAPISecurityActiveParsed(t *testing.T) { + c, err := FromMap(map[string]string{"db-uri": "file:x.db"}) + if err != nil { + t.Fatal(err) + } + if c.OpenAPISecurityActive { + t.Error("openapi-security-active should default to false") + } + c, err = FromMap(map[string]string{"db-uri": "file:x.db", "openapi-security-active": "true"}) + if err != nil { + t.Fatal(err) + } + if !c.OpenAPISecurityActive { + t.Error("openapi-security-active = true not parsed") + } + if _, err = FromMap(map[string]string{"db-uri": "file:x.db", "openapi-security-active": "banana"}); err == nil { + t.Error("a non-boolean openapi-security-active should abort boot") + } +} + func TestDBURIRequired(t *testing.T) { _, err := FromMap(map[string]string{}) if err == nil { @@ -50,6 +71,34 @@ func TestDBURIRequired(t *testing.T) { } } +// TestDBURIDefaultsOnPostgres pins the upstream stock workflow: with the +// postgres backend an unset db-uri becomes "postgresql://", the empty URI the +// driver fills from the PG* environment. Every other engine keeps the hard +// requirement. +func TestDBURIDefaultsOnPostgres(t *testing.T) { + c, err := FromMap(map[string]string{"db-backend": "postgres"}) + if err != nil { + t.Fatalf("postgres without db-uri should boot: %v", err) + } + if c.DBURI != "postgresql://" { + t.Errorf("db-uri = %q, want postgresql://", c.DBURI) + } + + for _, be := range []string{"sqlite", "mysql", "sqlserver", "mongodb"} { + if _, err := FromMap(map[string]string{"db-backend": be}); err == nil { + t.Errorf("%s without db-uri should be rejected", be) + } + } + + c, err = FromMap(map[string]string{"db-backend": "postgres", "db-uri": "postgresql://u@h/db"}) + if err != nil { + t.Fatal(err) + } + if c.DBURI != "postgresql://u@h/db" { + t.Errorf("explicit db-uri lost: %q", c.DBURI) + } +} + func TestFileParsing(t *testing.T) { path := writeConf(t, ` # dbrest configuration @@ -111,10 +160,165 @@ reviews.film_id -> films.id } } -func TestUnknownOptionIsError(t *testing.T) { +func TestUnknownFileOptionWarnsAndBoots(t *testing.T) { + // PostgREST ignores config keys it does not own, so a postgrest.conf + // carrying someone else's keys must boot. dbrest keeps a warning so the + // typo is visible. path := writeConf(t, "db-uri = \"x\"\ndb-ury = \"typo\"") - if _, err := Load(path, nil); err == nil { - t.Fatal("expected error for unknown option") + c, err := Load(path, nil) + if err != nil { + t.Fatalf("unknown file option must not abort: %v", err) + } + if len(c.Warnings) == 0 || !strings.Contains(strings.Join(c.Warnings, "\n"), "db-ury") { + t.Errorf("expected a warning naming db-ury, got %q", c.Warnings) + } +} + +func TestUnknownEnvKeyWarns(t *testing.T) { + // The env path matches the file path: an unrecognized PGRST-namespaced + // variable warns instead of being silently dropped. + c, err := Load("", []string{"PGRST_DB_URY=typo", "DBREST_DB_URI=file:real.db"}) + if err != nil { + t.Fatal(err) + } + if len(c.Warnings) == 0 || !strings.Contains(strings.Join(c.Warnings, "\n"), "PGRST_DB_URY") { + t.Errorf("expected a warning naming PGRST_DB_URY, got %q", c.Warnings) + } +} + +func TestV14KeySetAccepted(t *testing.T) { + // Every documented v14 option a real postgrest.conf may carry must parse. + path := writeConf(t, ` +db-uri = "file:demo.db" +app.settings.jwt_lifetime = "3600" +app.settings.name = "demo" +db-aggregates-enabled = true +db-channel = "custom" +db-channel-enabled = false +db-config = false +db-hoisted-tx-settings = "statement_timeout" +db-plan-enabled = true +db-pool-automatic-recovery = false +db-pool-max-idletime = 60 +db-pool-max-lifetime = 600 +db-pre-config = "postgrest.pre_config" +db-prepared-statements = false +db-root-spec = "root" +db-tx-end = "rollback-allow-override" +jwt-secret-is-base64 = true +openapi-security-active = true +server-trace-header = "X-Request-Id" +server-timing-enabled = true +server-unix-socket-mode = "770" +`) + c, err := Load(path, nil) + if err != nil { + t.Fatalf("v14 key set rejected: %v", err) + } + if c.AppSettings["jwt_lifetime"] != "3600" || c.AppSettings["name"] != "demo" { + t.Errorf("app.settings = %v", c.AppSettings) + } + if !c.AggregatesEnabled || !c.PlanEnabled || !c.OpenAPISecurityActive || !c.ServerTimingEnabled || !c.JWTSecretIsBase64 { + t.Error("boolean options did not parse") + } + if c.DBChannel != "custom" || c.DBChannelEnabled || c.DBConfig || c.DBPreparedStatements || c.DBPoolAutomaticRecovery { + t.Error("channel/config/pool options did not parse") + } + if c.DBPoolMaxIdleTime != 60*time.Second || c.DBPoolMaxLifetime != 600*time.Second { + t.Errorf("pool times = %v/%v", c.DBPoolMaxIdleTime, c.DBPoolMaxLifetime) + } + if c.TxEnd != "rollback-allow-override" { + t.Errorf("db-tx-end = %q", c.TxEnd) + } + if !slices.Equal(c.HoistedTxSettings, []string{"statement_timeout"}) { + t.Errorf("db-hoisted-tx-settings = %v", c.HoistedTxSettings) + } + if c.RootSpec != "root" || c.DBPreConfig != "postgrest.pre_config" { + t.Errorf("root-spec/pre-config = %q/%q", c.RootSpec, c.DBPreConfig) + } + if c.ServerTraceHeader != "X-Request-Id" || c.ServerUnixSocketMode != "770" { + t.Errorf("trace header/socket mode = %q/%q", c.ServerTraceHeader, c.ServerUnixSocketMode) + } +} + +func TestV14Defaults(t *testing.T) { + c, err := FromMap(map[string]string{"db-uri": "x"}) + if err != nil { + t.Fatal(err) + } + if c.DBChannel != "pgrst" || !c.DBChannelEnabled || !c.DBConfig || !c.DBPreparedStatements { + t.Error("channel/config defaults wrong") + } + if c.DBPoolMaxIdleTime != 30*time.Second || c.DBPoolMaxLifetime != 1800*time.Second || !c.DBPoolAutomaticRecovery { + t.Error("pool defaults wrong") + } + if c.DBPoolAcquisitionTimeout != 10*time.Second { + t.Errorf("db-pool-acquisition-timeout default = %v, want 10s", c.DBPoolAcquisitionTimeout) + } + if c.TxEnd != "commit" || c.ServerUnixSocketMode != "660" { + t.Errorf("tx-end/socket-mode defaults = %q/%q", c.TxEnd, c.ServerUnixSocketMode) + } + if c.PlanEnabled || c.AggregatesEnabled || c.ServerTimingEnabled || c.OpenAPISecurityActive || c.JWTSecretIsBase64 { + t.Error("boolean defaults should be false") + } + if !slices.Equal(c.HoistedTxSettings, []string{"statement_timeout", "plan_filter.statement_cost_limit", "default_transaction_isolation"}) { + t.Errorf("hoisted settings default = %v", c.HoistedTxSettings) + } +} + +func TestV14Aliases(t *testing.T) { + c, err := FromMap(map[string]string{ + "db-uri": "x", "pre-request": "fn", "root-spec": "rs", + "db-schema": "api", "role-claim-key": ".r", + "secret-is-base64": "true", "db-pool-timeout": "55", + }) + if err != nil { + t.Fatal(err) + } + if c.PreRequest != "fn" || c.RootSpec != "rs" || c.JWTRoleClaimKey != ".r" { + t.Errorf("string aliases = %q/%q/%q", c.PreRequest, c.RootSpec, c.JWTRoleClaimKey) + } + if !slices.Equal(c.Schemas, []string{"api"}) { + t.Errorf("db-schema alias = %v", c.Schemas) + } + if !c.JWTSecretIsBase64 || c.DBPoolMaxIdleTime != 55*time.Second { + t.Error("secret-is-base64 or db-pool-timeout alias did not parse") + } +} + +func TestAppSettingsFromEnv(t *testing.T) { + c, err := Load("", []string{ + "DBREST_DB_URI=x", + "PGRST_APP_SETTINGS_JWT_LIFETIME=1800", + "DBREST_APP_SETTINGS_LOCAL=yes", + }) + if err != nil { + t.Fatal(err) + } + if c.AppSettings["jwt_lifetime"] != "1800" || c.AppSettings["local"] != "yes" { + t.Errorf("app settings from env = %v", c.AppSettings) + } +} + +func TestBadTxEndAndSocketMode(t *testing.T) { + if _, err := FromMap(map[string]string{"db-uri": "x", "db-tx-end": "explode"}); err == nil { + t.Error("expected error for bad db-tx-end") + } + if _, err := FromMap(map[string]string{"db-uri": "x", "server-unix-socket-mode": "555"}); err == nil { + t.Error("expected error for socket mode below 600") + } + if _, err := FromMap(map[string]string{"db-uri": "x", "server-unix-socket-mode": "9x"}); err == nil { + t.Error("expected error for non-octal socket mode") + } +} + +func TestUnenforcedOptionWarns(t *testing.T) { + c, err := FromMap(map[string]string{"db-uri": "x", "db-tx-end": "rollback"}) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(strings.Join(c.Warnings, "\n"), "db-tx-end") { + t.Errorf("expected an unenforced warning for db-tx-end, got %q", c.Warnings) } } @@ -152,6 +356,76 @@ func TestMaxRowsAlias(t *testing.T) { } } +// TestNativeKeysScopedToDBREST pins the namespace split: a dbrest extension +// does not bind from the PGRST_ environment prefix (a future PostgREST +// release adding the same name must not change dbrest behavior), it warns +// there instead, and the DBREST_ spelling keeps working. +func TestNativeKeysScopedToDBREST(t *testing.T) { + c, err := Load("", []string{"DBREST_DB_URI=x", "PGRST_DB_BACKEND=postgres", "PGRST_MAX_ROWS=5"}) + if err != nil { + t.Fatal(err) + } + if c.Backend != BackendSQLite { + t.Errorf("backend = %q, PGRST_DB_BACKEND must not bind", c.Backend) + } + if c.MaxRows != 0 { + t.Errorf("max-rows = %d, PGRST_MAX_ROWS must not bind", c.MaxRows) + } + joined := strings.Join(c.Warnings, "\n") + for _, name := range []string{"PGRST_DB_BACKEND", "PGRST_MAX_ROWS"} { + if !strings.Contains(joined, name) { + t.Errorf("expected a warning naming %s, got %q", name, c.Warnings) + } + } + + c, err = Load("", []string{"DBREST_DB_URI=x", "DBREST_DB_BACKEND=postgres"}) + if err != nil { + t.Fatal(err) + } + if c.Backend != BackendPostgres { + t.Errorf("backend = %q, DBREST_DB_BACKEND should bind", c.Backend) + } +} + +// TestNativeKeysFilePrefix covers the explicit dbrest. file spelling: it maps +// onto the bare extension key, and a non-extension name under the prefix gets +// the unknown-option warning. +func TestNativeKeysFilePrefix(t *testing.T) { + path := writeConf(t, ` +db-uri = "x" +dbrest.max-rows = 25 +dbrest.function-registry = "fns.json" +dbrest.server-port = 9999 +`) + c, err := Load(path, nil) + if err != nil { + t.Fatal(err) + } + if c.MaxRows != 25 || c.FunctionRegistry != "fns.json" { + t.Errorf("dbrest. prefixed keys did not bind: max-rows=%d registry=%q", c.MaxRows, c.FunctionRegistry) + } + if c.ServerPort != 3000 { + t.Errorf("server-port = %d, dbrest.server-port is not an extension and must not bind", c.ServerPort) + } + if !strings.Contains(strings.Join(c.Warnings, "\n"), "dbrest.server-port") { + t.Errorf("expected an unknown-option warning for dbrest.server-port, got %q", c.Warnings) + } +} + +// TestNativeKeysAreKnown guards the extension list: every native key must be +// a real option, so a rename cannot silently orphan the scoping. +func TestNativeKeysAreKnown(t *testing.T) { + known := map[string]bool{} + for _, k := range optionKeys { + known[k] = true + } + for _, k := range nativeOptionKeys { + if !known[k] { + t.Errorf("native key %q is not in optionKeys", k) + } + } +} + func TestUnknownEnvKeyIgnored(t *testing.T) { // A typo in the variable name is not a known option, so it must not leak in. c, err := Load("", []string{"PGRST_DB_URY=typo", "DBREST_DB_URI=file:real.db"}) @@ -229,3 +503,298 @@ func BenchmarkLoad(b *testing.B) { } } } + +// TestAllAnonymousPostureWarns covers the startup validation gap: a config +// with neither db-anon-role nor JWT key material boots, but says what that +// means. Configuring either side silences the warning. +func TestAllAnonymousPostureWarns(t *testing.T) { + hasAnonWarning := func(c *Config) bool { + return strings.Contains(strings.Join(c.Warnings, "\n"), "anonymously with no role") + } + + c, err := FromMap(map[string]string{"db-uri": "x"}) + if err != nil { + t.Fatal(err) + } + if !hasAnonWarning(c) { + t.Errorf("expected the all-anonymous warning, got %q", c.Warnings) + } + + silenced := []map[string]string{ + {"db-uri": "x", "db-anon-role": "web_anon"}, + {"db-uri": "x", "jwt-secret": "reallyreallyreallyreallyverysafe"}, + {"db-uri": "x", "jwk-set": `{"keys":[]}`}, + } + for _, raw := range silenced { + c, err := FromMap(raw) + if err != nil { + t.Fatal(err) + } + if hasAnonWarning(c) { + t.Errorf("warning should be silent for %v, got %q", raw, c.Warnings) + } + } +} + +// TestAdminPortCannotEqualServerPort mirrors the upstream boot failure: the +// admin server cannot share the API port. +func TestAdminPortCannotEqualServerPort(t *testing.T) { + if _, err := FromMap(map[string]string{"db-uri": "x", "server-port": "3000", "admin-server-port": "3000"}); err == nil { + t.Error("expected error for admin-server-port == server-port") + } + if _, err := FromMap(map[string]string{"db-uri": "x", "server-port": "3000", "admin-server-port": "3001"}); err != nil { + t.Errorf("distinct ports should boot: %v", err) + } +} + +// TestAdminHostDefaultsToServerHost checks the upstream default: an unset +// admin-server-host follows server-host. +func TestAdminHostDefaultsToServerHost(t *testing.T) { + c, err := FromMap(map[string]string{"db-uri": "x", "server-host": "127.0.0.5", "admin-server-port": "3001"}) + if err != nil { + t.Fatal(err) + } + if c.AdminServerHost != "127.0.0.5" { + t.Errorf("admin-server-host = %q, want the server-host 127.0.0.5", c.AdminServerHost) + } + c, err = FromMap(map[string]string{"db-uri": "x", "admin-server-host": "10.0.0.1", "admin-server-port": "3001"}) + if err != nil { + t.Fatal(err) + } + if c.AdminServerHost != "10.0.0.1" { + t.Errorf("admin-server-host = %q, explicit value lost", c.AdminServerHost) + } +} + +// TestMergeReloadable checks the SIGUSR2 merge: runtime options follow the new +// config, boot-time options stay put and are reported. +func TestMergeReloadable(t *testing.T) { + old, err := FromMap(map[string]string{"db-uri": "file:a.db", "db-max-rows": "100", "server-port": "3000"}) + if err != nil { + t.Fatal(err) + } + next, err := FromMap(map[string]string{"db-uri": "file:b.db", "db-max-rows": "50", "server-port": "4000", "db-anon-role": "web_anon"}) + if err != nil { + t.Fatal(err) + } + merged, kept := old.MergeReloadable(next) + if merged.MaxRows != 50 || merged.AnonRole != "web_anon" { + t.Errorf("reloadable fields not applied: max-rows=%d anon=%q", merged.MaxRows, merged.AnonRole) + } + if merged.DBURI != "file:a.db" || merged.ServerPort != 3000 { + t.Errorf("boot-time fields changed: db-uri=%q port=%d", merged.DBURI, merged.ServerPort) + } + joined := strings.Join(kept, "\n") + for _, name := range []string{"db-uri", "server-port"} { + if !strings.Contains(joined, name) { + t.Errorf("expected a kept-value message for %s, got %q", name, kept) + } + } + if strings.Contains(joined, "db-max-rows") { + t.Errorf("db-max-rows is reloadable, should not be reported: %q", kept) + } +} + +// TestDumpRoundTrips pins the --dump-config format: the output is valid +// config-file syntax and loads back to the same resolved values. +func TestDumpRoundTrips(t *testing.T) { + first, err := FromMap(map[string]string{ + "db-uri": "file:dump.db", + "db-schemas": "public,api", + "db-anon-role": "web_anon", + "db-max-rows": "500", + "db-tx-end": "rollback", + "app.settings.tenant": "acme", + }) + if err != nil { + t.Fatal(err) + } + path := writeConf(t, first.Dump()) + second, err := Load(path, nil) + if err != nil { + t.Fatalf("dump output does not load: %v", err) + } + if second.DBURI != first.DBURI || second.AnonRole != first.AnonRole || + second.MaxRows != first.MaxRows || second.TxEnd != first.TxEnd { + t.Errorf("round trip drifted: %+v vs %+v", second, first) + } + if len(second.Schemas) != 2 || second.Schemas[0] != "public" { + t.Errorf("schemas drifted: %v", second.Schemas) + } + if second.AppSettings["tenant"] != "acme" { + t.Errorf("app settings drifted: %v", second.AppSettings) + } + if second.Dump() != first.Dump() { + t.Error("Dump is not a fixed point of Load(Dump)") + } +} + +// TestEnvInterpolation covers $(NAME) in file string values: an environment +// variable, an earlier config key, the $$ escape, and the hard error on an +// unset name, all upstream configurator behavior. +func TestEnvInterpolation(t *testing.T) { + t.Setenv("DBREST_TEST_SECRET", "from-env") + path := writeConf(t, ` +db-uri = "file:interp.db" +db-anon-role = "web_anon" +jwt-secret = "$(DBREST_TEST_SECRET)" +db-pre-request = "check_$(db-anon-role)" +app.settings.cost = "5$$ per row" +`) + c, err := Load(path, nil) + if err != nil { + t.Fatal(err) + } + if c.JWTSecret != "from-env" { + t.Errorf("jwt-secret = %q, want the env value", c.JWTSecret) + } + if c.PreRequest != "check_web_anon" { + t.Errorf("pre-request = %q, earlier config key did not resolve", c.PreRequest) + } + if c.AppSettings["cost"] != "5$ per row" { + t.Errorf("$$ escape: got %q", c.AppSettings["cost"]) + } + + bad := writeConf(t, `jwt-secret = "$(DBREST_TEST_UNSET_VAR)"`) + if _, err := Load(bad, nil); err == nil || !strings.Contains(err.Error(), "no such variable") { + t.Errorf("unset variable should be a hard error, got %v", err) + } +} + +// TestEnvValuesAreNotInterpolated pins the asymmetry: only file values +// expand; an env-sourced value keeps its dollars verbatim. +func TestEnvValuesAreNotInterpolated(t *testing.T) { + c, err := Load("", []string{"PGRST_DB_URI=x", "PGRST_JWT_SECRET=pa$(ss)word"}) + if err != nil { + t.Fatal(err) + } + if c.JWTSecret != "pa$(ss)word" { + t.Errorf("jwt-secret = %q, env values must stay literal", c.JWTSecret) + } +} + +// TestAtFileReferences covers the @path form for the two options that support +// it: jwt-secret (one trailing newline chomped) and db-uri (whitespace +// trimmed), plus the error on a missing file. +func TestAtFileReferences(t *testing.T) { + dir := t.TempDir() + secretPath := filepath.Join(dir, "secret") + if err := os.WriteFile(secretPath, []byte("hush hush hush hush hush hush 32\n"), 0o600); err != nil { + t.Fatal(err) + } + uriPath := filepath.Join(dir, "uri") + if err := os.WriteFile(uriPath, []byte(" file:from-file.db \n"), 0o600); err != nil { + t.Fatal(err) + } + c, err := FromMap(map[string]string{ + "db-uri": "@" + uriPath, + "jwt-secret": "@" + secretPath, + }) + if err != nil { + t.Fatal(err) + } + if c.DBURI != "file:from-file.db" { + t.Errorf("db-uri = %q, want the trimmed file contents", c.DBURI) + } + if c.JWTSecret != "hush hush hush hush hush hush 32" { + t.Errorf("jwt-secret = %q, want the file contents with one newline chomped", c.JWTSecret) + } + + if _, err := FromMap(map[string]string{"db-uri": "@" + filepath.Join(dir, "missing")}); err == nil { + t.Error("missing @file should be an error") + } +} + +// TestListenSpecs pins the special host values to their candidate lists, and +// the default host to upstream's !4. +func TestListenSpecs(t *testing.T) { + c, err := FromMap(map[string]string{"db-uri": "x"}) + if err != nil { + t.Fatal(err) + } + if c.ServerHost != "!4" { + t.Errorf("default server-host = %q, want !4", c.ServerHost) + } + + cases := []struct { + host string + want []ListenSpec + }{ + {"*", []ListenSpec{{"tcp", ":3000"}}}, + {"*4", []ListenSpec{{"tcp4", "0.0.0.0:3000"}, {"tcp6", "[::]:3000"}}}, + {"!4", []ListenSpec{{"tcp4", "0.0.0.0:3000"}}}, + {"*6", []ListenSpec{{"tcp6", "[::]:3000"}, {"tcp4", "0.0.0.0:3000"}}}, + {"!6", []ListenSpec{{"tcp6", "[::]:3000"}}}, + {"127.0.0.1", []ListenSpec{{"tcp", "127.0.0.1:3000"}}}, + {"::1", []ListenSpec{{"tcp", "[::1]:3000"}}}, + } + for _, tc := range cases { + c.ServerHost = tc.host + got := c.Listeners() + if len(got) != len(tc.want) { + t.Errorf("%s: %v, want %v", tc.host, got, tc.want) + continue + } + for i := range got { + if got[i] != tc.want[i] { + t.Errorf("%s[%d]: %v, want %v", tc.host, i, got[i], tc.want[i]) + } + } + } +} + +// TestUnixSocketReplacesTCP pins the listener selection: with +// server-unix-socket set the only candidate is the socket, and the admin +// listeners stay TCP. +func TestUnixSocketReplacesTCP(t *testing.T) { + c, err := FromMap(map[string]string{ + "db-uri": "x", "server-unix-socket": "/tmp/dbrest.sock", "admin-server-port": "3001", + }) + if err != nil { + t.Fatal(err) + } + got := c.Listeners() + if len(got) != 1 || got[0] != (ListenSpec{"unix", "/tmp/dbrest.sock"}) { + t.Errorf("Listeners = %v, want the unix socket only", got) + } + for _, spec := range c.AdminListeners() { + if spec.Network == "unix" { + t.Errorf("admin listener went to the socket: %v", spec) + } + } +} + +// TestSchemasDefaultFollowsBackend pins the engine-aware db-schemas default: +// public on postgres, main on sqlite, the backend's own default elsewhere, +// with an explicit value always winning and an explicitly empty one rejected. +func TestSchemasDefaultFollowsBackend(t *testing.T) { + cases := []struct { + backend string + want string + }{ + {"postgres", "public"}, + {"sqlite", ""}, + {"mysql", ""}, + } + for _, tc := range cases { + c, err := FromMap(map[string]string{"db-uri": "x", "db-backend": tc.backend}) + if err != nil { + t.Fatalf("%s: %v", tc.backend, err) + } + if len(c.Schemas) != 1 || c.Schemas[0] != tc.want { + t.Errorf("%s: schemas = %v, want [%q]", tc.backend, c.Schemas, tc.want) + } + } + + c, err := FromMap(map[string]string{"db-uri": "x", "db-backend": "postgres", "db-schemas": "api,private"}) + if err != nil { + t.Fatal(err) + } + if len(c.Schemas) != 2 || c.Schemas[0] != "api" { + t.Errorf("explicit schemas lost: %v", c.Schemas) + } + + if _, err := FromMap(map[string]string{"db-uri": "x", "db-schemas": ""}); err == nil { + t.Error("explicitly empty db-schemas should be rejected") + } +} diff --git a/config/dump.go b/config/dump.go new file mode 100644 index 0000000..5018bcb --- /dev/null +++ b/config/dump.go @@ -0,0 +1,92 @@ +package config + +import ( + "fmt" + "sort" + "strconv" + "strings" + "time" +) + +// seconds renders a pool timeout the way upstream's config does: a bare +// integer of seconds, falling back to the duration extension form below one +// second so the value round-trips. +func seconds(d time.Duration) string { + if d%time.Second == 0 { + return strconv.Itoa(int(d / time.Second)) + } + return strconv.Quote(d.String()) +} + +// Dump renders the resolved configuration in the config-file syntax, the +// answer to --dump-config: every option with its effective value, defaults +// included, sorted by key. The output parses back to the same configuration, +// which is also how the tests pin it. +func (c *Config) Dump() string { + q := strconv.Quote + pairs := map[string]string{ + "db-backend": q(c.Backend), + "db-uri": q(c.DBURI), + "db-schemas": q(strings.Join(c.Schemas, ",")), + "db-anon-role": q(c.AnonRole), + "db-pre-request": q(c.PreRequest), + "db-extra-search-path": q(strings.Join(c.ExtraSearchPath, ",")), + "db-max-rows": strconv.Itoa(c.MaxRows), + "dbrest.max-request-body": strconv.Itoa(c.MaxRequestBody), + "db-aggregates-enabled": strconv.FormatBool(c.AggregatesEnabled), + "db-root-spec": q(c.RootSpec), + "db-tx-end": q(c.TxEnd), + "db-hoisted-tx-settings": q(strings.Join(c.HoistedTxSettings, ",")), + "db-plan-enabled": strconv.FormatBool(c.PlanEnabled), + "db-channel": q(c.DBChannel), + "db-channel-enabled": strconv.FormatBool(c.DBChannelEnabled), + "db-config": strconv.FormatBool(c.DBConfig), + "db-pre-config": q(c.DBPreConfig), + "db-prepared-statements": strconv.FormatBool(c.DBPreparedStatements), + "db-pool": strconv.Itoa(c.DBPool), + "db-pool-acquisition-timeout": seconds(c.DBPoolAcquisitionTimeout), + "db-pool-max-idletime": seconds(c.DBPoolMaxIdleTime), + "db-pool-max-lifetime": seconds(c.DBPoolMaxLifetime), + "db-pool-automatic-recovery": strconv.FormatBool(c.DBPoolAutomaticRecovery), + "jwt-secret": q(c.JWTSecret), + "jwt-secret-is-base64": strconv.FormatBool(c.JWTSecretIsBase64), + "jwt-aud": q(c.JWTAud), + "jwt-role-claim-key": q(c.JWTRoleClaimKey), + "jwk-set": q(c.JWKSet), + "jwt-cache-max-entries": strconv.Itoa(c.JWTCacheMaxEntries), + "server-host": q(c.ServerHost), + "server-port": strconv.Itoa(c.ServerPort), + "server-unix-socket": q(c.ServerUnixSocket), + "server-unix-socket-mode": q(c.ServerUnixSocketMode), + "admin-server-host": q(c.AdminServerHost), + "admin-server-port": strconv.Itoa(c.AdminServerPort), + "openapi-mode": q(c.OpenAPIMode), + "openapi-security-active": strconv.FormatBool(c.OpenAPISecurityActive), + "openapi-server-proxy-uri": q(c.OpenAPIServerProxyURI), + "log-level": q(c.LogLevel), + "log-query": strconv.FormatBool(c.LogQuery), + "server-cors-allowed-origins": q(strings.Join(c.CORSAllowedOrigins, ",")), + "server-trace-header": q(c.ServerTraceHeader), + "server-timing-enabled": strconv.FormatBool(c.ServerTimingEnabled), + "declared-schema": q(c.DeclaredSchema), + "declared-relationships": q(c.DeclaredRelationships), + "function-registry": q(c.FunctionRegistry), + "policy-registry": q(c.PolicyRegistry), + "capability-overrides": q(c.CapabilityOverrides), + } + for name, v := range c.AppSettings { + pairs["app.settings."+name] = q(v) + } + + keys := make([]string, 0, len(pairs)) + for k := range pairs { + keys = append(keys, k) + } + sort.Strings(keys) + + var b strings.Builder + for _, k := range keys { + fmt.Fprintf(&b, "%s = %s\n", k, pairs[k]) + } + return b.String() +} diff --git a/config/parse.go b/config/parse.go index 45f422f..189fce3 100644 --- a/config/parse.go +++ b/config/parse.go @@ -14,60 +14,142 @@ import ( // only consulted for a key we actually understand, so a typo in PGRST_DB_URY is // ignored rather than silently dropped into a catch-all map. var optionKeys = []string{ - "db-backend", "db-uri", "db-schemas", "db-anon-role", "db-pre-request", - "db-extra-search-path", "db-max-rows", "max-rows", - "jwt-secret", "jwt-aud", "jwt-role-claim-key", "jwk-set", "jwt-cache-max-entries", - "server-host", "server-port", "server-unix-socket", + "db-backend", "db-uri", "db-schemas", "db-schema", "db-anon-role", + "db-pre-request", "pre-request", + "db-extra-search-path", "db-max-rows", "max-rows", "max-request-body", + "db-aggregates-enabled", "db-root-spec", "root-spec", + "db-tx-end", "db-hoisted-tx-settings", + "db-channel", "db-channel-enabled", "db-config", "db-pre-config", + "db-prepared-statements", "db-plan-enabled", + "jwt-secret", "jwt-secret-is-base64", "secret-is-base64", "jwt-aud", + "jwt-role-claim-key", "role-claim-key", "jwk-set", "jwt-cache-max-entries", + "server-host", "server-port", "server-unix-socket", "server-unix-socket-mode", "admin-server-host", "admin-server-port", "db-pool", "db-pool-acquisition-timeout", - "openapi-mode", "openapi-server-proxy-uri", + "db-pool-max-idletime", "db-pool-timeout", "db-pool-max-lifetime", + "db-pool-automatic-recovery", + "openapi-mode", "openapi-server-proxy-uri", "openapi-security-active", "log-level", "log-query", "server-cors-allowed-origins", + "server-trace-header", "server-timing-enabled", "declared-schema", "declared-relationships", "function-registry", "policy-registry", "capability-overrides", } +// nativeOptionKeys are the dbrest extensions inside optionKeys: options +// PostgREST does not have. They stay out of the PGRST_ environment namespace +// so a future upstream release adding a key with the same name cannot +// silently change dbrest behavior; the environment spelling is DBREST_ only, +// and the file accepts both the bare name (the documented extension list) and +// an explicit dbrest. prefix. +var nativeOptionKeys = []string{ + "db-backend", "jwk-set", "max-rows", "max-request-body", + "declared-schema", "declared-relationships", + "function-registry", "policy-registry", "capability-overrides", +} + +// isNativeKey reports whether key is a dbrest extension rather than a +// PostgREST-compatible option. +var isNativeKey = func() map[string]bool { + m := make(map[string]bool, len(nativeOptionKeys)) + for _, k := range nativeOptionKeys { + m[k] = true + } + return m +}() + +// nativeFilePrefix is the explicit file spelling for a dbrest extension, +// "dbrest.max-rows" for "max-rows". +const nativeFilePrefix = "dbrest." + +// appSettingsPrefix is the dynamic option namespace: any app.settings. +// key is accepted and carried to the backend as a transaction setting. +const appSettingsPrefix = "app.settings." + +// appSettingsEnvPrefix is the env-suffix spelling of the same namespace: +// PGRST_APP_SETTINGS_FOO maps to app.settings.foo. +const appSettingsEnvPrefix = "APP_SETTINGS_" + // envSuffix turns an option key into the variable suffix shared by both // prefixes: "db-uri" becomes "DB_URI", read as PGRST_DB_URI or DBREST_DB_URI. func envSuffix(key string) string { return strings.ToUpper(strings.ReplaceAll(key, "-", "_")) } -// overlayEnv layers the environment over raw. For each known key it reads the -// PGRST_ spelling first, then the DBREST_ spelling, so DBREST_ wins on a -// conflict; either present overrides the file. environ is os.Environ() form. -func overlayEnv(raw map[string]string, environ []string) { +// overlayEnv layers the environment over raw and returns warnings for +// namespaced variables that match no known option. For each known key it reads +// the PGRST_ spelling first, then the DBREST_ spelling, so DBREST_ wins on a +// conflict; either present overrides the file. The dynamic +// PGRST_APP_SETTINGS_* / DBREST_APP_SETTINGS_* namespace maps to +// app.settings.* keys with a lowercased name. environ is os.Environ() form. +func overlayEnv(raw map[string]string, environ []string) []string { env := map[string]string{} for _, kv := range environ { if k, v, ok := strings.Cut(kv, "="); ok { env[k] = v } } - for _, key := range optionKeys { - suffix := envSuffix(key) - if v, ok := env["PGRST_"+suffix]; ok { - raw[key] = v + known := map[string]bool{} + for _, k := range optionKeys { + known[envSuffix(k)] = true + } + nativeSuffix := map[string]bool{} + for _, k := range nativeOptionKeys { + nativeSuffix[envSuffix(k)] = true + } + var warnings []string + for _, prefix := range []string{"PGRST_", "DBREST_"} { + for _, key := range optionKeys { + // dbrest extensions never bind from the PGRST namespace, so a + // future upstream option with the same name cannot change dbrest + // behavior through an existing deployment's environment. + if prefix == "PGRST_" && isNativeKey[key] { + continue + } + if v, ok := env[prefix+envSuffix(key)]; ok { + raw[key] = v + } } - if v, ok := env["DBREST_"+suffix]; ok { - raw[key] = v + // The dynamic namespace and the unknown-suffix warnings need a scan + // over what is actually set, not over what we expect. + for name, v := range env { + suffix, ok := strings.CutPrefix(name, prefix) + if !ok { + continue + } + if setting, ok := strings.CutPrefix(suffix, appSettingsEnvPrefix); ok && setting != "" { + raw[appSettingsPrefix+strings.ToLower(setting)] = v + continue + } + if prefix == "PGRST_" && nativeSuffix[suffix] { + warnings = append(warnings, fmt.Sprintf("ignoring %s: %q is a dbrest extension; set DBREST_%s instead", name, strings.ToLower(strings.ReplaceAll(suffix, "_", "-")), suffix)) + continue + } + if !known[suffix] { + warnings = append(warnings, fmt.Sprintf("ignoring %s: no option named %q", name, strings.ToLower(strings.ReplaceAll(suffix, "_", "-")))) + } } } + return warnings } // parseFile reads a PostgREST-style flat configuration file into a raw map. The // format is one "key = value" per line; values are bare, double-quoted, or // triple-quoted for multi-line strings; '#' begins a comment outside a quoted -// value; blank lines are skipped. Unknown keys are an error, so a mistyped -// option fails loudly at startup rather than being ignored. -func parseFile(path string) (map[string]string, error) { +// value; blank lines are skipped. An unknown key is kept out of the map and +// reported as a warning, matching PostgREST, which ignores keys it does not +// own; the same posture applies to unknown namespaced environment variables in +// overlayEnv, so the two sources fail symmetrically. +func parseFile(path string) (map[string]string, []string, error) { data, err := os.ReadFile(path) if err != nil { - return nil, fmt.Errorf("config: reading %s: %w", path, err) + return nil, nil, fmt.Errorf("config: reading %s: %w", path, err) } known := map[string]bool{} for _, k := range optionKeys { known[k] = true } raw := map[string]string{} + var warnings []string lines := strings.Split(string(data), "\n") for i := 0; i < len(lines); i++ { line := strings.TrimSpace(stripComment(lines[i])) @@ -76,25 +158,92 @@ func parseFile(path string) (map[string]string, error) { } rawKey, rawVal, ok := strings.Cut(line, "=") if !ok { - return nil, fmt.Errorf("config: %s line %d: expected key = value", path, i+1) + return nil, nil, fmt.Errorf("config: %s line %d: expected key = value", path, i+1) } key := strings.TrimSpace(rawKey) val := strings.TrimSpace(rawVal) - if !known[key] { - return nil, fmt.Errorf("config: %s line %d: unknown option %q", path, i+1, key) + // "dbrest." is the explicit file spelling for an extension; it + // maps onto the bare key, and a non-extension name under the prefix + // falls through to the unknown-option warning below. + if name, ok := strings.CutPrefix(key, nativeFilePrefix); ok && isNativeKey[name] { + key = name + } + if !known[key] && !strings.HasPrefix(key, appSettingsPrefix) { + warnings = append(warnings, fmt.Sprintf("%s line %d: ignoring unknown option %q", path, i+1, key)) + key = "" } if strings.HasPrefix(val, `"""`) { block, used, err := readTripleQuoted(lines, i, val) if err != nil { - return nil, fmt.Errorf("config: %s line %d: %w", path, i+1, err) + return nil, nil, fmt.Errorf("config: %s line %d: %w", path, i+1, err) + } + if key != "" { + expanded, err := interpolate(block, raw) + if err != nil { + return nil, nil, fmt.Errorf("config: %s line %d: %w", path, i+1, err) + } + raw[key] = expanded } - raw[key] = block i = used continue } - raw[key] = unquote(val) + if key != "" { + v := val + if quoted := strings.HasPrefix(val, `"`); quoted { + // Only quoted strings interpolate, as in upstream's config + // format; a bare number or boolean is taken verbatim. + expanded, err := interpolate(unquote(val), raw) + if err != nil { + return nil, nil, fmt.Errorf("config: %s line %d: %w", path, i+1, err) + } + v = expanded + } + raw[key] = v + } + } + return raw, warnings, nil +} + +// interpolate expands $(NAME) inside a config-file string value, the upstream +// configurator behavior: NAME resolves against the options bound earlier in +// the file first, then the process environment, and an unset name is a hard +// error rather than an empty string. The sequence $$ collapses to a literal +// dollar. Environment-sourced option values are never interpolated; this runs +// only on file values. +func interpolate(v string, raw map[string]string) (string, error) { + if !strings.Contains(v, "$") { + return v, nil + } + var b strings.Builder + for i := 0; i < len(v); i++ { + if v[i] != '$' { + b.WriteByte(v[i]) + continue + } + if i+1 < len(v) && v[i+1] == '$' { + b.WriteByte('$') + i++ + continue + } + if i+1 < len(v) && v[i+1] == '(' { + end := strings.IndexByte(v[i+2:], ')') + if end < 0 { + return "", fmt.Errorf("unterminated $( in %q", v) + } + name := v[i+2 : i+2+end] + if prior, ok := raw[name]; ok { + b.WriteString(prior) + } else if env, ok := os.LookupEnv(name); ok { + b.WriteString(env) + } else { + return "", fmt.Errorf("no such variable %q", name) + } + i += 2 + end + continue + } + b.WriteByte('$') } - return raw, nil + return b.String(), nil } // stripComment removes a trailing '#' comment from a line, leaving '#' that sits @@ -156,6 +305,19 @@ func splitList(v string) []string { return out } +// pickString reads the first present key among aliases as a string, falling +// back to def when none is set. PostgREST keeps a handful of pre-rename +// aliases (pre-request, root-spec, db-schema, role-claim-key) working; this is +// the string side of that. +func pickString(raw map[string]string, def string, keys ...string) string { + for _, key := range keys { + if v, ok := raw[key]; ok { + return v + } + } + return def +} + // pickInt reads the first present key among aliases as an integer, recording a // validation error on a malformed value and falling back to def. func pickInt(raw map[string]string, errs *[]string, def int, keys ...string) int { @@ -174,32 +336,44 @@ func pickInt(raw map[string]string, errs *[]string, def int, keys ...string) int return def } -// pickBool reads key as a boolean (true/false, 1/0), recording a validation -// error on a malformed value and falling back to def. -func pickBool(raw map[string]string, errs *[]string, def bool, key string) bool { - v, ok := raw[key] - if !ok { - return def - } - b, err := strconv.ParseBool(strings.TrimSpace(v)) - if err != nil { - *errs = append(*errs, fmt.Sprintf("%s %q is not a boolean", key, v)) - return def +// pickBool reads the first present key among aliases as a boolean (true/false, +// 1/0), recording a validation error on a malformed value and falling back to +// def. +func pickBool(raw map[string]string, errs *[]string, def bool, keys ...string) bool { + for _, key := range keys { + v, ok := raw[key] + if !ok { + continue + } + b, err := strconv.ParseBool(strings.TrimSpace(v)) + if err != nil { + *errs = append(*errs, fmt.Sprintf("%s %q is not a boolean", key, v)) + return def + } + return b } - return b + return def } -// pickDuration reads key as a Go duration (for example "10s"), recording a -// validation error on a malformed value and falling back to def. -func pickDuration(raw map[string]string, errs *[]string, def time.Duration, key string) time.Duration { - v, ok := raw[key] - if !ok { - return def - } - d, err := time.ParseDuration(strings.TrimSpace(v)) - if err != nil { - *errs = append(*errs, fmt.Sprintf("%s %q is not a duration", key, v)) +// pickSeconds reads the first present key as an integer number of seconds, +// the unit upstream uses for the pool timeouts (`db-pool-acquisition-timeout +// = 10`). A Go duration string ("500ms") is also accepted, a dbrest extension +// for sub-second values. +func pickSeconds(raw map[string]string, errs *[]string, def time.Duration, keys ...string) time.Duration { + for _, key := range keys { + v, ok := raw[key] + if !ok { + continue + } + s := strings.TrimSpace(v) + if n, err := strconv.Atoi(s); err == nil { + return time.Duration(n) * time.Second + } + if d, err := time.ParseDuration(s); err == nil { + return d + } + *errs = append(*errs, fmt.Sprintf("%s %q is not a number of seconds", key, v)) return def } - return d + return def } diff --git a/conformance/live_test.go b/conformance/live_test.go index ec4c5a5..6a0e8b5 100644 --- a/conformance/live_test.go +++ b/conformance/live_test.go @@ -46,7 +46,9 @@ func fixtureServer(t *testing.T) (*httpapi.Server, sqliteCaps) { if err != nil { t.Fatalf("introspect: %v", err) } - return httpapi.NewServer(be, model, nil), sqliteCaps{be} + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv, sqliteCaps{be} } // sqliteCaps wraps the backend so a test can read its declared capabilities. diff --git a/conformance/postgres_live_test.go b/conformance/postgres_live_test.go new file mode 100644 index 0000000..06ffda7 --- /dev/null +++ b/conformance/postgres_live_test.go @@ -0,0 +1,112 @@ +package conformance_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/tamnd/dbrest/backend" + "github.com/tamnd/dbrest/backend/postgres" + "github.com/tamnd/dbrest/conformance" + "github.com/tamnd/dbrest/httpapi" +) + +// pgFixtureDDL mirrors cmd/dbrest-conformance: the films fixture in a dedicated +// schema, plus a text[] column so the array operators exercise the Native tier +// that SQLite lacks, and the anon role the server's SET LOCAL ROLE assumes. +const pgFixtureDDL = ` +DROP SCHEMA IF EXISTS _dbrest_conf CASCADE; +CREATE SCHEMA _dbrest_conf; +DO $$ BEGIN + IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = 'anon') THEN + CREATE ROLE anon NOLOGIN; + END IF; +END $$; +GRANT USAGE ON SCHEMA _dbrest_conf TO anon; +CREATE TABLE _dbrest_conf.films ( + id integer PRIMARY KEY, + title text NOT NULL, + year integer, + rating text, + tags text[] +); +INSERT INTO _dbrest_conf.films (id, title, year, rating, tags) VALUES + (1, 'Metropolis', 1927, 'NR', '{sci-fi,silent}'), + (2, 'Blade Runner', 1982, 'R', '{sci-fi,noir}'), + (3, 'Arrival', 2016, 'PG13', '{sci-fi,drama}'); +GRANT SELECT ON _dbrest_conf.films TO anon; +` + +// TestPostgresConformanceCorpus replays the checked-in postgres corpus against a +// live PostgreSQL fixture under the postgres allowlist, and reconciles the +// allowlist against the live capability matrix. It is the in-process twin of the +// dbrest-conformance CLI's postgres pass, gated on DBREST_PG_DSN so the suite +// stays green without a server. Postgres is the reference backend: every case +// passes natively and the allowlist documents no divergence. +func TestPostgresConformanceCorpus(t *testing.T) { + dsn := os.Getenv("DBREST_PG_DSN") + if dsn == "" { + t.Skip("DBREST_PG_DSN not set; skipping postgres conformance corpus") + } + + be, err := postgres.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = be.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + if _, err := be.Pool().Exec(ctx, pgFixtureDDL); err != nil { + t.Fatalf("load fixture: %v", err) + } + t.Cleanup(func() { + _, _ = be.Pool().Exec(context.Background(), "DROP SCHEMA IF EXISTS _dbrest_conf CASCADE") + }) + + be.SetSchemas([]string{"_dbrest_conf"}) + model, err := be.Introspect(ctx) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, []string{"_dbrest_conf"}) + srv.SetDefaultRole("anon") + + cases, err := conformance.LoadCorpus("testdata/postgres/corpus.json") + if err != nil { + t.Fatalf("load corpus: %v", err) + } + allow, err := conformance.LoadAllowlist("testdata/postgres/allowlist.json") + if err != nil { + t.Fatalf("load allowlist: %v", err) + } + + caps := be.Capabilities() + ft := backend.Native + if caps.FullText == backend.FTNone { + ft = backend.Unsupported + } + tiers := map[string]backend.Tier{ + "regex": caps.Regex, + "fts": ft, + "array-contains": caps.ArrayRangeTypes, + "count-planned": caps.CountPlanned, + } + if err := allow.CheckMatrix(tiers); err != nil { + t.Fatalf("allowlist vs matrix: %v", err) + } + + rep := conformance.Replay(srv, cases, allow) + if !rep.OK() { + for _, r := range rep.Results { + if r.Verdict == conformance.Fail { + t.Errorf("case %q: %s %v", r.Name, r.Verdict, r.Diffs) + } + } + t.Fatalf("postgres corpus: %d passed, %d allowlisted, %d failed", rep.Passed, rep.Allowed, rep.Failed) + } + if rep.Failed != 0 || rep.Allowed != 0 { + t.Errorf("reference backend should pass every case natively: %d failed, %d allowlisted", rep.Failed, rep.Allowed) + } +} diff --git a/conformance/testdata/postgres/allowlist.json b/conformance/testdata/postgres/allowlist.json new file mode 100644 index 0000000..113a900 --- /dev/null +++ b/conformance/testdata/postgres/allowlist.json @@ -0,0 +1,4 @@ +{ + "backend": "postgres", + "entries": [] +} diff --git a/conformance/testdata/postgres/corpus.json b/conformance/testdata/postgres/corpus.json new file mode 100644 index 0000000..e36ed55 --- /dev/null +++ b/conformance/testdata/postgres/corpus.json @@ -0,0 +1,88 @@ +[ + { + "name": "every column, all rows, ordered", + "request": { "method": "GET", "path": "/films", "query": "order=id.asc" }, + "golden": { + "status": 200, + "body": "[{\"id\":1,\"title\":\"Metropolis\",\"year\":1927,\"rating\":\"NR\",\"tags\":[\"sci-fi\",\"silent\"]},{\"id\":2,\"title\":\"Blade Runner\",\"year\":1982,\"rating\":\"R\",\"tags\":[\"sci-fi\",\"noir\"]},{\"id\":3,\"title\":\"Arrival\",\"year\":2016,\"rating\":\"PG13\",\"tags\":[\"sci-fi\",\"drama\"]}]" + } + }, + { + "name": "single row by id", + "request": { "method": "GET", "path": "/films", "query": "select=id,title&id=eq.1" }, + "golden": { + "status": 200, + "body": "[{\"id\":1,\"title\":\"Metropolis\"}]" + } + }, + { + "name": "projection and filter", + "request": { "method": "GET", "path": "/films", "query": "select=title,year&year=gte.2000&order=year.asc" }, + "golden": { + "status": 200, + "body": "[{\"title\":\"Arrival\",\"year\":2016}]" + } + }, + { + "name": "empty match is an empty array, not a 404", + "request": { "method": "GET", "path": "/films", "query": "id=eq.999" }, + "golden": { "status": 200, "body": "[]" } + }, + { + "name": "unknown table is a PGRST205 envelope", + "request": { "method": "GET", "path": "/ghosts" }, + "golden": { + "status": 404, + "body": "{\"code\":\"PGRST205\",\"message\":\"\",\"details\":\"\",\"hint\":\"\"}" + }, + "mask": ["/message", "/details", "/hint"] + }, + { + "name": "full-text filter lowers to a native tsquery", + "feature": "fts", + "request": { "method": "GET", "path": "/films", "query": "select=id,title&title=fts.metropolis" }, + "golden": { + "status": 200, + "body": "[{\"id\":1,\"title\":\"Metropolis\"}]" + } + }, + { + "name": "regex filter lowers to a native POSIX match", + "feature": "regex", + "request": { "method": "GET", "path": "/films", "query": "select=id,title&title=match.^Bl&order=id.asc" }, + "golden": { + "status": 200, + "body": "[{\"id\":2,\"title\":\"Blade Runner\"}]" + } + }, + { + "name": "array-contains is native on a text[] column", + "feature": "array-contains", + "request": { "method": "GET", "path": "/films", "query": "select=id,title&tags=cs.{noir}" }, + "golden": { + "status": 200, + "body": "[{\"id\":2,\"title\":\"Blade Runner\"}]" + } + }, + { + "name": "like is case-sensitive: a lowercase pattern matches nothing", + "request": { "method": "GET", "path": "/films", "query": "title=like.bl*" }, + "golden": { "status": 200, "body": "[]" } + }, + { + "name": "like is case-sensitive: the title-cased pattern matches", + "request": { "method": "GET", "path": "/films", "query": "select=id,title&title=like.Bl*" }, + "golden": { + "status": 200, + "body": "[{\"id\":2,\"title\":\"Blade Runner\"}]" + } + }, + { + "name": "ilike folds case: a lowercase pattern matches the title-cased row", + "request": { "method": "GET", "path": "/films", "query": "select=id,title&title=ilike.bl*" }, + "golden": { + "status": 200, + "body": "[{\"id\":2,\"title\":\"Blade Runner\"}]" + } + } +] diff --git a/conformance/testdata/sqlite/corpus.json b/conformance/testdata/sqlite/corpus.json index 79eed87..0db973b 100644 --- a/conformance/testdata/sqlite/corpus.json +++ b/conformance/testdata/sqlite/corpus.json @@ -64,5 +64,26 @@ "body": "{\"code\":\"PGRST127\",\"message\":\"\",\"details\":\"\",\"hint\":\"\"}" }, "mask": ["/message", "/details", "/hint"] + }, + { + "name": "like is case-sensitive: a lowercase pattern matches nothing", + "request": { "method": "GET", "path": "/films", "query": "title=like.bl*" }, + "golden": { "status": 200, "body": "[]" } + }, + { + "name": "like is case-sensitive: the title-cased pattern matches", + "request": { "method": "GET", "path": "/films", "query": "title=like.Bl*" }, + "golden": { + "status": 200, + "body": "[{\"id\":2,\"title\":\"Blade Runner\",\"year\":1982,\"rating\":\"R\"}]" + } + }, + { + "name": "ilike folds case: a lowercase pattern matches the title-cased row", + "request": { "method": "GET", "path": "/films", "query": "title=ilike.bl*" }, + "golden": { + "status": 200, + "body": "[{\"id\":2,\"title\":\"Blade Runner\",\"year\":1982,\"rating\":\"R\"}]" + } } ] diff --git a/dbrest-conformance b/dbrest-conformance new file mode 100755 index 0000000..6f09546 Binary files /dev/null and b/dbrest-conformance differ diff --git a/httpapi/auth_test.go b/httpapi/auth_test.go index 0c849d6..102b0af 100644 --- a/httpapi/auth_test.go +++ b/httpapi/auth_test.go @@ -84,8 +84,12 @@ func TestExpiredTokenIsRejected(t *testing.T) { if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { t.Fatalf("decode: %v", err) } - if body["code"] != "PGRST301" { - t.Errorf("code = %v, want PGRST301", body["code"]) + if body["code"] != "PGRST303" { + t.Errorf("code = %v, want PGRST303", body["code"]) + } + want := `Bearer error="invalid_token", error_description="JWT expired"` + if h := resp.Header.Get("WWW-Authenticate"); h != want { + t.Errorf("WWW-Authenticate = %q, want %q", h, want) } } @@ -99,8 +103,11 @@ func TestGarbageTokenIsRejected(t *testing.T) { } var body map[string]any json.NewDecoder(resp.Body).Decode(&body) - if body["code"] != "PGRST302" { - t.Errorf("code = %v, want PGRST302", body["code"]) + if body["code"] != "PGRST301" { + t.Errorf("code = %v, want PGRST301", body["code"]) + } + if h := resp.Header.Get("WWW-Authenticate"); !strings.Contains(h, `error="invalid_token"`) { + t.Errorf("WWW-Authenticate = %q, want the invalid_token challenge", h) } } @@ -130,6 +137,55 @@ func TestAuthRejectionShortCircuitsBeforeQuery(t *testing.T) { } } +func TestNoVerifierNoDefaultRoleFailsClosed(t *testing.T) { + // A bare NewServer has no identity source: every request is refused + // with 401 PGRST302 rather than served as a made-up role. + srv := newServerNoRole(t) + resp := do(t, srv, http.MethodGet, "/films", nil) + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status = %d, want 401", resp.StatusCode) + } + var body map[string]any + json.NewDecoder(resp.Body).Decode(&body) + if body["code"] != "PGRST302" { + t.Errorf("code = %v, want PGRST302", body["code"]) + } + if body["message"] != "Anonymous access is disabled" { + t.Errorf("message = %v, want the exact PGRST302 text", body["message"]) + } + if h := resp.Header.Get("WWW-Authenticate"); h != "Bearer" { + t.Errorf("WWW-Authenticate = %q, want Bearer", h) + } +} + +func TestTokenWithoutSecretIs500(t *testing.T) { + // A verifier with no key material refuses presented tokens with the + // PGRST300 misconfiguration error instead of running them as anon. + srv := authServer(t, auth.Config{Secret: []byte{}, AnonRole: "anon"}) + resp := do(t, srv, http.MethodGet, "/films", map[string]string{ + "Authorization": "Bearer some.jwt.token", + }) + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("status = %d, want 500", resp.StatusCode) + } + var body map[string]any + json.NewDecoder(resp.Body).Decode(&body) + if body["code"] != "PGRST300" { + t.Errorf("code = %v, want PGRST300", body["code"]) + } + if body["message"] != "Server lacks JWT secret" { + t.Errorf("message = %v, want Server lacks JWT secret", body["message"]) + } + if h := resp.Header.Get("WWW-Authenticate"); h != "" { + t.Errorf("a PGRST300 must not carry a challenge, got %q", h) + } + // Tokenless requests on the same server still run as anon. + resp = do(t, srv, http.MethodGet, "/films?select=id", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("tokenless status = %d, want 200", resp.StatusCode) + } +} + func TestSecretNeverEchoedInError(t *testing.T) { srv := authServer(t, auth.Config{}) resp := do(t, srv, http.MethodGet, "/films", map[string]string{ diff --git a/httpapi/authz_test.go b/httpapi/authz_test.go index 95d7f83..3fe26a2 100644 --- a/httpapi/authz_test.go +++ b/httpapi/authz_test.go @@ -69,14 +69,32 @@ func TestAuthzGrantedReadSucceeds(t *testing.T) { } } -func TestAuthzColumnGrantHidesUngrantedColumn(t *testing.T) { +func TestAuthzColumnGrantRejectsStarProjection(t *testing.T) { srv := authzServer(t, []authz.Grant{ {Role: "web_user", Relation: "films", Action: authz.Select, Columns: []string{"id", "title"}}, }, nil) - // A star projection is narrowed to the granted columns. + // No select parameter is SELECT *, which a column-limited grant does not + // cover: PostgreSQL raises 42501, PostgREST surfaces 403, and so do we. resp := do(t, srv, http.MethodGet, "/films?id=eq.2", map[string]string{ "Authorization": "Bearer " + userToken(t, "web_user", "alice"), }) + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("status = %d, want 403", resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "42501" { + t.Errorf("code = %v, want 42501", env["code"]) + } +} + +func TestAuthzColumnGrantAllowsNamedColumns(t *testing.T) { + srv := authzServer(t, []authz.Grant{ + {Role: "web_user", Relation: "films", Action: authz.Select, Columns: []string{"id", "title"}}, + }, nil) + resp := do(t, srv, http.MethodGet, "/films?select=id,title&id=eq.2", map[string]string{ + "Authorization": "Bearer " + userToken(t, "web_user", "alice"), + }) if resp.StatusCode != http.StatusOK { t.Fatalf("status = %d, want 200", resp.StatusCode) } @@ -84,9 +102,6 @@ func TestAuthzColumnGrantHidesUngrantedColumn(t *testing.T) { if len(rows) != 1 { t.Fatalf("rows = %d, want 1", len(rows)) } - if _, ok := rows[0]["rating"]; ok { - t.Error("rating is not granted and must not appear") - } if _, ok := rows[0]["title"]; !ok { t.Error("title is granted and should appear") } diff --git a/httpapi/content_negotiation_test.go b/httpapi/content_negotiation_test.go index b4d5c31..dead0bc 100644 --- a/httpapi/content_negotiation_test.go +++ b/httpapi/content_negotiation_test.go @@ -176,12 +176,14 @@ func TestPostUnsupportedMediaType(t *testing.T) { resp := send(t, srv, http.MethodPost, "/films", "", map[string]string{ "Content-Type": "application/xml", }) - if resp.StatusCode != http.StatusUnsupportedMediaType { - t.Fatalf("status = %d, want 415", resp.StatusCode) + // Live v14 answers 400 PGRST102 for an unparseable request Content-Type; + // the docs' PGRST107/415 row is stale (see compat/errors_v14_test.go). + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) } env := decodeEnvelope(t, resp) - if env["code"] != "PGRST107" { - t.Errorf("code = %v, want PGRST107", env["code"]) + if env["code"] != "PGRST102" { + t.Errorf("code = %v, want PGRST102", env["code"]) } } diff --git a/httpapi/context_settings_test.go b/httpapi/context_settings_test.go new file mode 100644 index 0000000..81eccf7 --- /dev/null +++ b/httpapi/context_settings_test.go @@ -0,0 +1,44 @@ +package httpapi_test + +import ( + "net/http" + "testing" +) + +// TestContextCarriesConfiguredSettings checks that db-pre-request, +// app.settings.*, and log-query reach the backend on the per-request context, +// the seam each driver consumes them from. +func TestContextCarriesConfiguredSettings(t *testing.T) { + srv, cb := captureServer(t) + srv.SetPreRequest("check_request") + srv.SetAppSettings(map[string]string{"tenant": "acme"}) + srv.SetLogQuery(true) + + srv.ServeHTTP(newRecorder(), newReq(http.MethodGet, "/films")) + if cb.got == nil { + t.Fatal("backend never saw a request context") + } + if cb.got.PreRequest != "check_request" { + t.Errorf("PreRequest = %q, want check_request", cb.got.PreRequest) + } + if cb.got.AppSettings["tenant"] != "acme" { + t.Errorf("AppSettings = %v, want tenant=acme", cb.got.AppSettings) + } + if !cb.got.LogQuery { + t.Error("LogQuery did not reach the backend") + } +} + +// TestContextSettingsUnsetByDefault pins the unconfigured shape: no hook, no +// settings, no echo on the context. +func TestContextSettingsUnsetByDefault(t *testing.T) { + srv, cb := captureServer(t) + srv.ServeHTTP(newRecorder(), newReq(http.MethodGet, "/films")) + if cb.got == nil { + t.Fatal("backend never saw a request context") + } + if cb.got.PreRequest != "" || len(cb.got.AppSettings) != 0 || cb.got.LogQuery { + t.Errorf("unconfigured context carries settings: pre=%q app=%v logQuery=%v", + cb.got.PreRequest, cb.got.AppSettings, cb.got.LogQuery) + } +} diff --git a/httpapi/cors_test.go b/httpapi/cors_test.go new file mode 100644 index 0000000..c39a2a2 --- /dev/null +++ b/httpapi/cors_test.go @@ -0,0 +1,105 @@ +package httpapi_test + +import ( + "net/http" + "testing" +) + +// TestCORSPreflightDefault checks the permissive default: any origin gets a +// wildcard preflight answer with the PostgREST method and header lists. +func TestCORSPreflightDefault(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodOptions, "/films", map[string]string{ + "Origin": "http://example.com", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Foo,Bar", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("preflight status = %d, want 200", resp.StatusCode) + } + want := map[string]string{ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PATCH, PUT, DELETE, OPTIONS, HEAD", + "Access-Control-Allow-Headers": "Authorization, Foo, Bar, Accept, Accept-Language, Content-Language", + "Access-Control-Max-Age": "86400", + } + for k, v := range want { + if got := resp.Header.Get(k); got != v { + t.Errorf("%s = %q, want %q", k, got, v) + } + } + if resp.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Error("wildcard preflight must not carry Allow-Credentials") + } +} + +// TestCORSSimpleRequestDefault checks that a plain cross-origin read carries +// the wildcard origin and the exposed-headers list. +func TestCORSSimpleRequestDefault(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films", map[string]string{ + "Origin": "http://example.com", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "*" { + t.Errorf("Allow-Origin = %q, want *", got) + } + const expose = "Content-Encoding, Content-Location, Content-Range, Content-Type, " + + "Date, Location, Server, Transfer-Encoding, Range-Unit" + if got := resp.Header.Get("Access-Control-Expose-Headers"); got != expose { + t.Errorf("Expose-Headers = %q", got) + } +} + +// TestCORSRestrictedOrigins checks server-cors-allowed-origins semantics: a +// listed origin is reflected with credentials, an unlisted one gets no CORS +// headers but the request still runs. +func TestCORSRestrictedOrigins(t *testing.T) { + srv := newServer(t) + srv.SetCORSAllowedOrigins([]string{"http://allowed.example"}) + + resp := do(t, srv, http.MethodGet, "/films", map[string]string{ + "Origin": "http://allowed.example", + }) + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "http://allowed.example" { + t.Errorf("Allow-Origin = %q, want the reflected origin", got) + } + if got := resp.Header.Get("Access-Control-Allow-Credentials"); got != "true" { + t.Errorf("Allow-Credentials = %q, want true", got) + } + + resp = do(t, srv, http.MethodGet, "/films", map[string]string{ + "Origin": "http://denied.example", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unlisted origin must still be served, got %d", resp.StatusCode) + } + if got := resp.Header.Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("unlisted origin got Allow-Origin %q, want none", got) + } + + preflight := do(t, srv, http.MethodOptions, "/films", map[string]string{ + "Origin": "http://allowed.example", + "Access-Control-Request-Method": "POST", + }) + if got := preflight.Header.Get("Access-Control-Allow-Origin"); got != "http://allowed.example" { + t.Errorf("preflight Allow-Origin = %q", got) + } + if got := preflight.Header.Get("Access-Control-Allow-Headers"); got != "Authorization, Accept, Accept-Language, Content-Language" { + t.Errorf("preflight Allow-Headers = %q", got) + } +} + +// TestCORSNoOriginUntouched checks that a same-origin request gets no CORS +// headers at all. +func TestCORSNoOriginUntouched(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films", nil) + for _, k := range []string{"Access-Control-Allow-Origin", "Access-Control-Expose-Headers"} { + if got := resp.Header.Get(k); got != "" { + t.Errorf("%s = %q on a request without Origin", k, got) + } + } +} diff --git a/httpapi/embedding_test.go b/httpapi/embedding_test.go index 7bb32df..1e46f3c 100644 --- a/httpapi/embedding_test.go +++ b/httpapi/embedding_test.go @@ -66,7 +66,9 @@ func newEmbedServer(t testing.TB) *httpapi.Server { if err != nil { t.Fatalf("introspect: %v", err) } - return httpapi.NewServer(be, model, nil) + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv } func TestEmbedToOneObject(t *testing.T) { @@ -224,6 +226,12 @@ func TestEmbedNoRelationship(t *testing.T) { if env["code"] != "PGRST200" { t.Errorf("code = %v, want PGRST200", env["code"]) } + // The rendered body carries the searched-pair details, not a null (item 04.4). + details, _ := env["details"].(string) + want := "Searched for a foreign key relationship between 'films' and 'nonsense' in the schema 'public', but no matches were found." + if details != want { + t.Errorf("details = %q, want %q", details, want) + } } func TestEmbedColumnInCSV(t *testing.T) { @@ -248,6 +256,161 @@ func TestEmbedColumnInCSV(t *testing.T) { } } +// order=rel(col) sorts the parent by a to-one embed's column. Films sort by +// their director's name; the film with no director (a NULL key) lands last under +// the requested nullslast (item 07.6). +func TestRelatedOrderSortsParent(t *testing.T) { + srv := newEmbedServer(t) + resp := do(t, srv, http.MethodGet, + "/films?select=title,director:directors(name)&order=director(name).asc.nullslast", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + rows := decodeArray(t, resp) + got := make([]string, len(rows)) + for i, r := range rows { + got[i], _ = r["title"].(string) + } + // Lang < Scott < Villeneuve by name; Untitled has no director, so NULL last. + want := []string{"Metropolis", "Blade Runner", "Arrival", "Untitled"} + if len(got) != len(want) { + t.Fatalf("got %d rows %v, want %d", len(got), got, len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("row %d title = %q, want %q (order %v)", i, got[i], want[i], got) + } + } +} + +// The same order, nullsfirst, floats the directorless film to the top. +func TestRelatedOrderNullsFirst(t *testing.T) { + srv := newEmbedServer(t) + resp := do(t, srv, http.MethodGet, + "/films?select=title,director:directors(name)&order=director(name).asc.nullsfirst", nil) + rows := decodeArray(t, resp) + if len(rows) == 0 { + t.Fatal("no rows") + } + if title, _ := rows[0]["title"].(string); title != "Untitled" { + t.Errorf("first title = %q, want Untitled (NULL director sorts first)", title) + } +} + +// Ordering by a relation the select never embedded is PGRST108. +func TestRelatedOrderNotEmbeddedHTTP(t *testing.T) { + srv := newEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/films?select=title&order=director(name).asc", nil) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + env := decodeEnvelope(t, resp) + if env["code"] != "PGRST108" { + t.Errorf("code = %v, want PGRST108", env["code"]) + } +} + +// Ordering by a to-many embed is PGRST118: a director has many films, so it has +// no single film title to sort by. +func TestRelatedOrderToManyHTTP(t *testing.T) { + srv := newEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/directors?select=name,films(title)&order=films(title).asc", nil) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + env := decodeEnvelope(t, resp) + if env["code"] != "PGRST118" { + t.Errorf("code = %v, want PGRST118", env["code"]) + } +} + +// An empty-parenthesis embed hides the key from the response while still joining +// the relation: director() returns films without a director field, the opposite +// of director(*) (item 07.8). +func TestRelatedEmptySelectHidesKey(t *testing.T) { + srv := newEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/films?select=title,director:directors()&id=eq.1", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("got %d rows, want 1", len(rows)) + } + if _, has := rows[0]["director"]; has { + t.Errorf("director() should hide the key, got %#v", rows[0]) + } + if rows[0]["title"] != "Metropolis" { + t.Errorf("title = %v, want Metropolis", rows[0]["title"]) + } +} + +// An !inner empty-parenthesis embed prunes parents with no related row while +// still projecting nothing: the directorless film drops out, and no director key +// appears on those that remain. +func TestRelatedEmptySelectInnerFilters(t *testing.T) { + srv := newEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/films?select=title,directors!inner()&order=id", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + rows := decodeArray(t, resp) + // Four films, but only three have a director; the inner join drops Untitled. + if len(rows) != 3 { + t.Fatalf("got %d rows, want 3 (directorless film dropped)", len(rows)) + } + for _, r := range rows { + if _, has := r["directors"]; has { + t.Errorf("empty-paren embed should hide its key, got %#v", r) + } + if r["title"] == "Untitled" { + t.Errorf("Untitled has no director and should have been filtered out") + } + } +} + +// A to-one spread lifts the director's name straight onto the film row, with no +// nested director object (item 07.9). +func TestSpreadToOneLiftsColumn(t *testing.T) { + srv := newEmbedServer(t) + resp := do(t, srv, http.MethodGet, + "/films?select=title,...directors(director:name)&id=eq.1", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("got %d rows, want 1", len(rows)) + } + if _, nested := rows[0]["directors"]; nested { + t.Errorf("a spread must not nest a directors object, got %#v", rows[0]) + } + if rows[0]["director"] != "Lang" { + t.Errorf("lifted director = %v, want Lang", rows[0]["director"]) + } +} + +// A to-many spread lifts the related column as an array onto the parent row. +func TestSpreadToManyLiftsArray(t *testing.T) { + srv := newEmbedServer(t) + resp := do(t, srv, http.MethodGet, + "/directors?select=name,...films(title)&id=eq.1", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("got %d rows, want 1", len(rows)) + } + titles, ok := rows[0]["title"].([]any) + if !ok { + t.Fatalf("title = %#v, want an array", rows[0]["title"]) + } + if len(titles) != 1 || titles[0] != "Metropolis" { + t.Errorf("lifted titles = %v, want [Metropolis]", titles) + } +} + func BenchmarkEmbedToMany(b *testing.B) { srv := newEmbedServer(b) req := httptest.NewRequest(http.MethodGet, "/directors?select=name,films(title)&order=id", nil) diff --git a/httpapi/fts_test.go b/httpapi/fts_test.go index 621d57e..b61d866 100644 --- a/httpapi/fts_test.go +++ b/httpapi/fts_test.go @@ -45,7 +45,9 @@ func newFTSServer(t testing.TB) *httpapi.Server { if err != nil { t.Fatalf("introspect: %v", err) } - return httpapi.NewServer(be, model, nil) + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv } // TestFTSMatchSelectsRow exercises the full request path: an fts filter lowers to diff --git a/httpapi/jsonpath_test.go b/httpapi/jsonpath_test.go new file mode 100644 index 0000000..a6f0c3e --- /dev/null +++ b/httpapi/jsonpath_test.go @@ -0,0 +1,66 @@ +package httpapi_test + +import ( + "context" + "net/http" + "strings" + "testing" + + "github.com/tamnd/dbrest/backend/sqlite" + "github.com/tamnd/dbrest/httpapi" +) + +// newDocsServer seeds a table with a JSON column so the 07.1 JSON-path render +// contract can be checked end to end through the HTTP layer. +func newDocsServer(t testing.TB) *httpapi.Server { + t.Helper() + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + + _, err = be.DB().Exec(` + CREATE TABLE docs (id INTEGER PRIMARY KEY, data JSON); + INSERT INTO docs (id, data) VALUES + (1, '{"blood_type":"A-","meta":{"k":1}}'); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv +} + +// A final -> projection renders as raw JSON (decodes back to an object), while a +// final ->> projection renders as a plain JSON string. This is the renderer +// contract behind PostgREST's -> json / ->> text typing (07.1). +func TestJSONPathRenderTyping(t *testing.T) { + srv := newDocsServer(t) + resp := do(t, srv, http.MethodGet, "/docs?select=id,data->meta,data->>blood_type", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("rows = %d, want 1", len(rows)) + } + // -> meta is spliced raw: it decodes to a JSON object, not a quoted string. + meta, ok := rows[0]["meta"].(map[string]any) + if !ok { + t.Fatalf("meta = %T (%v), want a JSON object spliced verbatim", rows[0]["meta"], rows[0]["meta"]) + } + if meta["k"] != float64(1) { + t.Errorf("meta.k = %v, want 1", meta["k"]) + } + // ->> blood_type is text: a plain JSON string. + if bt, ok := rows[0]["blood_type"].(string); !ok || bt != "A-" { + t.Errorf("blood_type = %v (%T), want the string A-", rows[0]["blood_type"], rows[0]["blood_type"]) + } +} diff --git a/httpapi/maperror_test.go b/httpapi/maperror_test.go new file mode 100644 index 0000000..b9346bc --- /dev/null +++ b/httpapi/maperror_test.go @@ -0,0 +1,36 @@ +package httpapi + +import ( + "net/http" + "testing" + + "github.com/tamnd/dbrest/pgerr" +) + +// 04.7: a native 42501 insufficient_privilege is 403 for an authenticated +// request and 401 (with a Bearer challenge) for an anonymous one. The base +// error carries 403; mapExecError lifts only the anonymous case. + +func TestMapExecError42501Split(t *testing.T) { + base := pgerr.New(http.StatusForbidden, pgerr.CodeInsufficientPrivilege, "permission denied for table films") + + authed := mapExecError(nil, base, false) + if authed.HTTPStatus != http.StatusForbidden { + t.Errorf("authenticated status = %d, want 403", authed.HTTPStatus) + } + if authed.WWWAuthenticate != "" { + t.Errorf("authenticated WWW-Authenticate = %q, want none", authed.WWWAuthenticate) + } + + anon := mapExecError(nil, base, true) + if anon.HTTPStatus != http.StatusUnauthorized { + t.Errorf("anonymous status = %d, want 401", anon.HTTPStatus) + } + if anon.WWWAuthenticate != "Bearer" { + t.Errorf("anonymous WWW-Authenticate = %q, want Bearer", anon.WWWAuthenticate) + } + // The lift must not mutate the shared base error. + if base.HTTPStatus != http.StatusForbidden { + t.Errorf("base mutated to %d", base.HTTPStatus) + } +} diff --git a/httpapi/maxaffected_test.go b/httpapi/maxaffected_test.go new file mode 100644 index 0000000..3f8924a --- /dev/null +++ b/httpapi/maxaffected_test.go @@ -0,0 +1,108 @@ +package httpapi_test + +import ( + "net/http" + "strings" + "testing" +) + +// 02.2: Prefer: max-affected caps the rows a write may affect under +// handling=strict. A violation is 400 PGRST124 and the whole transaction rolls +// back; under lenient handling the preference is ignored entirely. + +// TestPatchMaxAffectedExceededRollsBack: a PATCH whose filter matches more rows +// than max-affected fails with PGRST124 and leaves every row unchanged. +func TestPatchMaxAffectedExceededRollsBack(t *testing.T) { + srv := newServer(t) + // year >= 1900 matches films 1, 2, 3 (film 4 has a NULL year), three rows. + resp := send(t, srv, http.MethodPatch, "/films?year=gte.1900", `{"rating":"X"}`, map[string]string{ + "Prefer": "handling=strict, max-affected=1", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + env := decodeEnvelope(t, resp) + if env["code"] != "PGRST124" { + t.Errorf("code = %v, want PGRST124", env["code"]) + } + if env["details"] != "The query affects 3 rows" { + t.Errorf("details = %v, want the affected count", env["details"]) + } + // The transaction rolled back: no row took the new rating. + after := do(t, srv, http.MethodGet, "/films?rating=eq.X&select=id", nil) + if rows := decodeArray(t, after); len(rows) != 0 { + t.Errorf("rollback failed, %d rows were updated", len(rows)) + } +} + +// TestDeleteMaxAffectedExceededRollsBack: a DELETE matching more rows than the +// bound fails with PGRST124 and deletes nothing. +func TestDeleteMaxAffectedExceededRollsBack(t *testing.T) { + srv := newServer(t) + // No filter: all four seed rows match. + resp := send(t, srv, http.MethodDelete, "/films", "", map[string]string{ + "Prefer": "handling=strict, max-affected=2", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + if env := decodeEnvelope(t, resp); env["code"] != "PGRST124" { + t.Errorf("code = %v, want PGRST124", env["code"]) + } + after := do(t, srv, http.MethodGet, "/films?select=id", nil) + if rows := decodeArray(t, after); len(rows) != 4 { + t.Errorf("rollback failed, %d rows remain, want 4", len(rows)) + } +} + +// TestPatchMaxAffectedWithinBoundCommits: a write at or under the bound proceeds +// normally and persists. +func TestPatchMaxAffectedWithinBoundCommits(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.2", `{"rating":"X"}`, map[string]string{ + "Prefer": "handling=strict, max-affected=1", + }) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + after := do(t, srv, http.MethodGet, "/films?id=eq.2&select=rating", nil) + rows := decodeArray(t, after) + if len(rows) != 1 || rows[0]["rating"] != "X" { + t.Errorf("write did not persist: %v", rows) + } +} + +// TestPatchMaxAffectedLenientIgnored: without handling=strict the preference is +// ignored, so an over-broad write still commits and is not echoed. +func TestPatchMaxAffectedLenientIgnored(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?year=gte.1900", `{"rating":"X"}`, map[string]string{ + "Prefer": "max-affected=1", + }) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + if pa := resp.Header.Get("Preference-Applied"); pa != "" { + t.Errorf("Preference-Applied = %q, want max-affected not echoed under lenient", pa) + } + after := do(t, srv, http.MethodGet, "/films?rating=eq.X&select=id", nil) + if rows := decodeArray(t, after); len(rows) != 3 { + t.Errorf("lenient write affected %d rows, want all 3", len(rows)) + } +} + +// TestMaxAffectedEchoedUnderStrict: a strict request that stays within the bound +// echoes max-affected in Preference-Applied. +func TestMaxAffectedEchoedUnderStrict(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.2", `{"rating":"X"}`, map[string]string{ + "Prefer": "handling=strict, max-affected=5", + }) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + pa := resp.Header.Get("Preference-Applied") + if pa == "" || !strings.Contains(pa, "max-affected=5") { + t.Errorf("Preference-Applied = %q, want it to echo max-affected=5", pa) + } +} diff --git a/httpapi/maxbody_test.go b/httpapi/maxbody_test.go new file mode 100644 index 0000000..d71011d --- /dev/null +++ b/httpapi/maxbody_test.go @@ -0,0 +1,77 @@ +package httpapi_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestUnlimitedBodyByDefault checks dbrest imposes no body cap out of the box, +// matching PostgREST: a bulk insert far past the old 16 MiB limit is accepted. +func TestUnlimitedBodyByDefault(t *testing.T) { + srv := newServer(t) + + var sb strings.Builder + sb.WriteByte('[') + for i := 0; i < 20000; i++ { + if i > 0 { + sb.WriteByte(',') + } + sb.WriteString(`{"title":"`) + sb.WriteString(strings.Repeat("x", 64)) + sb.WriteString(`"}`) + } + sb.WriteByte(']') + + req := httptest.NewRequest(http.MethodPost, "/films", strings.NewReader(sb.String())) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201; body=%s", rec.Code, rec.Body.String()) + } +} + +// TestMaxRequestBodyRejectsOversizeWith413 checks a configured cap answers an +// oversize body with 413 PGRSTX13 and the byte bound, not a parse error. +func TestMaxRequestBodyRejectsOversizeWith413(t *testing.T) { + srv := newServer(t) + srv.SetMaxRequestBody(32) + + big := `[{"title":"` + strings.Repeat("x", 200) + `"}]` + req := httptest.NewRequest(http.MethodPost, "/films", strings.NewReader(big)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("status = %d, want 413; body=%s", rec.Code, rec.Body.String()) + } + var env struct{ Code, Message string } + if err := json.NewDecoder(rec.Body).Decode(&env); err != nil { + t.Fatalf("decode: %v", err) + } + if env.Code != "PGRSTX13" { + t.Errorf("code = %q, want PGRSTX13", env.Code) + } + if !strings.Contains(env.Message, "32") { + t.Errorf("message %q should name the 32 byte bound", env.Message) + } +} + +// TestMaxRequestBodyAllowsUnderCap checks a body within the configured cap is +// processed normally. +func TestMaxRequestBodyAllowsUnderCap(t *testing.T) { + srv := newServer(t) + srv.SetMaxRequestBody(1 << 20) + + req := httptest.NewRequest(http.MethodPost, "/films", strings.NewReader(`{"title":"Solaris"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want 201; body=%s", rec.Code, rec.Body.String()) + } +} diff --git a/httpapi/maxrows_test.go b/httpapi/maxrows_test.go new file mode 100644 index 0000000..62db45a --- /dev/null +++ b/httpapi/maxrows_test.go @@ -0,0 +1,110 @@ +package httpapi_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// readJSONArray decodes a JSON array response body. +func readJSONArray(t *testing.T, resp *http.Response) []map[string]any { + t.Helper() + var rows []map[string]any + if err := json.NewDecoder(resp.Body).Decode(&rows); err != nil { + t.Fatalf("decode body: %v", err) + } + return rows +} + +// TestMaxRowsCapsRead checks that db-max-rows is an implicit LIMIT on reads: +// the body is truncated and Content-Range reports the served window. +func TestMaxRowsCapsRead(t *testing.T) { + srv := newServer(t) // 4 films seeded + srv.SetMaxRows(2) + + resp := do(t, srv, http.MethodGet, "/films?order=id", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200 (no count requested)", resp.StatusCode) + } + if got := resp.Header.Get("Content-Range"); got != "0-1/*" { + t.Errorf("Content-Range = %q, want 0-1/*", got) + } + if rows := readJSONArray(t, resp); len(rows) != 2 { + t.Errorf("rows = %d, want 2", len(rows)) + } +} + +// TestMaxRowsWithExactCount checks the 206 shape: a capped read with +// count=exact reports the true total and Partial Content. +func TestMaxRowsWithExactCount(t *testing.T) { + srv := newServer(t) + srv.SetMaxRows(2) + + resp := do(t, srv, http.MethodGet, "/films?order=id", map[string]string{"Prefer": "count=exact"}) + if resp.StatusCode != http.StatusPartialContent { + t.Fatalf("status = %d, want 206", resp.StatusCode) + } + if got := resp.Header.Get("Content-Range"); got != "0-1/4" { + t.Errorf("Content-Range = %q, want 0-1/4", got) + } +} + +// TestMaxRowsMinWithRequestedLimit checks min(requested, max-rows) in both +// directions. +func TestMaxRowsMinWithRequestedLimit(t *testing.T) { + srv := newServer(t) + srv.SetMaxRows(2) + + resp := do(t, srv, http.MethodGet, "/films?order=id&limit=1", nil) + if rows := readJSONArray(t, resp); len(rows) != 1 { + t.Errorf("limit below cap: rows = %d, want 1", len(rows)) + } + resp = do(t, srv, http.MethodGet, "/films?order=id&limit=10", nil) + if rows := readJSONArray(t, resp); len(rows) != 2 { + t.Errorf("limit above cap: rows = %d, want 2", len(rows)) + } +} + +// TestMaxRowsExemptsMutationRepresentation checks the PostgREST v10+ rule: +// the representation of a write returns every affected row, uncapped. +func TestMaxRowsExemptsMutationRepresentation(t *testing.T) { + srv := newServer(t) + srv.SetMaxRows(1) + + body := `[{"title":"One"},{"title":"Two"},{"title":"Three"}]` + req := httptest.NewRequest(http.MethodPost, "/films", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Prefer", "return=representation") + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + resp := rec.Result() + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201", resp.StatusCode) + } + if rows := readJSONArray(t, resp); len(rows) != 3 { + t.Errorf("representation rows = %d, want all 3 despite max-rows=1", len(rows)) + } +} + +// TestMaxRowsCapsRPC checks that a table-returning function is capped too. +// The setof-scalar path compiles the function body verbatim in the sqlite +// backend and cannot take a window yet; the cap reaches it once that +// compiler gap closes (the RPC pagination item). +func TestMaxRowsCapsRPC(t *testing.T) { + srv := newRPCServer(t) // 3 films + srv.SetMaxRows(1) + + resp := do(t, srv, http.MethodGet, "/rpc/films_after?y=1900", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var rows []any + if err := json.NewDecoder(resp.Body).Decode(&rows); err != nil { + t.Fatalf("decode body: %v", err) + } + if len(rows) != 1 { + t.Errorf("rpc rows = %d, want 1", len(rows)) + } +} diff --git a/httpapi/negotiate.go b/httpapi/negotiate.go index 95a07ae..2400aef 100644 --- a/httpapi/negotiate.go +++ b/httpapi/negotiate.go @@ -4,6 +4,8 @@ import ( "sort" "strconv" "strings" + + "github.com/tamnd/dbrest/backend" ) // The response media types dbrest can produce, in preference order. A wildcard @@ -22,13 +24,30 @@ const ( var supportedMedia = []string{mediaJSON, mediaArray, mediaObject, mediaPlan, mediaCSV, mediaOctet, mediaText} +// The internal media keys for the nulls=stripped variants of the vendor array +// and object types. They are not real Accept literals; negotiate returns them so +// the render path knows to drop null-valued keys and echo the parameterized +// Content-Type. +const ( + mediaArrayStripped = "application/vnd.pgrst.array+json;nulls=stripped" + mediaObjectStripped = "application/vnd.pgrst.object+json;nulls=stripped" +) + +// singularMedia reports whether a negotiated media type asks for a single object +// (the object vendor type or its nulls=stripped variant). +func singularMedia(media string) bool { + return media == mediaObject || media == mediaObjectStripped +} + // mediaRange is one parsed entry of an Accept header: a type/subtype pair, its -// quality value, and its position in the header for stable tie-breaking. +// quality value, its position in the header for stable tie-breaking, and whether +// it carried the nulls=stripped parameter. type mediaRange struct { - typ string - sub string - q float64 - order int + typ string + sub string + q float64 + order int + stripNulls bool } // parseAccept parses the Accept header values into media ranges sorted by @@ -48,14 +67,19 @@ func parseAccept(headers []string) []mediaRange { continue } q := 1.0 + stripNulls := false for _, p := range segs[1:] { - if v, ok := strings.CutPrefix(strings.TrimSpace(p), "q="); ok { + p = strings.TrimSpace(p) + if v, ok := strings.CutPrefix(p, "q="); ok { if f, err := strconv.ParseFloat(v, 64); err == nil { q = f } } + if v, ok := strings.CutPrefix(strings.ToLower(p), "nulls="); ok && strings.TrimSpace(v) == "stripped" { + stripNulls = true + } } - ranges = append(ranges, mediaRange{strings.ToLower(typ), strings.ToLower(sub), q, n}) + ranges = append(ranges, mediaRange{strings.ToLower(typ), strings.ToLower(sub), q, n, stripNulls}) n++ } } @@ -63,31 +87,80 @@ func parseAccept(headers []string) []mediaRange { return ranges } -// planAnalyze reports whether the Accept header for vnd.pgrst.plan+json carries -// "options=analyze", which asks for EXPLAIN ANALYZE rather than plain EXPLAIN. -func planAnalyze(headers []string) bool { +// planSubtypes are the application/vnd.pgrst.plan family subtypes dbrest +// recognizes, mapping each to its output format. The bare type and the +text +// suffix are PostgREST's text default; +json is the machine-readable form. +var planSubtypes = map[string]backend.PlanFormat{ + "vnd.pgrst.plan": backend.PlanText, + "vnd.pgrst.plan+text": backend.PlanText, + "vnd.pgrst.plan+json": backend.PlanJSON, +} + +// parsePlan scans the Accept header for the application/vnd.pgrst.plan family and +// returns the parsed plan options. The second return is false when no plan type +// is present. Output defaults to text (the bare type and +text); +json selects +// the JSON form. The for="" parameter (default application/json) and the +// options=a|b|c flags (analyze, verbose, settings, buffers, wal) ride along. +func parsePlan(headers []string) (backend.PlanOptions, bool) { for _, h := range headers { for part := range strings.SplitSeq(h, ",") { - part = strings.TrimSpace(part) - segs := strings.Split(part, ";") + segs := strings.Split(strings.TrimSpace(part), ";") typ, sub, ok := strings.Cut(strings.TrimSpace(segs[0]), "/") if !ok { continue } - if strings.ToLower(typ)+"/"+strings.ToLower(sub) != "application/vnd.pgrst.plan+json" { + typ = strings.ToLower(strings.TrimSpace(typ)) + sub = strings.ToLower(strings.TrimSpace(sub)) + format, isPlan := planSubtypes[sub] + if typ != "application" || !isPlan { continue } + opts := backend.PlanOptions{Format: format, For: mediaJSON} for _, p := range segs[1:] { - p = strings.TrimSpace(p) - if v, ok := strings.CutPrefix(strings.ToLower(p), "options="); ok { - if strings.Contains(v, "analyze") { - return true + k, v, ok := strings.Cut(strings.TrimSpace(p), "=") + if !ok { + continue + } + k = strings.ToLower(strings.TrimSpace(k)) + v = strings.Trim(strings.TrimSpace(v), `"`) + switch k { + case "for": + opts.For = v + case "options": + for _, o := range strings.Split(strings.ToLower(v), "|") { + switch strings.TrimSpace(o) { + case "analyze": + opts.Analyze = true + case "verbose": + opts.Verbose = true + case "settings": + opts.Settings = true + case "buffers": + opts.Buffers = true + case "wal": + opts.Wal = true + } } } } + return opts, true } } - return false + return backend.PlanOptions{}, false +} + +// vendorSynonym maps the suffixless PostgREST vendor spellings to their +json +// forms, which PostgREST accepts as synonyms. Any other type passes through +// unchanged. +func vendorSynonym(full string) string { + switch full { + case "application/vnd.pgrst.array": + return mediaArray + case "application/vnd.pgrst.object": + return mediaObject + default: + return full + } } // negotiate picks the best supported response media type for the Accept header. @@ -114,9 +187,27 @@ func negotiate(headers []string) (string, bool) { } } default: - full := r.typ + "/" + r.sub + // The plan family (bare, +text, +json) negotiates to the single plan + // sentinel; parsePlan recovers the exact format and options later. + if r.typ == "application" { + if _, isPlan := planSubtypes[r.sub]; isPlan { + return mediaPlan, true + } + } + full := vendorSynonym(r.typ + "/" + r.sub) for _, m := range supportedMedia { if m == full { + // nulls=stripped applies only to the vendor array and object + // types; on plain application/json the parameter is ignored, + // matching PostgREST. + if r.stripNulls { + switch m { + case mediaArray: + return mediaArrayStripped, true + case mediaObject: + return mediaObjectStripped, true + } + } return m, true } } diff --git a/httpapi/negotiate_test.go b/httpapi/negotiate_test.go index 9419eca..fe522b1 100644 --- a/httpapi/negotiate_test.go +++ b/httpapi/negotiate_test.go @@ -1,6 +1,10 @@ package httpapi -import "testing" +import ( + "testing" + + "github.com/tamnd/dbrest/backend" +) func TestNegotiateDefaults(t *testing.T) { cases := []struct { @@ -50,9 +54,92 @@ func TestNegotiateSkipsUnsupportedThenMatches(t *testing.T) { } } +// TestNegotiateSuffixlessVendorTypes checks the suffixless PostgREST vendor +// spellings resolve to the same renderers as their +json forms. +func TestNegotiateSuffixlessVendorTypes(t *testing.T) { + if got, ok := negotiate([]string{"application/vnd.pgrst.object"}); !ok || got != mediaObject { + t.Errorf("object synonym got (%q,%v), want %q", got, ok, mediaObject) + } + if got, ok := negotiate([]string{"application/vnd.pgrst.array"}); !ok || got != mediaArray { + t.Errorf("array synonym got (%q,%v), want %q", got, ok, mediaArray) + } +} + func TestNegotiateZeroQualityRefuses(t *testing.T) { // q=0 explicitly refuses a type; with nothing else acceptable this is a 406. if got, ok := negotiate([]string{"application/json;q=0"}); ok { t.Errorf("q=0 should refuse, got (%q,%v)", got, ok) } } + +// TestNegotiatePlanFamily checks every spelling of the plan media type (bare, +// +text, +json, and a parameterized form) negotiates to the single mediaPlan +// sentinel that servePlan keys on. +func TestNegotiatePlanFamily(t *testing.T) { + cases := []string{ + "application/vnd.pgrst.plan", + "application/vnd.pgrst.plan+text", + "application/vnd.pgrst.plan+json", + `application/vnd.pgrst.plan+json; for="application/json"; options=analyze`, + } + for _, accept := range cases { + t.Run(accept, func(t *testing.T) { + got, ok := negotiate([]string{accept}) + if !ok || got != mediaPlan { + t.Errorf("negotiate(%q) = (%q,%v), want (%q,true)", accept, got, ok, mediaPlan) + } + }) + } +} + +// TestParsePlanFormat checks the output format each plan spelling selects: bare +// and +text are PostgREST's text default, +json the machine-readable form. +func TestParsePlanFormat(t *testing.T) { + cases := []struct { + accept string + want backend.PlanFormat + }{ + {"application/vnd.pgrst.plan", backend.PlanText}, + {"application/vnd.pgrst.plan+text", backend.PlanText}, + {"application/vnd.pgrst.plan+json", backend.PlanJSON}, + } + for _, c := range cases { + t.Run(c.accept, func(t *testing.T) { + opts, ok := parsePlan([]string{c.accept}) + if !ok { + t.Fatalf("parsePlan(%q) not recognized", c.accept) + } + if opts.Format != c.want { + t.Errorf("Format = %d, want %d", opts.Format, c.want) + } + // for= defaults to application/json when the parameter is absent. + if opts.For != mediaJSON { + t.Errorf("For = %q, want %q", opts.For, mediaJSON) + } + }) + } +} + +// TestParsePlanNonPlan reports that a non-plan Accept is not recognized as a +// plan request. +func TestParsePlanNonPlan(t *testing.T) { + if _, ok := parsePlan([]string{"application/json"}); ok { + t.Error("application/json should not parse as a plan request") + } +} + +// TestParsePlanForAndOptions checks the for="" target and the options= +// flag list are both parsed off the plan media type. +func TestParsePlanForAndOptions(t *testing.T) { + accept := `application/vnd.pgrst.plan+json; for="text/csv"; options=analyze|verbose|settings|buffers|wal` + opts, ok := parsePlan([]string{accept}) + if !ok { + t.Fatalf("parsePlan(%q) not recognized", accept) + } + if opts.For != "text/csv" { + t.Errorf("For = %q, want text/csv", opts.For) + } + if !opts.Analyze || !opts.Verbose || !opts.Settings || !opts.Buffers || !opts.Wal { + t.Errorf("options not all set: %+v", opts) + } +} diff --git a/httpapi/options_test.go b/httpapi/options_test.go new file mode 100644 index 0000000..fb7594b --- /dev/null +++ b/httpapi/options_test.go @@ -0,0 +1,111 @@ +package httpapi_test + +import ( + "encoding/json" + "net/http" + "testing" +) + +// TestOptionsOnTableAnswersAllow checks a plain OPTIONS on a table (no CORS +// preflight headers) is 200 with the full relation verb set and no body, the way +// PostgREST answers OPTIONS without running a transaction. +func TestOptionsOnTableAnswersAllow(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodOptions, "/films", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if allow := resp.Header.Get("Allow"); allow != "OPTIONS,GET,HEAD,POST,PUT,PATCH,DELETE" { + t.Errorf("Allow = %q", allow) + } + buf := make([]byte, 1) + if n, _ := resp.Body.Read(buf); n != 0 { + t.Error("OPTIONS should have no body") + } +} + +// TestOptionsOnVolatileRPCIsPostOnly checks OPTIONS on a volatile function +// answers OPTIONS,POST: a function that writes is not reachable by GET. +func TestOptionsOnVolatileRPCIsPostOnly(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodOptions, "/rpc/bump_year", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if allow := resp.Header.Get("Allow"); allow != "OPTIONS,POST" { + t.Errorf("Allow = %q, want OPTIONS,POST", allow) + } +} + +// TestOptionsOnReadOnlyRPCAllowsGet checks OPTIONS on a read-only function also +// answers GET and HEAD, the verbs a stable/immutable function accepts. +func TestOptionsOnReadOnlyRPCAllowsGet(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodOptions, "/rpc/film_titles", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if allow := resp.Header.Get("Allow"); allow != "OPTIONS,GET,HEAD,POST" { + t.Errorf("Allow = %q, want OPTIONS,GET,HEAD,POST", allow) + } +} + +// TestUnsupportedMethodIs405PGRST117 checks a verb the server implements nowhere +// is PostgREST's 405 PGRST117 naming the method, not the capability gate's 400 +// PGRST127. +func TestUnsupportedMethodIs405PGRST117(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodTrace, "/films", nil) + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("status = %d, want 405", resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "PGRST117" { + t.Errorf("code = %v, want PGRST117", env["code"]) + } +} + +// TestStrictHandlingRejectsUnknownPreference checks a read under +// Prefer: handling=strict carrying an unknown preference is a 400 PGRST122, +// while the same request under the default lenient handling succeeds. +func TestStrictHandlingRejectsUnknownPreference(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films", map[string]string{ + "Prefer": "handling=strict, frobnicate=yes", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "PGRST122" { + t.Errorf("code = %v, want PGRST122", env["code"]) + } + + ok := do(t, srv, http.MethodGet, "/films", map[string]string{ + "Prefer": "frobnicate=yes", + }) + if ok.StatusCode != http.StatusOK { + t.Errorf("lenient status = %d, want 200", ok.StatusCode) + } +} + +// TestDeleteOnRPCIsPGRST101 checks PUT/PATCH/DELETE on a function keep +// PostgREST's PGRST101 with the exact "Cannot use the method on RPC" +// text, distinct from the PGRST117 unsupported-method case. +func TestDeleteOnRPCIsPGRST101(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodDelete, "/rpc/add_them", nil) + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("status = %d, want 405", resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "PGRST101" { + t.Errorf("code = %v, want PGRST101", env["code"]) + } + if env["message"] != "Cannot use the DELETE method on RPC" { + t.Errorf("message = %v", env["message"]) + } +} diff --git a/httpapi/pagination_test.go b/httpapi/pagination_test.go new file mode 100644 index 0000000..ae004fe --- /dev/null +++ b/httpapi/pagination_test.go @@ -0,0 +1,97 @@ +package httpapi_test + +import ( + "encoding/json" + "net/http" + "testing" +) + +// TestReadPlannedCountReturns206 checks a bounded window under a non-exact count +// is 206, not 200: PostgREST returns 206 whenever a total is known and the span +// is smaller, for every count kind. SQLite downgrades planned to an exact total, +// so the four-row table over a one-row window is a genuine partial. +func TestReadPlannedCountReturns206(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?limit=1&order=id", map[string]string{ + "Prefer": "count=planned", + }) + if resp.StatusCode != http.StatusPartialContent { + t.Fatalf("status = %d, want 206", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "0-0/4" { + t.Errorf("Content-Range = %q, want 0-0/4", cr) + } +} + +// TestReadOffsetEqualsTotalIs206 checks an offset equal to the total is in range: +// zero rows with 206 and Content-Range "*/total", the case a paginate-until-empty +// loop lands on when the total is an exact multiple of the page size. +func TestReadOffsetEqualsTotalIs206(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?offset=4&order=id", map[string]string{ + "Prefer": "count=exact", + }) + if resp.StatusCode != http.StatusPartialContent { + t.Fatalf("status = %d, want 206", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "*/4" { + t.Errorf("Content-Range = %q, want */4", cr) + } +} + +// TestReadOffsetBeyondTotalIs416 checks an offset strictly past the end is still +// 416, the boundary one row beyond the equal-to-total case. +func TestReadOffsetBeyondTotalIs416(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?offset=5&order=id", map[string]string{ + "Prefer": "count=exact", + }) + if resp.StatusCode != http.StatusRequestedRangeNotSatisfiable { + t.Fatalf("status = %d, want 416", resp.StatusCode) + } + var env struct{ Code, Details string } + if err := json.NewDecoder(resp.Body).Decode(&env); err != nil { + t.Fatalf("decode: %v", err) + } + if env.Code != "PGRST103" { + t.Errorf("code = %q, want PGRST103", env.Code) + } + if want := "An offset of 5 was requested, but there are only 4 rows."; env.Details != want { + t.Errorf("details = %q, want %q", env.Details, want) + } +} + +// TestInvertedRangeHeaderIs416 checks a well-formed Range header whose upper +// bound is below its lower bound is the 416 range error, not silently ignored. +// A malformed header (TestMalformedRangeHeaderIgnored) still serves the full set. +func TestInvertedRangeHeaderIs416(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?order=id", map[string]string{ + "Range": "5-2", + }) + if resp.StatusCode != http.StatusRequestedRangeNotSatisfiable { + t.Fatalf("status = %d, want 416", resp.StatusCode) + } + var env struct{ Code, Details string } + if err := json.NewDecoder(resp.Body).Decode(&env); err != nil { + t.Fatalf("decode: %v", err) + } + if env.Code != "PGRST103" { + t.Errorf("code = %q, want PGRST103", env.Code) + } + if want := "The lower boundary must be lower than or equal to the upper boundary in the Range header."; env.Details != want { + t.Errorf("details = %q, want %q", env.Details, want) + } +} + +// TestMalformedRangeHeaderIgnored checks a non-numeric Range header is dropped +// rather than answered with 416: PostgREST serves the full result. +func TestMalformedRangeHeaderIgnored(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?order=id", map[string]string{ + "Range": "abc-def", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } +} diff --git a/httpapi/plan_test.go b/httpapi/plan_test.go new file mode 100644 index 0000000..c77739b --- /dev/null +++ b/httpapi/plan_test.go @@ -0,0 +1,191 @@ +package httpapi_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + + "github.com/tamnd/dbrest/backend" + "github.com/tamnd/dbrest/backend/sqlite" + "github.com/tamnd/dbrest/httpapi" + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/reqctx" + "github.com/tamnd/dbrest/rpc" +) + +const planMedia = "application/vnd.pgrst.plan+json" + +// explainBackend wraps the sqlite backend with canned Explain methods, standing +// in for an engine that supports EXPLAIN. The three methods mirror the read, +// write, and call execution paths the Explainer interface covers. +type explainBackend struct { + *sqlite.Backend +} + +func (e *explainBackend) ExplainRead(context.Context, *ir.Plan, *reqctx.Context, backend.PlanOptions) ([]byte, error) { + return []byte(`[{"Plan":{"Node Type":"Seq Scan"}}]`), nil +} + +func (e *explainBackend) ExplainWrite(context.Context, *ir.Plan, *reqctx.Context, backend.PlanOptions) ([]byte, error) { + return []byte(`[{"Plan":{"Node Type":"ModifyTable"}}]`), nil +} + +func (e *explainBackend) ExplainCall(context.Context, *ir.Plan, *reqctx.Context, backend.PlanOptions) ([]byte, error) { + return []byte(`[{"Plan":{"Node Type":"Function Scan"}}]`), nil +} + +// planServer builds a server over a seeded films table with an +// EXPLAIN-capable backend, so the db-plan-enabled gate is the only variable. +func planServer(t *testing.T) *httpapi.Server { + t.Helper() + dsn := "file:" + t.Name() + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + if _, err := be.DB().Exec(`CREATE TABLE films (id INTEGER PRIMARY KEY, title TEXT)`); err != nil { + t.Fatalf("seed: %v", err) + } + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(&explainBackend{be}, model, nil) + srv.SetDefaultRole("web_anon") + return srv +} + +// TestPlanDisabledByDefault pins the upstream security default: without +// db-plan-enabled = true a plan request fails with the media-type error even +// when the backend could explain the query. +func TestPlanDisabledByDefault(t *testing.T) { + srv := planServer(t) + resp := do(t, srv, http.MethodGet, "/films", map[string]string{"Accept": planMedia}) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } + var body struct { + Code string `json:"code"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if body.Code != "PGRST107" { + t.Errorf("code = %q, want PGRST107", body.Code) + } +} + +// TestPlanServedWhenEnabled checks the gate opens with the option on. +func TestPlanServedWhenEnabled(t *testing.T) { + srv := planServer(t) + srv.SetPlanEnabled(true) + resp := do(t, srv, http.MethodGet, "/films", map[string]string{"Accept": planMedia}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + // PostgREST echoes the negotiated plan media type with its parameters: the + // +json suffix, the for="" the plan was computed for (application/json + // by default), and the charset. + wantCT := `application/vnd.pgrst.plan+json; for="application/json"; charset=utf-8` + if ct := resp.Header.Get("Content-Type"); ct != wantCT { + t.Errorf("Content-Type = %q, want %q", ct, wantCT) + } + b, _ := io.ReadAll(resp.Body) + if string(b) != `[{"Plan":{"Node Type":"Seq Scan"}}]` { + t.Errorf("body = %s", b) + } +} + +// TestPlanEnabledStillNeedsExplainer keeps the older behavior under the gate: +// an enabled config on a backend without EXPLAIN support is still 406. +func TestPlanEnabledStillNeedsExplainer(t *testing.T) { + srv := newServer(t) + srv.SetPlanEnabled(true) + resp := do(t, srv, http.MethodGet, "/films", map[string]string{"Accept": planMedia}) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } +} + +// TestPlanForWrite checks a mutation plan request routes to ExplainWrite and +// returns the plan instead of executing the write. This pins that the write +// handler hands a plan-typed request to servePlan before touching Execute. +func TestPlanForWrite(t *testing.T) { + srv := planServer(t) + srv.SetPlanEnabled(true) + resp := send(t, srv, http.MethodPost, "/films", `{"id":7,"title":"M"}`, map[string]string{"Accept": planMedia}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + b, _ := io.ReadAll(resp.Body) + if string(b) != `[{"Plan":{"Node Type":"ModifyTable"}}]` { + t.Errorf("body = %s", b) + } +} + +// TestPlanForWriteDisabledIs406 pins that a mutation plan request under a closed +// gate fails with the media-type error, not a 500 or a silently executed write. +func TestPlanForWriteDisabledIs406(t *testing.T) { + srv := planServer(t) + resp := send(t, srv, http.MethodPost, "/films", `{"id":7,"title":"M"}`, map[string]string{"Accept": planMedia}) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } +} + +// planRPCServer is planServer with a portable function registered, so the RPC +// plan path has something to explain. +func planRPCServer(t *testing.T) *httpapi.Server { + t.Helper() + dsn := "file:" + t.Name() + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + if _, err := be.DB().Exec(`CREATE TABLE films (id INTEGER PRIMARY KEY, title TEXT)`); err != nil { + t.Fatalf("seed: %v", err) + } + be.Register(rpc.NewStaticRegistry([]*rpc.Function{{ + Name: "add_them", + Params: []rpc.Param{{Name: "a", Type: "integer"}, {Name: "b", Type: "integer"}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "integer"}, + Volatility: rpc.Immutable, + Query: &rpc.PortableQuery{SQL: "SELECT :a + :b"}, + }})) + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(&explainBackend{be}, model, nil) + srv.SetDefaultRole("web_anon") + return srv +} + +// TestPlanForRPC checks an RPC plan request routes to ExplainCall and returns +// the plan instead of invoking the function. +func TestPlanForRPC(t *testing.T) { + srv := planRPCServer(t) + srv.SetPlanEnabled(true) + resp := send(t, srv, http.MethodPost, "/rpc/add_them", `{"a":2,"b":3}`, map[string]string{"Accept": planMedia}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + b, _ := io.ReadAll(resp.Body) + if string(b) != `[{"Plan":{"Node Type":"Function Scan"}}]` { + t.Errorf("body = %s", b) + } +} + +// TestPlanForRPCDisabledIs406 pins that an RPC plan request under a closed gate +// fails with the media-type error rather than 500ing or running the function. +func TestPlanForRPCDisabledIs406(t *testing.T) { + srv := planRPCServer(t) + resp := send(t, srv, http.MethodPost, "/rpc/add_them", `{"a":2,"b":3}`, map[string]string{"Accept": planMedia}) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } +} diff --git a/httpapi/render.go b/httpapi/render.go index 080333f..39289f7 100644 --- a/httpapi/render.go +++ b/httpapi/render.go @@ -28,19 +28,33 @@ type rendered struct { func renderFor(media string, res backend.Result, rawCols map[string]bool) (*rendered, *pgerr.APIError) { switch media { case mediaJSON, mediaArray: - out, err := renderRows(res, false, rawCols) + out, err := renderRows(res, false, rawCols, false) if err != nil { return nil, err } out.contentType = "application/json; charset=utf-8" return out, nil + case mediaArrayStripped: + out, err := renderRows(res, false, rawCols, true) + if err != nil { + return nil, err + } + out.contentType = "application/vnd.pgrst.array+json; nulls=stripped; charset=utf-8" + return out, nil case mediaObject: - out, err := renderRows(res, true, rawCols) + out, err := renderRows(res, true, rawCols, false) if err != nil { return nil, err } out.contentType = singularMediaType + "; charset=utf-8" return out, nil + case mediaObjectStripped: + out, err := renderRows(res, true, rawCols, true) + if err != nil { + return nil, err + } + out.contentType = "application/vnd.pgrst.object+json; nulls=stripped; charset=utf-8" + return out, nil case mediaCSV: return renderCSV(res) case mediaOctet: @@ -60,7 +74,8 @@ func renderFor(media string, res backend.Result, rawCols map[string]bool) (*rend // singular request. fnName is the bare function name; it is used for native-RPC // heuristic detection when fn is nil. func renderCall(media string, res backend.Result, fn *rpc.Function, fnName string) (*rendered, *pgerr.APIError) { - if fn == nil { + switch { + case fn == nil: // Native RPC: detect scalar vs table by inspecting column names. // res.Rows().Columns() does not advance the cursor; the stream remains // fully readable for the render path below. @@ -70,8 +85,12 @@ func renderCall(media string, res backend.Result, fn *rpc.Function, fnName strin } else { return renderFor(media, res, nil) } - } else if fn.Returns.Kind == rpc.ReturnTable { + case fn.Returns.Kind == rpc.ReturnTable: return renderFor(media, res, nil) + case fn.Returns.Kind == rpc.ReturnObject: + return renderCallObject(media, res) + case fn.Returns.Kind == rpc.ReturnVoid: + return renderVoid(res) } switch media { case mediaCSV: @@ -98,7 +117,7 @@ func renderCall(media string, res backend.Result, fn *rpc.Function, fnName strin } // A scalar function projects one column; if a registry declares scalar // over a wider statement, the first column is the value. - vals = append(vals, row[0]) + vals = append(vals, rawJSONValue(row[0], fn.Returns.Type)) } if err := rs.Err(); err != nil { return nil, pgerr.ErrInternal(err.Error()) @@ -109,9 +128,10 @@ func renderCall(media string, res backend.Result, fn *rpc.Function, fnName strin out.total, out.hasTotl = total, true } - if media == mediaObject { + if singularMedia(media) { if len(vals) != 1 { - return nil, pgerr.ErrSingularZeroMany() + return nil, pgerr.ErrSingularZeroMany(). + WithDetails(fmt.Sprintf("The result contains %d rows", len(vals))) } body, aerr := marshalCall(vals[0]) if aerr != nil { @@ -145,6 +165,98 @@ func renderCall(media string, res backend.Result, fn *rpc.Function, fnName strin return out, nil } +// renderCallObject shapes a function that returns a single composite row (RETURNS +// , not SETOF): PostgREST renders it as one bare JSON object, never an +// array, and a null body when the function produced no row. The CSV and scalar +// media types fall back to the column/row shapers a table read uses. The singular +// object media type keeps its content type; every other JSON media renders the +// same single object under application/json. +func renderCallObject(media string, res backend.Result) (*rendered, *pgerr.APIError) { + switch media { + case mediaCSV: + return renderCSV(res) + case mediaOctet: + return renderScalar(res, false) + case mediaText: + return renderScalar(res, true) + } + + rs := res.Rows() + defer rs.Close() + cols := rs.Columns() + + var obj []byte + n := 0 + for rs.Next() { + if n == 0 { + vals, err := rs.Values() + if err != nil { + return nil, pgerr.ErrInternal(err.Error()) + } + rb, err := encodeRowObject(cols, vals, nil, false) + if err != nil { + return nil, pgerr.ErrInternal(err.Error()) + } + obj = rb + } + n++ + } + if err := rs.Err(); err != nil { + return nil, pgerr.ErrInternal(err.Error()) + } + + out := &rendered{nRows: n} + if total, ok := res.Count(); ok { + out.total, out.hasTotl = total, true + } + if obj == nil { + obj = []byte("null") + } + out.body = obj + out.contentType = "application/json; charset=utf-8" + if singularMedia(media) { + out.contentType = singularMediaType + "; charset=utf-8" + } + return out, nil +} + +// renderVoid shapes a void-returning function: PostgREST answers 200 with a null +// JSON body, never 204, so dbrest pins the same contract across backends rather +// than letting a portable scalar-with-no-rows or a native 204 special case decide +// it. The result is drained so the statement runs to completion, then discarded. +func renderVoid(res backend.Result) (*rendered, *pgerr.APIError) { + rs := res.Rows() + defer rs.Close() + for rs.Next() { + if _, err := rs.Values(); err != nil { + return nil, pgerr.ErrInternal(err.Error()) + } + } + if err := rs.Err(); err != nil { + return nil, pgerr.ErrInternal(err.Error()) + } + return &rendered{ + body: []byte("null"), + contentType: "application/json; charset=utf-8", + }, nil +} + +// rawJSONValue embeds a json-declared scalar verbatim. An engine expression +// (a registry SELECT json_object(...), say) carries no column type the driver +// could key the conversion on, so the declared return type decides here: a +// valid-JSON string under a json/jsonb declaration becomes a RawMessage and +// the encoder emits the document rather than a quoted string, matching how +// PostgreSQL functions returning json behave through PostgREST. +func rawJSONValue(v any, declared string) any { + if declared != "json" && declared != "jsonb" { + return v + } + if s, ok := v.(string); ok && json.Valid([]byte(s)) { + return json.RawMessage(s) + } + return v +} + // marshalCall encodes one RPC value (a scalar or an array of scalars) to JSON // without HTML escaping and without the trailing newline the encoder appends. func marshalCall(v any) ([]byte, *pgerr.APIError) { @@ -165,7 +277,7 @@ func marshalCall(v any) ([]byte, *pgerr.APIError) { // This is the Go-shaped assembly path (Result.Rows). The engine-assembled path // (Result.Body) is used once the embedding subsystem emits in-engine JSON; the // observable body is identical either way. -func renderRows(res backend.Result, singular bool, rawCols map[string]bool) (*rendered, *pgerr.APIError) { +func renderRows(res backend.Result, singular bool, rawCols map[string]bool, stripNulls bool) (*rendered, *pgerr.APIError) { rs := res.Rows() defer rs.Close() cols := rs.Columns() @@ -180,23 +292,13 @@ func renderRows(res backend.Result, singular bool, rawCols map[string]bool) (*re if err != nil { return nil, pgerr.ErrInternal(err.Error()) } - obj := make(map[string]any, len(cols)) - for i, c := range cols { - if rawCols[c] { - obj[c] = rawJSON(vals[i]) - } else { - obj[c] = vals[i] - } - } // Encode each row independently so a large result streams in bounded // memory once the engine-assembled path replaces this shaper. - var rb bytes.Buffer - re := json.NewEncoder(&rb) - re.SetEscapeHTML(false) - if err := re.Encode(obj); err != nil { + rb, err := encodeRowObject(cols, vals, rawCols, stripNulls) + if err != nil { return nil, pgerr.ErrInternal(err.Error()) } - rows = append(rows, json.RawMessage(bytes.TrimRight(rb.Bytes(), "\n"))) + rows = append(rows, json.RawMessage(rb)) } if err := rs.Err(); err != nil { return nil, pgerr.ErrInternal(err.Error()) @@ -209,7 +311,8 @@ func renderRows(res backend.Result, singular bool, rawCols map[string]bool) (*re if singular { if len(rows) != 1 { - return nil, pgerr.ErrSingularZeroMany() + return nil, pgerr.ErrSingularZeroMany(). + WithDetails(fmt.Sprintf("The result contains %d rows", len(rows))) } out.body = rows[0] return out, nil @@ -222,6 +325,58 @@ func renderRows(res backend.Result, singular bool, rawCols map[string]bool) (*re return out, nil } +// encodeRowObject serializes one row as a JSON object whose keys appear in +// projection (column) order, the way PostgREST preserves select order rather +// than the alphabetical order a Go map would impose. A rawCols column carries +// engine-assembled JSON emitted verbatim; every other value is encoded normally. +func encodeRowObject(cols []string, vals []any, rawCols map[string]bool, stripNulls bool) ([]byte, error) { + var rb bytes.Buffer + rb.WriteByte('{') + first := true + for i, c := range cols { + var v any + if rawCols[c] { + v = rawJSON(vals[i]) + } else { + v = vals[i] + } + // nulls=stripped drops a key whose value is SQL NULL (a nil after the raw + // embed unwrap), so the object omits it entirely. + if stripNulls && v == nil { + continue + } + if !first { + rb.WriteByte(',') + } + first = false + key, err := jsonNoEscape(c) + if err != nil { + return nil, err + } + rb.Write(key) + rb.WriteByte(':') + val, err := jsonNoEscape(v) + if err != nil { + return nil, err + } + rb.Write(val) + } + rb.WriteByte('}') + return rb.Bytes(), nil +} + +// jsonNoEscape encodes a value to JSON the way PostgREST does: HTML characters +// stay unescaped and the encoder's trailing newline is trimmed. +func jsonNoEscape(v any) ([]byte, error) { + var b bytes.Buffer + e := json.NewEncoder(&b) + e.SetEscapeHTML(false) + if err := e.Encode(v); err != nil { + return nil, err + } + return bytes.TrimRight(b.Bytes(), "\n"), nil +} + // renderCSV writes a header row of column names followed by one RFC 4180 record // per row. A nested value (an embedded relation or a JSON column) is serialized // as its JSON text inside a single cell rather than expanded into more columns, @@ -294,10 +449,12 @@ func csvCell(v any) string { case []byte: return string(t) case bool: + // PostgreSQL's text output (what PostgREST's CSV mirrors) renders booleans + // as t/f, not the JSON true/false. if t { - return "true" + return "t" } - return "false" + return "f" case json.Number: return t.String() case float64: diff --git a/httpapi/render_test.go b/httpapi/render_test.go index 6e950a9..d231e9b 100644 --- a/httpapi/render_test.go +++ b/httpapi/render_test.go @@ -47,8 +47,8 @@ func TestCSVCellForms(t *testing.T) { {"null", nil, ""}, {"string", "Dune", "Dune"}, {"bytes", []byte("blob"), "blob"}, - {"bool-true", true, "true"}, - {"bool-false", false, "false"}, + {"bool-true", true, "t"}, + {"bool-false", false, "f"}, {"json-number", json.Number("42"), "42"}, {"float", 3.5, "3.5"}, {"int64", int64(-7), "-7"}, @@ -92,6 +92,30 @@ func TestRawJSONForms(t *testing.T) { } } +// rawJSONValue embeds a json/jsonb-declared scalar verbatim so a function +// returning json does not double-encode into a quoted string on a portable +// backend, where the driver hands the result back as TEXT. A non-json declaration +// leaves the value untouched, and an invalid-JSON string under a json declaration +// is left as a string rather than emitted as a broken document. +func TestRawJSONValue(t *testing.T) { + if got, ok := rawJSONValue(`{"a":1}`, "json").(json.RawMessage); !ok || string(got) != `{"a":1}` { + t.Errorf("json scalar = %#v, want RawMessage", got) + } + if got, ok := rawJSONValue(`[1,2]`, "jsonb").(json.RawMessage); !ok || string(got) != `[1,2]` { + t.Errorf("jsonb scalar = %#v, want RawMessage", got) + } + // A non-json declaration passes the text through as a plain string, which the + // encoder will quote. + if got := rawJSONValue(`{"a":1}`, "text"); got != `{"a":1}` { + t.Errorf("text scalar = %#v, want the string unchanged", got) + } + // Malformed JSON under a json declaration is not wrapped, so the encoder quotes + // it rather than emitting an invalid document. + if _, ok := rawJSONValue(`{not json`, "json").(json.RawMessage); ok { + t.Error("invalid JSON should not become a RawMessage") + } +} + // asAPIError normalizes a backend execution error three ways: an error that is // already an API error passes straight through, an engine-native error the // backend recognizes becomes whatever it maps to, and anything else falls back @@ -113,7 +137,7 @@ func TestAsAPIError(t *testing.T) { // Branch two: a raw engine error the backend recognizes becomes its mapping. t.Run("backend-maps-it", func(t *testing.T) { - mapped := pgerr.ErrUniqueViolation("films_pkey") + mapped := pgerr.ErrConstraintViolation("23505", "duplicate key", "", "") b := fakeBackend{mapErr: func(error) *pgerr.APIError { return mapped }} if got := asAPIError(b, errors.New("duplicate key")); got != mapped { t.Errorf("asAPIError = %#v, want the backend mapping %#v", got, mapped) diff --git a/httpapi/rendering_test.go b/httpapi/rendering_test.go new file mode 100644 index 0000000..48133c8 --- /dev/null +++ b/httpapi/rendering_test.go @@ -0,0 +1,147 @@ +package httpapi_test + +import ( + "io" + "net/http" + "strings" + "testing" +) + +// readBody returns the response body as a string. +func readBody(t *testing.T, resp *http.Response) string { + t.Helper() + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + return string(b) +} + +// TestSelectOrderPreservedInJSON pins 02.19: object keys appear in projection +// order, not alphabetized. A select of title,id renders {"title":...,"id":...}. +func TestSelectOrderPreservedInJSON(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?id=eq.1&select=title,id", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + body := readBody(t, resp) + titlePos := strings.Index(body, `"title"`) + idPos := strings.Index(body, `"id"`) + if titlePos < 0 || idPos < 0 { + t.Fatalf("body missing expected keys: %s", body) + } + if titlePos > idPos { + t.Errorf("keys out of select order, want title before id: %s", body) + } +} + +// TestNullsStrippedArrayOmitsNullKeys pins 02.13: the nulls=stripped parameter +// on the vendor array type drops null-valued keys from each object and the +// Content-Type echoes the parameter. Film 4 has a NULL year. +func TestNullsStrippedArrayOmitsNullKeys(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?id=eq.4&select=id,title,year,rating", map[string]string{ + "Accept": "application/vnd.pgrst.array+json;nulls=stripped", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); ct != "application/vnd.pgrst.array+json; nulls=stripped; charset=utf-8" { + t.Errorf("Content-Type = %q", ct) + } + body := readBody(t, resp) + if strings.Contains(body, `"year"`) { + t.Errorf("null year key should be stripped: %s", body) + } + if !strings.Contains(body, `"title"`) || !strings.Contains(body, `"rating"`) { + t.Errorf("non-null keys should remain: %s", body) + } +} + +// TestNullsStrippedObjectOmitsNullKeys pins the same parameter on the object +// vendor type for a singular request. +func TestNullsStrippedObjectOmitsNullKeys(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?id=eq.4&select=id,title,year", map[string]string{ + "Accept": "application/vnd.pgrst.object+json;nulls=stripped", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); ct != "application/vnd.pgrst.object+json; nulls=stripped; charset=utf-8" { + t.Errorf("Content-Type = %q", ct) + } + body := readBody(t, resp) + if strings.Contains(body, `"year"`) { + t.Errorf("null year key should be stripped: %s", body) + } + if strings.HasPrefix(strings.TrimSpace(body), "[") { + t.Errorf("object request should not be an array: %s", body) + } +} + +// TestNullsStrippedIgnoredOnPlainJSON pins that the parameter is vendor-only: +// plain application/json keeps null keys, matching PostgREST. +func TestNullsStrippedIgnoredOnPlainJSON(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?id=eq.4&select=id,title,year", map[string]string{ + "Accept": "application/json;nulls=stripped", + }) + body := readBody(t, resp) + if !strings.Contains(body, `"year"`) { + t.Errorf("plain json should keep the null key: %s", body) + } +} + +// TestCSVQuotesAndNoTrailingBlankLine pins 02.20: a field with a comma and a +// double quote is RFC 4180 quoted the way PostgREST quotes CSV (comma forces +// quoting, inner quotes are doubled), records are \n-terminated, and there is no +// extra trailing blank line after the last record. +func TestCSVQuotesAndNoTrailingBlankLine(t *testing.T) { + srv := newServer(t) + send(t, srv, http.MethodPost, "/films", `{"id":91,"title":"A, B \"q\""}`, map[string]string{ + "Prefer": "return=minimal", + }) + resp := do(t, srv, http.MethodGet, "/films?id=eq.91&select=id,title", map[string]string{ + "Accept": "text/csv", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + body := readBody(t, resp) + want := "id,title\n91,\"A, B \"\"q\"\"\"\n" + if body != want { + t.Errorf("CSV body = %q, want %q", body, want) + } +} + +// TestCSVEmptyResultKeepsHeader pins the empty-result CSV shape dbrest produces: +// the column-name header line plus a newline, with no data rows. The PostgREST +// empty-result shape itself is verified separately against a live server (02.20). +func TestCSVEmptyResultKeepsHeader(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?id=eq.9999&select=id,title", map[string]string{ + "Accept": "text/csv", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if body := readBody(t, resp); body != "id,title\n" { + t.Errorf("empty CSV body = %q, want the header line only", body) + } +} + +// TestSelectOrderReversed pins the inverse projection to prove the order tracks +// the select, not a fixed column order: id,title renders {"id":...,"title":...}. +func TestSelectOrderReversed(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?id=eq.1&select=year,rating,title", nil) + body := readBody(t, resp) + yearPos := strings.Index(body, `"year"`) + ratingPos := strings.Index(body, `"rating"`) + titlePos := strings.Index(body, `"title"`) + if yearPos >= ratingPos || ratingPos >= titlePos { + t.Errorf("keys not in select order year,rating,title: %s", body) + } +} diff --git a/httpapi/reqctx_test.go b/httpapi/reqctx_test.go index 24e8558..fb8f69a 100644 --- a/httpapi/reqctx_test.go +++ b/httpapi/reqctx_test.go @@ -2,6 +2,7 @@ package httpapi_test import ( "context" + "io" "net/http" "net/http/httptest" "strings" @@ -67,7 +68,9 @@ func captureServer(t *testing.T) (*httpapi.Server, *captureBackend) { t.Fatalf("introspect: %v", err) } cap := &captureBackend{Backend: be} - return httpapi.NewServer(cap, model, nil), cap + srv := httpapi.NewServer(cap, model, nil) + srv.SetDefaultRole("anon") + return srv, cap } func TestContextCarriesRequestMetadata(t *testing.T) { @@ -97,25 +100,95 @@ func TestContextCarriesRequestMetadata(t *testing.T) { } } -func TestContextCarriesProfileSchema(t *testing.T) { +func TestAcceptProfileUnknownSchemaIs406(t *testing.T) { srv, cap := captureServer(t) req := newReq(http.MethodGet, "/films?select=id") req.Header.Set("Accept-Profile", "reporting") - srv.ServeHTTP(newRecorder(), req) + rec := newRecorder() + srv.ServeHTTP(rec, req) - if cap.got.Schema != "reporting" { - t.Errorf("Schema = %q, want reporting (from Accept-Profile)", cap.got.Schema) + if cap.got != nil { + t.Fatal("backend executed despite an invalid Accept-Profile") + } + resp := rec.Result() + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), `"PGRST106"`) { + t.Errorf("body = %s, want code PGRST106", body) + } + if !strings.Contains(string(body), "Invalid schema: reporting") { + t.Errorf("body = %s, want message naming the schema", body) + } + if h := resp.Header.Get("Content-Profile"); h != "" { + t.Errorf("Content-Profile = %q on an error, want unset", h) } } -func TestContextWriteUsesContentProfile(t *testing.T) { +func TestContentProfileUnknownSchemaIs406(t *testing.T) { srv, cap := captureServer(t) req := newReqBody(http.MethodPost, "/films", `{"id":3,"title":"Dune"}`) req.Header.Set("Content-Profile", "staging") - srv.ServeHTTP(newRecorder(), req) + rec := newRecorder() + srv.ServeHTTP(rec, req) + + if cap.got != nil { + t.Fatal("backend executed despite an invalid Content-Profile") + } + resp := rec.Result() + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "Invalid schema: staging") { + t.Errorf("body = %s, want message naming the schema", body) + } +} + +// TestNoProfileUsesDefaultSchema pins the default: with no profile header the +// active schema is the first exposed schema, and on a single-schema deployment +// no Content-Profile response header is emitted. +func TestNoProfileUsesDefaultSchema(t *testing.T) { + srv, cap := captureServer(t) + rec := newRecorder() + srv.ServeHTTP(rec, newReq(http.MethodGet, "/films?select=id")) - if cap.got.Schema != "staging" { - t.Errorf("Schema = %q, want staging (from Content-Profile)", cap.got.Schema) + if cap.got == nil { + t.Fatal("Execute never received a context") + } + if cap.got.Schema != "" { + t.Errorf("Schema = %q, want the default (first exposed) schema", cap.got.Schema) + } + if h := rec.Result().Header.Get("Content-Profile"); h != "" { + t.Errorf("Content-Profile = %q, want unset on a single-schema server", h) + } +} + +func TestContextCarriesPreRequest(t *testing.T) { + srv, cap := captureServer(t) + srv.SetPreRequest("check_request") + + srv.ServeHTTP(newRecorder(), newReq(http.MethodGet, "/films?select=id")) + if cap.got == nil || cap.got.PreRequest != "check_request" { + t.Fatalf("read context PreRequest = %v, want check_request", cap.got) + } + + cap.got = nil + srv.ServeHTTP(newRecorder(), newReqBody(http.MethodPost, "/films", `{"id":6,"title":"Heat"}`)) + if cap.got == nil || cap.got.PreRequest != "check_request" { + t.Fatalf("write context PreRequest = %v, want check_request", cap.got) + } +} + +func TestContextHasNoPreRequestByDefault(t *testing.T) { + srv, cap := captureServer(t) + srv.ServeHTTP(newRecorder(), newReq(http.MethodGet, "/films?select=id")) + if cap.got == nil { + t.Fatal("Execute never received a context") + } + if cap.got.PreRequest != "" { + t.Errorf("PreRequest = %q, want empty with none configured", cap.got.PreRequest) } } diff --git a/httpapi/root.go b/httpapi/root.go index ccc58cc..d930775 100644 --- a/httpapi/root.go +++ b/httpapi/root.go @@ -6,8 +6,12 @@ import ( "strings" "github.com/tamnd/dbrest/config" + "github.com/tamnd/dbrest/ir" "github.com/tamnd/dbrest/openapi" "github.com/tamnd/dbrest/pgerr" + "github.com/tamnd/dbrest/plan" + "github.com/tamnd/dbrest/reqctx" + "github.com/tamnd/dbrest/schema" ) // handleRoot serves the self-describing OpenAPI document at GET /. The document @@ -15,39 +19,184 @@ import ( // backend's declared capabilities, so it describes exactly what this server can // serve and never promises an operator the next request would reject. HEAD // returns the headers with no body. See spec 19. -func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) { +func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request, id identity, activeSchema string) { + if r.Method == http.MethodOptions { + // OPTIONS on the root answers with the verb set, the way PostgREST's + // info response does, with no body and no media type. + w.Header().Set("Allow", rootAllow) + w.WriteHeader(http.StatusOK) + return + } if r.Method != http.MethodGet && r.Method != http.MethodHead { - writeError(w, pgerr.ErrUnsupported(r.Method+" requests on the root", "dbrest")) + w.Header().Set("Allow", rootAllow) + writeError(w, pgerr.ErrUnsupportedMethod(r.Method)) return } if s.openapiMode == config.OpenAPIDisabled { // openapi-mode=disabled turns the self-describing root off entirely; a - // request there is a plain 404, as PostgREST returns. - writeError(w, pgerr.New(http.StatusNotFound, "", "openapi is disabled")) + // request there is PostgREST's 404 PGRST126. + writeError(w, errRootDisabled()) + return + } + if !rootAcceptable(r.Header.Values("Accept")) { + writeError(w, pgerr.ErrNotAcceptable(acceptedList(r.Header.Values("Accept")))) + return + } + if s.rootSpec != "" { + // db-root-spec replaces the generated document with the named + // function's result, upstream's escape hatch for a custom spec. + s.serveRootSpec(w, r, id, activeSchema) return } opts := openapi.Options{ - Host: r.Host, - Schemes: []string{requestScheme(r)}, - JWT: s.verifier != nil, + Host: r.Host, + Schemes: []string{requestScheme(r)}, + SecurityActive: s.openapiSecurity, + ActiveSchema: activeSchema, + } + model := s.Model() + if comment := model.SchemaComment(activeSchema); comment != "" { + // The database comment on the active schema names the API: the first + // line is the info title, the rest the description, as v14 reads it. + title, rest, _ := strings.Cut(comment, "\n") + opts.Title = title + opts.Description = strings.TrimSpace(rest) } if s.openapiProxy != "" { applyProxyURI(&opts, s.openapiProxy) } - body, err := openapi.Generate(s.model, s.backend.Functions(), s.backend.Capabilities(), opts) + if s.openapiMode == config.OpenAPIFollowPrivileges && s.authz != nil { + // follow-privileges scopes the document to what the requesting role may + // actually do, so an anonymous caller cannot enumerate relations it + // cannot touch. The answers come from the same gate that authorizes a + // real request; ignore-privileges leaves Visibility nil and emits all. + rc := s.buildContext(r, id, activeSchema) + opts.Visibility = func(rel *schema.Relation) openapi.Actions { + return openapi.Actions{ + Get: s.probeAction(rc, rel.Name, ir.Read), + Post: s.probeAction(rc, rel.Name, ir.Insert), + Patch: s.probeAction(rc, rel.Name, ir.Update), + Delete: s.probeAction(rc, rel.Name, ir.Delete), + } + } + } + // The RPC paths come from the registry that resolves /rpc calls in the active + // schema: for a NativeRPC backend that introspects its functions, the native + // per-schema registry; otherwise the portable registry. This keeps the + // document's /rpc/ entries in step with what the call path actually serves. + body, err := openapi.Generate(model, s.rpcRegistry(activeSchema), s.backend.Capabilities(), opts) if err != nil { writeError(w, pgerr.ErrInternal(err.Error())) return } - w.Header().Set("Content-Type", openapi.MediaType) + w.Header().Set("Content-Type", openapi.MediaType+"; charset=utf-8") w.WriteHeader(http.StatusOK) if r.Method != http.MethodHead { w.Write(body) } } +// serveRootSpec invokes the db-root-spec function and serves its JSON result +// in place of the generated document. The call runs exactly like GET +// /rpc/ with no arguments, the same planning and execution path, so role +// switching and error mapping behave identically; only the response media +// type differs, staying the root's openapi+json. +func (s *Server) serveRootSpec(w http.ResponseWriter, r *http.Request, id identity, activeSchema string) { + call, apiErr := ir.ParseCall(s.rootSpec, "", nil, true, "", nil, "", "") + if apiErr != nil { + writeError(w, apiErr) + return + } + + var planned *ir.Plan + if s.backend.Capabilities().NativeRPC { + planned = &ir.Plan{Call: call, ReadOnly: true} + } else { + planned, apiErr = plan.Call(s.backend.Functions(), s.Model(), call, true, []string{activeSchema}) + if apiErr != nil { + writeError(w, apiErr) + return + } + } + + rc := s.buildContext(r, id, activeSchema) + res, err := s.backend.Execute(r.Context(), planned, rc) + if err != nil { + writeError(w, mapExecError(s.backend, err, id.anonymous)) + return + } + out, apiErr := renderCall(mediaJSON, res, planned.Func, s.rootSpec) + if apiErr != nil { + writeError(w, apiErr) + return + } + + w.Header().Set("Content-Type", openapi.MediaType+"; charset=utf-8") + w.WriteHeader(http.StatusOK) + if r.Method != http.MethodHead { + w.Write(out.body) + } +} + +// rootAllow is the verb set the root serves: the Allow value OPTIONS answers +// with and the one a 405 carries so the rejected caller knows what would work. +const rootAllow = "OPTIONS,GET,HEAD" + +// errRootDisabled is PostgREST's PGRST126: the root metadata endpoint turned +// off by openapi-mode=disabled (or an unset db-root-spec in that mode), a 404 +// with an explicit code rather than a bare not-found. +func errRootDisabled() *pgerr.APIError { + return pgerr.New(http.StatusNotFound, "PGRST126", "Root endpoint metadata is disabled") +} + +// rootAcceptable reports whether the Accept header admits the root document. +// The root produces application/openapi+json and application/json only; an +// absent header or a wildcard range accepts it, anything else is the caller's +// 406 PGRST107 (PostgREST root negotiation). +func rootAcceptable(accept []string) bool { + ranges := parseAccept(accept) + if len(ranges) == 0 { + return true + } + for _, mr := range ranges { + if mr.q <= 0 { + continue + } + if mr.typ == "*" && mr.sub == "*" { + return true + } + if mr.typ == "application" && (mr.sub == "*" || mr.sub == "openapi+json" || mr.sub == "json") { + return true + } + } + return false +} + +// acceptedList renders the requested media types for the PGRST107 message the +// way PostgREST does: parameters stripped, ordered by descending quality. +func acceptedList(accept []string) string { + ranges := parseAccept(accept) + parts := make([]string, len(ranges)) + for i, mr := range ranges { + parts[i] = mr.typ + "/" + mr.sub + } + return strings.Join(parts, ", ") +} + +// probeAction asks the authorization gate whether the role could perform one +// kind of query on a relation, by authorizing a minimal throwaway plan. Using +// the real gate keeps the document's answer identical to what a request would +// get; the probe plan is discarded, so the gate's mutations never escape. +func (s *Server) probeAction(rc *reqctx.Context, rel string, kind ir.QueryKind) bool { + q := &ir.Query{Kind: kind, Relation: ir.Ref{Name: rel}} + if kind != ir.Read { + q.Write = &ir.WriteSpec{} + } + return s.authz.Authorize(rc, &ir.Plan{Query: q}) == nil +} + // requestScheme reports the URL scheme the client reached the server with, // reading the TLS state. Behind a proxy this is the listen-side scheme; the // externally visible scheme comes from the proxy-uri configuration (spec 20). diff --git a/httpapi/root_test.go b/httpapi/root_test.go index e2880b1..d4c0936 100644 --- a/httpapi/root_test.go +++ b/httpapi/root_test.go @@ -1,12 +1,18 @@ package httpapi_test import ( + "context" "encoding/json" "net/http" "strings" "testing" + "github.com/tamnd/dbrest/auth" + "github.com/tamnd/dbrest/authz" + "github.com/tamnd/dbrest/backend/sqlite" "github.com/tamnd/dbrest/config" + "github.com/tamnd/dbrest/httpapi" + "github.com/tamnd/dbrest/rpc" ) // TestRootServesOpenAPI checks GET / returns the OpenAPI document with the @@ -17,8 +23,8 @@ func TestRootServesOpenAPI(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("status = %d, want 200", resp.StatusCode) } - if ct := resp.Header.Get("Content-Type"); ct != "application/openapi+json" { - t.Errorf("Content-Type = %q, want application/openapi+json", ct) + if ct := resp.Header.Get("Content-Type"); ct != "application/openapi+json; charset=utf-8" { + t.Errorf("Content-Type = %q, want application/openapi+json; charset=utf-8", ct) } var doc map[string]any if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { @@ -50,7 +56,7 @@ func TestRootHeadHasNoBody(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatalf("status = %d, want 200", resp.StatusCode) } - if ct := resp.Header.Get("Content-Type"); ct != "application/openapi+json" { + if ct := resp.Header.Get("Content-Type"); ct != "application/openapi+json; charset=utf-8" { t.Errorf("Content-Type = %q", ct) } buf := make([]byte, 1) @@ -59,21 +65,255 @@ func TestRootHeadHasNoBody(t *testing.T) { } } -// TestRootDisabledIs404 checks openapi-mode=disabled turns the root off. +// TestRootNegotiatesAccept pins the root's Accept handling: openapi+json, +// plain json, and wildcards are served; anything else is 406 PGRST107 with +// the requested types echoed in q-descending order, parameters stripped. +func TestRootNegotiatesAccept(t *testing.T) { + srv := newServer(t) + for _, accept := range []string{ + "application/openapi+json", + "application/json", + "*/*", + "application/*", + "application/json;q=0.5, text/html", // one acceptable type suffices + } { + resp := do(t, srv, http.MethodGet, "/", map[string]string{"Accept": accept}) + if resp.StatusCode != http.StatusOK { + t.Errorf("Accept %q: status = %d, want 200", accept, resp.StatusCode) + } + } + + resp := do(t, srv, http.MethodGet, "/", map[string]string{"Accept": "text/csv;q=0.3, application/xml"}) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } + var e struct { + Code string `json:"code"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&e); err != nil { + t.Fatalf("decode: %v", err) + } + if e.Code != "PGRST107" { + t.Errorf("code = %q, want PGRST107", e.Code) + } + if e.Message != "None of these media types are available: application/xml, text/csv" { + t.Errorf("message = %q, want q-ordered type list", e.Message) + } +} + +// TestRootDisabledIs404 checks openapi-mode=disabled turns the root off with +// PostgREST's explicit PGRST126 code, not a bare not-found. func TestRootDisabledIs404(t *testing.T) { srv := newServer(t) - srv.SetOpenAPI(config.OpenAPIDisabled, "") + srv.SetOpenAPI(config.OpenAPIDisabled, "", false) resp := do(t, srv, http.MethodGet, "/", nil) if resp.StatusCode != http.StatusNotFound { t.Fatalf("status = %d, want 404", resp.StatusCode) } + var e struct { + Code string `json:"code"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&e); err != nil { + t.Fatalf("decode: %v", err) + } + if e.Code != "PGRST126" { + t.Errorf("code = %q, want PGRST126", e.Code) + } + if e.Message != "Root endpoint metadata is disabled" { + t.Errorf("message = %q", e.Message) + } +} + +// TestRootMethodNotAllowed pins the verb gate at the root: anything besides +// GET, HEAD, and OPTIONS is 405 PGRST117 naming the method, with the Allow +// header listing what the root serves. The gate runs before the disabled +// check, so the answer is the same in every openapi-mode. +func TestRootMethodNotAllowed(t *testing.T) { + srv := newServer(t) + for _, mode := range []string{config.OpenAPIFollowPrivileges, config.OpenAPIDisabled} { + srv.SetOpenAPI(mode, "", false) + for _, method := range []string{http.MethodDelete, http.MethodPatch, http.MethodPost, http.MethodPut, "TRACE"} { + resp := do(t, srv, method, "/", nil) + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("mode %s %s /: status = %d, want 405", mode, method, resp.StatusCode) + } + if allow := resp.Header.Get("Allow"); allow != "OPTIONS,GET,HEAD" { + t.Errorf("%s /: Allow = %q, want OPTIONS,GET,HEAD", method, allow) + } + var e struct { + Code string `json:"code"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&e); err != nil { + t.Fatalf("decode: %v", err) + } + if e.Code != "PGRST117" { + t.Errorf("%s /: code = %q, want PGRST117", method, e.Code) + } + if e.Message != "Unsupported HTTP method: "+method { + t.Errorf("%s /: message = %q", method, e.Message) + } + } + } +} + +// TestRootOptionsAnswersAllow checks OPTIONS / is 200 with the verb set and +// no body, the way PostgREST's info response answers it. +func TestRootOptionsAnswersAllow(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodOptions, "/", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if allow := resp.Header.Get("Allow"); allow != "OPTIONS,GET,HEAD" { + t.Errorf("Allow = %q, want OPTIONS,GET,HEAD", allow) + } + buf := make([]byte, 1) + if n, _ := resp.Body.Read(buf); n != 0 { + t.Error("OPTIONS / should have no body") + } +} + +// TestRootFollowPrivilegesFiltersDocument checks the default openapi-mode: +// the document only describes the relations and operations the requesting +// role can access, so anon and an authenticated role see different documents. +func TestRootFollowPrivilegesFiltersDocument(t *testing.T) { + srv := authzServer(t, []authz.Grant{ + {Role: "web_user", Relation: "films", Action: authz.Select}, + {Role: "web_user", Relation: "films", Action: authz.Insert}, + }, nil) + srv.SetOpenAPI(config.OpenAPIFollowPrivileges, "", false) + + // The authenticated role sees films with exactly its granted operations. + resp := do(t, srv, http.MethodGet, "/", map[string]string{ + "Authorization": "Bearer " + userToken(t, "web_user", "alice"), + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var doc struct { + Paths map[string]map[string]any `json:"paths"` + Definitions map[string]any `json:"definitions"` + } + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("decode: %v", err) + } + films, ok := doc.Paths["/films"] + if !ok { + t.Fatal("granted role should see /films") + } + for _, op := range []string{"get", "post"} { + if _, ok := films[op]; !ok { + t.Errorf("/films missing granted operation %s", op) + } + } + for _, op := range []string{"patch", "delete"} { + if _, ok := films[op]; ok { + t.Errorf("/films advertises ungranted operation %s", op) + } + } + + // Anon holds no grants: nothing is enumerated. Only the "/" entry that + // describes the document itself remains, as in v14. + resp = do(t, srv, http.MethodGet, "/", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("anon status = %d, want 200", resp.StatusCode) + } + doc.Paths, doc.Definitions = nil, nil + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("decode: %v", err) + } + if len(doc.Paths) != 1 { + t.Errorf("anon sees paths %v, want only the root entry", doc.Paths) + } + if _, ok := doc.Paths["/"]; !ok { + t.Errorf("anon paths = %v, want the root entry", doc.Paths) + } + if len(doc.Definitions) != 0 { + t.Errorf("anon sees definitions %v, want none", doc.Definitions) + } +} + +// TestRootIgnorePrivilegesEmitsAll checks openapi-mode=ignore-privileges keeps +// the full document even for a role with no grants. +func TestRootIgnorePrivilegesEmitsAll(t *testing.T) { + srv := authzServer(t, nil, nil) + srv.SetOpenAPI(config.OpenAPIIgnorePrivileges, "", false) + resp := do(t, srv, http.MethodGet, "/", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var doc struct { + Paths map[string]any `json:"paths"` + } + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("decode: %v", err) + } + if _, ok := doc.Paths["/films"]; !ok { + t.Error("ignore-privileges should still describe /films") + } +} + +// TestRootSecurityActive checks openapi-security-active emits the JWT scheme +// and a document-level security requirement, the way PostgREST v14 shapes it; +// off (the default) the document carries neither, even with JWT configured. +func TestRootSecurityActive(t *testing.T) { + srv := authServer(t, auth.Config{}) + srv.SetOpenAPI(config.OpenAPIIgnorePrivileges, "", true) + resp := do(t, srv, http.MethodGet, "/", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var doc struct { + SecurityDefinitions map[string]map[string]any `json:"securityDefinitions"` + Security []map[string][]any `json:"security"` + Paths map[string]map[string]struct { + Security []map[string][]any `json:"security"` + } `json:"paths"` + } + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("decode: %v", err) + } + jwt, ok := doc.SecurityDefinitions["JWT"] + if !ok { + t.Fatal("securityDefinitions missing the JWT scheme") + } + if jwt["type"] != "apiKey" || jwt["name"] != "Authorization" || jwt["in"] != "header" { + t.Errorf("JWT scheme = %v", jwt) + } + if len(doc.Security) != 1 { + t.Fatalf("security = %v, want one document-level requirement", doc.Security) + } + if _, ok := doc.Security[0]["JWT"]; !ok { + t.Errorf("security requirement = %v, want JWT", doc.Security[0]) + } + // v14 attaches the requirement at the document, never per operation. + if sec := doc.Paths["/films"]["get"].Security; len(sec) != 0 { + t.Errorf("get security = %v, want none per operation", sec) + } + + // Off (the default): no securityDefinitions and no requirement at all. + srv.SetOpenAPI(config.OpenAPIIgnorePrivileges, "", false) + resp = do(t, srv, http.MethodGet, "/", nil) + doc.SecurityDefinitions, doc.Security, doc.Paths = nil, nil, nil + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("decode: %v", err) + } + if len(doc.SecurityDefinitions) != 0 { + t.Errorf("securityDefinitions = %v, want none when inactive", doc.SecurityDefinitions) + } + if len(doc.Security) != 0 { + t.Errorf("security = %v, want none when inactive", doc.Security) + } } // TestRootProxyURIRewritesHost checks openapi-server-proxy-uri overrides the // host, scheme, and base path the document advertises. func TestRootProxyURIRewritesHost(t *testing.T) { srv := newServer(t) - srv.SetOpenAPI(config.OpenAPIFollowPrivileges, "https://api.example.com/v1") + srv.SetOpenAPI(config.OpenAPIFollowPrivileges, "https://api.example.com/v1", false) resp := do(t, srv, http.MethodGet, "/", nil) var doc map[string]any if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { @@ -100,21 +340,92 @@ func TestRootAdvertisesServedOperators(t *testing.T) { var doc map[string]any json.NewDecoder(resp.Body).Decode(&doc) - params := doc["paths"].(map[string]any)["/films"].(map[string]any)["get"].(map[string]any)["parameters"].([]any) - for _, p := range params { - pm := p.(map[string]any) - if pm["name"] != "title" { - continue - } - desc := pm["description"].(string) - // match/imatch and fts are served on SQLite; the range operators are not. - for _, want := range []string{"match", "fts"} { - if !strings.Contains(desc, want) { - t.Errorf("expected %q advertised; desc = %q", want, desc) - } - } - if strings.Contains(desc, " sl,") || strings.Contains(desc, " adj.") { - t.Errorf("range operators should not be advertised on SQLite; desc = %q", desc) + // Operations reference the shared rowFilter parameters; the operator list + // lives on the definition in the document's parameters map. + title, ok := doc["parameters"].(map[string]any)["rowFilter.films.title"].(map[string]any) + if !ok { + t.Fatal("document is missing the rowFilter.films.title parameter") + } + desc := title["description"].(string) + // match/imatch and fts are served on SQLite; the range operators are not. + for _, want := range []string{"match", "fts"} { + if !strings.Contains(desc, want) { + t.Errorf("expected %q advertised; desc = %q", want, desc) } } + if strings.Contains(desc, " sl,") || strings.Contains(desc, " adj.") { + t.Errorf("range operators should not be advertised on SQLite; desc = %q", desc) + } +} + +// newRootSpecServer builds a server whose registry carries a custom-spec +// function and points db-root-spec at it. +func newRootSpecServer(t *testing.T) *httpapi.Server { + t.Helper() + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + if _, err := be.DB().Exec(`CREATE TABLE films (id INTEGER PRIMARY KEY, title TEXT NOT NULL);`); err != nil { + t.Fatalf("seed: %v", err) + } + be.Register(rpc.NewStaticRegistry([]*rpc.Function{{ + Name: "custom_spec", + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "json"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: `SELECT json_object('swagger', '2.0', 'info', json_object('title', 'My Custom API'))`}, + }})) + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + srv.SetRootSpec("custom_spec") + return srv +} + +// TestRootSpecOverridesDocument pins db-root-spec: the named function's JSON +// result replaces the generated document, served with the root's media type. +func TestRootSpecOverridesDocument(t *testing.T) { + srv := newRootSpecServer(t) + resp := do(t, srv, http.MethodGet, "/", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); ct != "application/openapi+json; charset=utf-8" { + t.Errorf("Content-Type = %q", ct) + } + var doc map[string]any + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("decode: %v", err) + } + if doc["swagger"] != "2.0" { + t.Errorf("swagger = %v", doc["swagger"]) + } + info, ok := doc["info"].(map[string]any) + if !ok || info["title"] != "My Custom API" { + t.Errorf("info = %v, want the custom title", doc["info"]) + } + if _, generated := doc["paths"]; generated { + t.Error("the generated document should be fully replaced") + } +} + +// TestRootSpecDisabledStaysOff checks openapi-mode=disabled wins over +// db-root-spec: the root stays a 404 PGRST126. +func TestRootSpecDisabledStaysOff(t *testing.T) { + srv := newRootSpecServer(t) + srv.SetOpenAPI(config.OpenAPIDisabled, "", false) + resp := do(t, srv, http.MethodGet, "/", nil) + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status = %d, want 404", resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "PGRST126" { + t.Errorf("code = %v, want PGRST126", env["code"]) + } } diff --git a/httpapi/rpc_embed_test.go b/httpapi/rpc_embed_test.go new file mode 100644 index 0000000..4bcf3cb --- /dev/null +++ b/httpapi/rpc_embed_test.go @@ -0,0 +1,225 @@ +package httpapi_test + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "testing" + + "github.com/tamnd/dbrest/backend/sqlite" + "github.com/tamnd/dbrest/httpapi" + "github.com/tamnd/dbrest/rpc" +) + +// rpcEmbedFunctions returns rows of known relations so the call result supports +// embedding: recent_films is setof films (a to-one director, a many-to-many +// actors), all_directors is setof directors (a to-many films), and film_titles +// is setof text, a scalar set with no relation to embed against. +func rpcEmbedFunctions() []*rpc.Function { + return []*rpc.Function{ + { + Name: "recent_films", + Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf, Type: "films"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: "SELECT * FROM films ORDER BY id"}, + }, + { + Name: "all_directors", + Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf, Type: "directors"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: "SELECT * FROM directors ORDER BY id"}, + }, + { + Name: "film_titles", + Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf, Type: "text"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: "SELECT title FROM films ORDER BY id"}, + }, + } +} + +// newRPCEmbedServer seeds the canonical embedding fixture (directors, films, +// actors, roles) and registers functions that return rows of those relations, so +// /rpc embeds resolve through the same relationships a table read uses. +func newRPCEmbedServer(t testing.TB) *httpapi.Server { + t.Helper() + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + + _, err = be.DB().Exec(` + CREATE TABLE directors ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + CREATE TABLE films ( + id INTEGER PRIMARY KEY, + title TEXT NOT NULL, + year INTEGER, + director_id INTEGER REFERENCES directors(id) + ); + CREATE TABLE actors ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + CREATE TABLE roles ( + film_id INTEGER NOT NULL REFERENCES films(id), + actor_id INTEGER NOT NULL REFERENCES actors(id), + PRIMARY KEY (film_id, actor_id) + ); + INSERT INTO directors (id, name) VALUES + (1, 'Lang'), (2, 'Scott'), (3, 'Villeneuve'); + INSERT INTO films (id, title, year, director_id) VALUES + (1, 'Metropolis', 1927, 1), + (2, 'Blade Runner', 1982, 2), + (3, 'Arrival', 2016, 3), + (4, 'Untitled', NULL, NULL); + INSERT INTO actors (id, name) VALUES + (1, 'Ford'), (2, 'Hauer'), (3, 'Adams'); + INSERT INTO roles (film_id, actor_id) VALUES + (2, 1), (2, 2), (3, 3); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + + be.Register(rpc.NewStaticRegistry(rpcEmbedFunctions())) + + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv +} + +// A function returning rows of a relation embeds its to-one relation: each film +// carries its director as a nested object, NULL when the film has no director. +func TestRPCEmbedToOne(t *testing.T) { + srv := newRPCEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/recent_films?select=title,directors(name)&order=id", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var rows []map[string]any + if err := json.NewDecoder(resp.Body).Decode(&rows); err != nil { + t.Fatalf("decode: %v", err) + } + if len(rows) != 4 { + t.Fatalf("got %d rows, want 4", len(rows)) + } + if rows[0]["title"] != "Metropolis" { + t.Errorf("row 0 title = %v, want Metropolis", rows[0]["title"]) + } + d, ok := rows[0]["directors"].(map[string]any) + if !ok { + t.Fatalf("directors = %#v, want a nested object", rows[0]["directors"]) + } + if d["name"] != "Lang" { + t.Errorf("director = %v, want Lang", d["name"]) + } + // Film 4 has no director, so the to-one embed is JSON null. + if rows[3]["directors"] != nil { + t.Errorf("film 4 directors = %#v, want null", rows[3]["directors"]) + } +} + +// A function result embeds a many-to-many relation: each film carries its actors +// as an array, empty for a film with no roles. +func TestRPCEmbedToMany(t *testing.T) { + srv := newRPCEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/recent_films?select=title,actors(name)&order=id", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var rows []map[string]any + if err := json.NewDecoder(resp.Body).Decode(&rows); err != nil { + t.Fatalf("decode: %v", err) + } + // Blade Runner (id 2) has two actors. + actors, ok := rows[1]["actors"].([]any) + if !ok { + t.Fatalf("actors = %#v, want an array", rows[1]["actors"]) + } + if len(actors) != 2 { + t.Fatalf("Blade Runner has %d actors, want 2", len(actors)) + } + // Metropolis (id 1) has no actors: an empty array, not null. + empty, ok := rows[0]["actors"].([]any) + if !ok || len(empty) != 0 { + t.Errorf("Metropolis actors = %#v, want an empty array", rows[0]["actors"]) + } +} + +// An !inner embed on a call drops parent rows with no related match, the same as +// on a table read: only films that have actors survive. +func TestRPCEmbedInnerFilters(t *testing.T) { + srv := newRPCEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/recent_films?select=title,actors!inner(name)&order=id", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var rows []map[string]any + if err := json.NewDecoder(resp.Body).Decode(&rows); err != nil { + t.Fatalf("decode: %v", err) + } + if len(rows) != 2 { + t.Fatalf("got %d rows, want 2 (only films with actors)", len(rows)) + } +} + +// A to-many embed from the other side: each director carries its films. +func TestRPCEmbedToManyFilms(t *testing.T) { + srv := newRPCEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/all_directors?select=name,films(title)&order=id", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var rows []map[string]any + if err := json.NewDecoder(resp.Body).Decode(&rows); err != nil { + t.Fatalf("decode: %v", err) + } + films, ok := rows[0]["films"].([]any) + if !ok || len(films) != 1 { + t.Fatalf("director 1 films = %#v, want one film", rows[0]["films"]) + } + if films[0].(map[string]any)["title"] != "Metropolis" { + t.Errorf("film = %v, want Metropolis", films[0]) + } +} + +// An exact count on an embedded call carries the embed's restriction: with +// actors!inner only the two films that have actors count, and Content-Range +// reports that total rather than the function's full row set. +func TestRPCEmbedInnerCount(t *testing.T) { + srv := newRPCEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/recent_films?select=title,actors!inner(name)&order=id", + map[string]string{"Prefer": "count=exact"}) + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + t.Fatalf("status = %d, want 200/206", resp.StatusCode) + } + cr := resp.Header.Get("Content-Range") + if !strings.HasSuffix(cr, "/2") { + t.Errorf("Content-Range = %q, want a total of 2", cr) + } +} + +// Embedding on a function whose result is not a known relation has nothing to +// resolve against and is the read path's PGRST200. +func TestRPCEmbedOnScalarSetIsError(t *testing.T) { + srv := newRPCEmbedServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/film_titles?select=directors(name)", nil) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "PGRST200" { + t.Errorf("code = %v, want PGRST200", env["code"]) + } +} diff --git a/httpapi/rpc_render_test.go b/httpapi/rpc_render_test.go new file mode 100644 index 0000000..0786c26 --- /dev/null +++ b/httpapi/rpc_render_test.go @@ -0,0 +1,132 @@ +package httpapi + +import ( + "io" + "testing" + + "github.com/tamnd/dbrest/backend" + "github.com/tamnd/dbrest/reqctx" + "github.com/tamnd/dbrest/rpc" +) + +// rowStream is a forward-only stub over fixed rows, enough to drive renderCall. +type rowStream struct { + cols []string + rows [][]any + pos int +} + +func (s *rowStream) Columns() []string { return s.cols } +func (s *rowStream) Next() bool { + if s.pos >= len(s.rows) { + return false + } + s.pos++ + return true +} +func (s *rowStream) Values() ([]any, error) { return s.rows[s.pos-1], nil } +func (s *rowStream) Err() error { return nil } +func (s *rowStream) Close() error { return nil } + +// rowResult is a backend.Result backed by an in-memory row stream. +type rowResult struct{ s *rowStream } + +func (r rowResult) Body() io.Reader { return nil } +func (r rowResult) Rows() backend.RowStream { return r.s } +func (r rowResult) Count() (int64, bool) { return 0, false } +func (r rowResult) Affected() (int64, bool) { return 0, false } +func (r rowResult) ResponseControls() *reqctx.ResponseControls { + return &reqctx.ResponseControls{} +} + +func resultOf(cols []string, rows ...[]any) backend.Result { + return rowResult{s: &rowStream{cols: cols, rows: rows}} +} + +// TestRenderCallShapes covers finding 03-P06: renderCall shapes a native RPC +// result by the function's introspected return kind, not by a column-name guess. +// A SETOF scalar is a JSON array of bare values (no longer truncated to the first +// row); a single composite is one bare object (no longer wrapped in an array); a +// table whose lone column collides with the function name is still an array of +// objects (no longer collapsed to a scalar); a scalar with a named OUT parameter +// is the bare value (no longer an object). +func TestRenderCallShapes(t *testing.T) { + cases := []struct { + name string + fn *rpc.Function + res backend.Result + want string + }{ + { + name: "setof scalar is an array of bare values", + fn: &rpc.Function{Name: "ret_setof_integers", Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf, Type: "int4"}}, + res: resultOf([]string{"ret_setof_integers"}, []any{int64(1)}, []any{int64(2)}, []any{int64(3)}), + want: "[1,2,3]", + }, + { + name: "single composite is one bare object", + fn: &rpc.Function{Name: "ret_point_2d", Returns: rpc.ReturnShape{Kind: rpc.ReturnObject}}, + res: resultOf([]string{"x", "y"}, []any{int64(10), int64(5)}), + want: `{"x":10,"y":5}`, + }, + { + name: "single composite with no row is null", + fn: &rpc.Function{Name: "ret_point_2d", Returns: rpc.ReturnShape{Kind: rpc.ReturnObject}}, + res: resultOf([]string{"x", "y"}), + want: "null", + }, + { + name: "name-collision table is an array of objects", + fn: &rpc.Function{Name: "title", Returns: rpc.ReturnShape{Kind: rpc.ReturnTable}}, + res: resultOf([]string{"title"}, []any{"Dune"}, []any{"Arrival"}), + want: `[{"title":"Dune"},{"title":"Arrival"}]`, + }, + { + name: "scalar with a named OUT parameter is the bare value", + fn: &rpc.Function{Name: "add", Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "int4"}}, + res: resultOf([]string{"sum"}, []any{int64(7)}), + want: "7", + }, + { + name: "plain scalar is the bare value", + fn: &rpc.Function{Name: "now_year", Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "int4"}}, + res: resultOf([]string{"now_year"}, []any{int64(2026)}), + want: "2026", + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + out, apiErr := renderCall(mediaJSON, c.res, c.fn, c.fn.Name) + if apiErr != nil { + t.Fatalf("renderCall: %v", apiErr) + } + if string(out.body) != c.want { + t.Errorf("body = %s, want %s", out.body, c.want) + } + }) + } +} + +// A nil descriptor (a function the catalog never introspected) keeps the legacy +// column-name fallback: a lone column named after the function is a scalar, and +// anything wider renders like a table read. This guards the regression-safe path +// nativeFunc relies on when it returns nil. +func TestRenderCallNilFallback(t *testing.T) { + scalar := resultOf([]string{"answer"}, []any{int64(42)}) + out, apiErr := renderCall(mediaJSON, scalar, nil, "answer") + if apiErr != nil { + t.Fatalf("renderCall(scalar fallback): %v", apiErr) + } + if string(out.body) != "42" { + t.Errorf("scalar fallback body = %s, want 42", out.body) + } + + table := resultOf([]string{"a", "b"}, []any{int64(1), int64(2)}) + out, apiErr = renderCall(mediaJSON, table, nil, "some_fn") + if apiErr != nil { + t.Fatalf("renderCall(table fallback): %v", apiErr) + } + if string(out.body) != `[{"a":1,"b":2}]` { + t.Errorf("table fallback body = %s, want [{\"a\":1,\"b\":2}]", out.body) + } +} diff --git a/httpapi/rpc_response_controls_test.go b/httpapi/rpc_response_controls_test.go new file mode 100644 index 0000000..5bb5509 --- /dev/null +++ b/httpapi/rpc_response_controls_test.go @@ -0,0 +1,199 @@ +package httpapi_test + +import ( + "context" + "net/http" + "strings" + "testing" + + "github.com/tamnd/dbrest/backend/sqlite" + "github.com/tamnd/dbrest/httpapi" + "github.com/tamnd/dbrest/rpc" +) + +// 07.14: a portable registry function steers the response the way a PostgreSQL +// function does with the response.status / response.headers GUCs, except an +// emulated backend has no setting a single SELECT can write, so the function +// projects reserved columns of the same name. The backend lifts them into the +// response controls and strips them from the body. + +func responseControlFunctions() []*rpc.Function { + return []*rpc.Function{ + { + Name: "gone", + Returns: rpc.ReturnShape{Kind: rpc.ReturnTable, Columns: []rpc.Column{{Name: "message"}}}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: `SELECT 'resource gone' AS message, 410 AS "response.status"`}, + }, + { + Name: "with_header", + Returns: rpc.ReturnShape{Kind: rpc.ReturnTable, Columns: []rpc.Column{{Name: "message"}}}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: `SELECT 'ok' AS message, '[{"X-Total-Count":"42"}]' AS "response.headers"`}, + }, + { + Name: "archive", + Params: []rpc.Param{{Name: "id", Type: "integer"}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnTable, Columns: []rpc.Column{{Name: "title"}}}, + Volatility: rpc.Volatile, + Query: &rpc.PortableQuery{SQL: `UPDATE films SET year = 0 WHERE id = :id RETURNING title, 202 AS "response.status"`}, + }, + { + Name: "bad_status", + Returns: rpc.ReturnShape{Kind: rpc.ReturnTable, Columns: []rpc.Column{{Name: "message"}}}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: `SELECT 'x' AS message, 9999 AS "response.status"`}, + }, + { + Name: "bad_header", + Returns: rpc.ReturnShape{Kind: rpc.ReturnTable, Columns: []rpc.Column{{Name: "message"}}}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: `SELECT 'x' AS message, 'not-a-header' AS "response.headers"`}, + }, + { + Name: "bad_status_volatile", + Params: []rpc.Param{{Name: "id", Type: "integer"}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnTable, Columns: []rpc.Column{{Name: "title"}}}, + Volatility: rpc.Volatile, + Query: &rpc.PortableQuery{SQL: `UPDATE films SET year = 0 WHERE id = :id RETURNING title, 9999 AS "response.status"`}, + }, + } +} + +func newResponseControlServer(t *testing.T) *httpapi.Server { + t.Helper() + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + + _, err = be.DB().Exec(` + CREATE TABLE films ( + id INTEGER PRIMARY KEY, + title TEXT NOT NULL, + year INTEGER + ); + INSERT INTO films (id, title, year) VALUES + (1, 'Metropolis', 1927), + (2, 'Blade Runner', 1982); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + be.Register(rpc.NewStaticRegistry(responseControlFunctions())) + + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv +} + +// TestRPCResponseStatusOverride: a read-only function projecting response.status +// sets the HTTP status and the column never appears in the body. +func TestRPCResponseStatusOverride(t *testing.T) { + srv := newResponseControlServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/gone", nil) + if resp.StatusCode != http.StatusGone { + t.Fatalf("status = %d, want 410", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 || rows[0]["message"] != "resource gone" { + t.Fatalf("body = %v", rows) + } + if _, leaked := rows[0]["response.status"]; leaked { + t.Error("response.status column leaked into the body") + } +} + +// TestRPCResponseHeaderOverride: a function projecting response.headers merges the +// header into the response. +func TestRPCResponseHeaderOverride(t *testing.T) { + srv := newResponseControlServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/with_header", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if got := resp.Header.Get("X-Total-Count"); got != "42" { + t.Errorf("X-Total-Count = %q, want 42", got) + } + rows := decodeArray(t, resp) + if len(rows) != 1 || rows[0]["message"] != "ok" { + t.Fatalf("body = %v", rows) + } + if _, leaked := rows[0]["response.headers"]; leaked { + t.Error("response.headers column leaked into the body") + } +} + +// TestRPCResponseStatusVolatile: a volatile function steers the status the same +// way through its RETURNING projection, after the mutation it commits. +func TestRPCResponseStatusVolatile(t *testing.T) { + srv := newResponseControlServer(t) + resp := send(t, srv, http.MethodPost, "/rpc/archive", `{"id":1}`, nil) + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("status = %d, want 202", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 || rows[0]["title"] != "Metropolis" { + t.Fatalf("body = %v", rows) + } + if _, leaked := rows[0]["response.status"]; leaked { + t.Error("response.status column leaked into the body") + } + // The mutation committed: the archived film now has year 0. + after := do(t, srv, http.MethodGet, "/films?id=eq.1&select=year", nil) + got := decodeArray(t, after) + if len(got) != 1 || got[0]["year"].(float64) != 0 { + t.Errorf("archive did not persist: %v", got) + } +} + +// TestRPCInvalidResponseStatus: a function projecting an out-of-range status is +// PGRST112, the way PostgREST rejects a junk response.status rather than +// forwarding it. +func TestRPCInvalidResponseStatus(t *testing.T) { + srv := newResponseControlServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/bad_status", nil) + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("status = %d, want 500", resp.StatusCode) + } + if env := decodeEnvelope(t, resp); env["code"] != "PGRST112" { + t.Errorf("code = %v, want PGRST112", env["code"]) + } +} + +// TestRPCInvalidResponseHeaders: a malformed response.headers is PGRST111. +func TestRPCInvalidResponseHeaders(t *testing.T) { + srv := newResponseControlServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/bad_header", nil) + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("status = %d, want 500", resp.StatusCode) + } + if env := decodeEnvelope(t, resp); env["code"] != "PGRST111" { + t.Errorf("code = %v, want PGRST111", env["code"]) + } +} + +// TestRPCInvalidResponseStatusVolatileRollsBack: an invalid status from a +// volatile function fails before commit, so the mutation is discarded. +func TestRPCInvalidResponseStatusVolatileRollsBack(t *testing.T) { + srv := newResponseControlServer(t) + resp := send(t, srv, http.MethodPost, "/rpc/bad_status_volatile", `{"id":2}`, nil) + if resp.StatusCode != http.StatusInternalServerError { + t.Fatalf("status = %d, want 500", resp.StatusCode) + } + if env := decodeEnvelope(t, resp); env["code"] != "PGRST112" { + t.Errorf("code = %v, want PGRST112", env["code"]) + } + // The UPDATE rolled back: film 2 still has its seeded year. + after := do(t, srv, http.MethodGet, "/films?id=eq.2&select=year", nil) + got := decodeArray(t, after) + if len(got) != 1 || got[0]["year"].(float64) != 1982 { + t.Errorf("rollback failed, film 2 year = %v, want 1982", got) + } +} diff --git a/httpapi/rpc_test.go b/httpapi/rpc_test.go index 97e3542..8580da7 100644 --- a/httpapi/rpc_test.go +++ b/httpapi/rpc_test.go @@ -3,6 +3,7 @@ package httpapi_test import ( "context" "encoding/json" + "io" "net/http" "net/http/httptest" "strings" @@ -36,6 +37,18 @@ func rpcFunctions() []*rpc.Function { Volatility: rpc.Stable, Query: &rpc.PortableQuery{SQL: "SELECT title FROM films ORDER BY id"}, }, + { + Name: "get_request_method", + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "text"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: "SELECT :request_method"}, + }, + { + Name: "get_jwt_claims", + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "json"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: "SELECT :request_jwt_claims"}, + }, { Name: "films_after", Params: []rpc.Param{{Name: "y", Type: "integer"}}, @@ -43,6 +56,13 @@ func rpcFunctions() []*rpc.Function { Volatility: rpc.Stable, Query: &rpc.PortableQuery{SQL: "SELECT id, title FROM films WHERE year > :y ORDER BY id"}, }, + { + Name: "pick_titles", + Params: []rpc.Param{{Name: "ids", Type: "integer", Variadic: true}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf, Type: "text"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: "SELECT title FROM films WHERE id IN (:ids) ORDER BY id"}, + }, } } @@ -76,7 +96,9 @@ func newRPCServer(t testing.TB) *httpapi.Server { if err != nil { t.Fatalf("introspect: %v", err) } - return httpapi.NewServer(be, model, nil) + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv } func TestRPCGetScalarAddThem(t *testing.T) { @@ -126,6 +148,24 @@ func TestRPCUnknownFunctionIs404(t *testing.T) { } } +// /rpc//extra is a multi-segment path, not a missing function: PostgREST +// answers PGRST125 at 404, not the PGRST202 a missing function gets (item 04.8). +func TestRPCNestedPathIsInvalidPath(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/add/extra", nil) + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status = %d, want 404", resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "PGRST125" { + t.Errorf("code = %v, want PGRST125", env["code"]) + } +} + +// A GET to a volatile function fails with the read-only-transaction SQLSTATE +// 25006 at 405, the same code and status PostgREST surfaces when the read-only +// transaction rejects the function's write (item 04.6). func TestRPCGetOnVolatileIs405(t *testing.T) { srv := newRPCServer(t) resp := do(t, srv, http.MethodGet, "/rpc/bump_year?film_id=1", nil) @@ -134,8 +174,8 @@ func TestRPCGetOnVolatileIs405(t *testing.T) { } var env map[string]any json.NewDecoder(resp.Body).Decode(&env) - if env["code"] != "PGRST101" { - t.Errorf("code = %v, want PGRST101", env["code"]) + if env["code"] != "25006" { + t.Errorf("code = %v, want 25006", env["code"]) } } @@ -162,6 +202,9 @@ func TestRPCPostVolatilePersists(t *testing.T) { func TestRPCPostVolatileRollback(t *testing.T) { srv := newRPCServer(t) + // tx= is only honored under an allow-override db-tx-end policy; the default + // commit ignores it (02.4). Enable override so the rollback takes effect. + srv.SetTxEnd("commit-allow-override") resp := send(t, srv, http.MethodPost, "/rpc/bump_year", `{"film_id":2}`, map[string]string{ "Prefer": "tx=rollback", }) @@ -226,6 +269,70 @@ func TestRPCTablePostFilter(t *testing.T) { } } +// TestRPCGetArgAndColumnFilter pins the GET argument-versus-filter split: y names +// the function parameter and binds as an argument, while title names no parameter +// and post-filters the table return as a horizontal filter, the way PostgREST +// treats a non-argument query key on a table-valued function. +func TestRPCGetArgAndColumnFilter(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/films_after?y=1900&title=eq.Arrival", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("got %d rows, want 1", len(rows)) + } + if rows[0]["title"] != "Arrival" { + t.Errorf("title = %v, want Arrival", rows[0]["title"]) + } +} + +// TestRPCVariadicGet checks a variadic parameter collects repeated query keys on +// GET and expands into the IN list, so pick_titles?ids=1&ids=3 binds both ids. +func TestRPCVariadicGet(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/pick_titles?ids=1&ids=3", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var titles []string + if err := json.NewDecoder(resp.Body).Decode(&titles); err != nil { + t.Fatalf("decode: %v", err) + } + if len(titles) != 2 || titles[0] != "Metropolis" || titles[1] != "Arrival" { + t.Errorf("titles = %v, want [Metropolis Arrival]", titles) + } +} + +// TestRPCVariadicPost checks a variadic parameter takes a JSON array on POST and +// expands into the same IN list. +func TestRPCVariadicPost(t *testing.T) { + srv := newRPCServer(t) + resp := send(t, srv, http.MethodPost, "/rpc/pick_titles", `{"ids":[1,3]}`, nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var titles []string + if err := json.NewDecoder(resp.Body).Decode(&titles); err != nil { + t.Fatalf("decode: %v", err) + } + if len(titles) != 2 || titles[0] != "Metropolis" || titles[1] != "Arrival" { + t.Errorf("titles = %v, want [Metropolis Arrival]", titles) + } +} + +// TestRPCGetBadArgTypeIs400 checks a GET argument that does not coerce to its +// declared parameter type is a 22P02 400, the same error a read filter raises, +// rather than reaching the engine as raw text. +func TestRPCGetBadArgTypeIs400(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/add_them?a=notanint&b=3", nil) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } +} + func TestRPCTableSingular(t *testing.T) { srv := newRPCServer(t) resp := do(t, srv, http.MethodGet, "/rpc/films_after?y=2000", map[string]string{ @@ -273,3 +380,281 @@ func BenchmarkRPCGetScalar(b *testing.B) { } } } + +// TestRPCScalarJSONReturnsRaw pins the declared-json contract: a function +// returning json emits the document itself, not a quoted string, the way a +// PostgreSQL json function behaves through PostgREST. An expression carries +// no column type, so the declared return type drives the conversion. +func TestRPCScalarJSONReturnsRaw(t *testing.T) { + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + be.Register(rpc.NewStaticRegistry([]*rpc.Function{{ + Name: "payload", + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "json"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: `SELECT json_object('a', 1)`}, + }})) + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + + resp := do(t, srv, http.MethodGet, "/rpc/payload", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var doc map[string]any + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("the body should be a JSON object, not a quoted string: %v", err) + } + if doc["a"] != float64(1) { + t.Errorf("body = %v", doc) + } +} + +// TestRPCVoidReturns200Null pins the void contract: PostgREST answers a +// void-returning function with 200 and a null JSON body, never 204. The function +// runs its side effect (an INSERT here) and the response carries null. +func TestRPCVoidReturns200Null(t *testing.T) { + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + if _, err := be.DB().Exec(` + CREATE TABLE films (id INTEGER PRIMARY KEY, title TEXT NOT NULL, year INTEGER); + `); err != nil { + t.Fatalf("seed: %v", err) + } + be.Register(rpc.NewStaticRegistry([]*rpc.Function{{ + Name: "touch_film", + Returns: rpc.ReturnShape{Kind: rpc.ReturnVoid}, + Volatility: rpc.Volatile, + Query: &rpc.PortableQuery{SQL: `INSERT INTO films(id, title, year) VALUES (999, 'Void', 2000)`}, + }})) + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + + resp := send(t, srv, http.MethodPost, "/rpc/touch_film", `{}`, nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if got := strings.TrimSpace(string(body)); got != "null" { + t.Errorf("body = %q, want null", got) + } + // The side effect ran: the row is present. + var n int + if err := be.DB().QueryRow(`SELECT count(*) FROM films WHERE id = 999`).Scan(&n); err != nil { + t.Fatalf("verify: %v", err) + } + if n != 1 { + t.Errorf("void function side effect did not persist: count = %d", n) + } +} + +// TestRPCSingleRawBodyTakesWholeBody pins the single-unnamed-parameter form: a +// function with one raw-body parameter receives the entire POST body as that one +// argument, decoded by Content-Type, rather than read as an object of named +// arguments. A JSON array body would fail the named-object decode, so its +// round-trip proves the raw-body path bound it whole. +func TestRPCSingleRawBodyTakesWholeBody(t *testing.T) { + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + be.Register(rpc.NewStaticRegistry([]*rpc.Function{{ + Name: "echo_payload", + Params: []rpc.Param{{Name: "payload", Type: "json", RawBody: true}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "json"}, + Volatility: rpc.Immutable, + Query: &rpc.PortableQuery{SQL: `SELECT :payload`}, + }})) + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + + resp := send(t, srv, http.MethodPost, "/rpc/echo_payload", `[1,2,3]`, nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var arr []json.Number + dec := json.NewDecoder(resp.Body) + dec.UseNumber() + if err := dec.Decode(&arr); err != nil { + t.Fatalf("the array body should round-trip whole: %v", err) + } + if len(arr) != 3 || arr[0].String() != "1" || arr[2].String() != "3" { + t.Errorf("body = %v, want [1 2 3]", arr) + } +} + +// TestRPCSingleRawBodyText pins the text content type on the raw-body form: a +// text/plain body binds to the lone parameter as text and echoes back. +func TestRPCSingleRawBodyText(t *testing.T) { + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + be.Register(rpc.NewStaticRegistry([]*rpc.Function{{ + Name: "shout", + Params: []rpc.Param{{Name: "line", Type: "text", RawBody: true}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "text"}, + Volatility: rpc.Immutable, + Query: &rpc.PortableQuery{SQL: `SELECT upper(:line)`}, + }})) + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + + resp := send(t, srv, http.MethodPost, "/rpc/shout", `hello`, map[string]string{ + "Content-Type": "text/plain", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var s string + if err := json.NewDecoder(resp.Body).Decode(&s); err != nil { + t.Fatalf("decode: %v", err) + } + if s != "HELLO" { + t.Errorf("body = %q, want HELLO", s) + } +} + +// The reserved :request_* placeholders give a registry function the request +// context PostgreSQL functions read with current_setting (spec 15). The HTTP +// surface matches PostgREST's GUC behavior on every engine. +func TestRPCContextRequestMethod(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/get_request_method", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var s string + if err := json.NewDecoder(resp.Body).Decode(&s); err != nil { + t.Fatalf("decode: %v", err) + } + if s != "GET" { + t.Errorf("body = %q, want GET", s) + } +} + +// TestRPCSetofContentRangeAlwaysPresent pins 02.7: an RPC read carries a +// Content-Range like a table read even with no count requested, the unknown +// total rendered as 0-2/*. +func TestRPCSetofContentRangeAlwaysPresent(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/film_titles", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "0-2/*" { + t.Errorf("Content-Range = %q, want 0-2/*", cr) + } +} + +// TestRPCRangeHeaderOnGet pins 02.7: a GET /rpc honors a unitless Range header +// the way a table read does, slicing the set. With no count the total is +// unknown, so the status stays 200 (PostgREST's rangeStatus on a missing total) +// while Content-Range echoes the slice. +func TestRPCRangeHeaderOnGet(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/film_titles", map[string]string{ + "Range": "0-1", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "0-1/*" { + t.Errorf("Content-Range = %q, want 0-1/*", cr) + } + var titles []string + if err := json.NewDecoder(resp.Body).Decode(&titles); err != nil { + t.Fatalf("decode: %v", err) + } + if len(titles) != 2 || titles[0] != "Metropolis" || titles[1] != "Blade Runner" { + t.Errorf("titles = %v, want [Metropolis Blade Runner]", titles) + } +} + +// TestRPCRangeHeaderOnGetWithCountIs206 pins 02.7: the same slice with an exact +// count knows the total exceeds the slice, so it is the 206 a table read gives. +func TestRPCRangeHeaderOnGetWithCountIs206(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/film_titles", map[string]string{ + "Range": "0-1", + "Prefer": "count=exact", + }) + if resp.StatusCode != http.StatusPartialContent { + t.Fatalf("status = %d, want 206", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "0-1/3" { + t.Errorf("Content-Range = %q, want 0-1/3", cr) + } +} + +// TestRPCRangeOutOfBoundsIs416 pins 02.7: a GET /rpc Range whose offset is past +// the known total is the same 416 a table read raises. +func TestRPCRangeOutOfBoundsIs416(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/film_titles", map[string]string{ + "Range": "5-9", + "Prefer": "count=exact", + }) + if resp.StatusCode != http.StatusRequestedRangeNotSatisfiable { + t.Fatalf("status = %d, want 416", resp.StatusCode) + } +} + +// TestRPCInvertedRangeOnGetIs416 pins 02.7: an inverted Range on a GET /rpc is +// the same 416 a table read raises, before any work runs. +func TestRPCInvertedRangeOnGetIs416(t *testing.T) { + srv := newRPCServer(t) + resp := do(t, srv, http.MethodGet, "/rpc/film_titles", map[string]string{ + "Range": "3-1", + }) + if resp.StatusCode != http.StatusRequestedRangeNotSatisfiable { + t.Fatalf("status = %d, want 416", resp.StatusCode) + } +} + +// An anonymous request still carries the resolved role in request.jwt.claims: +// PostgREST folds the role into the claims object even when the token had none, +// so the claims are {"role":""}, not {}. Verified against PostgREST +// 14.12, where an anonymous call presents {"role":""}. +func TestRPCContextJWTClaimsCarriesAnonRole(t *testing.T) { + srv := newRPCServer(t) + resp := send(t, srv, http.MethodPost, "/rpc/get_jwt_claims", `{}`, nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + var claims map[string]any + if err := json.NewDecoder(resp.Body).Decode(&claims); err != nil { + t.Fatalf("decode: %v", err) + } + if len(claims) != 1 || claims["role"] != "anon" { + t.Errorf("claims = %v, want {\"role\":\"anon\"}", claims) + } +} diff --git a/httpapi/server.go b/httpapi/server.go index d4fca75..1aa1749 100644 --- a/httpapi/server.go +++ b/httpapi/server.go @@ -5,12 +5,15 @@ package httpapi import ( + "context" + "errors" "fmt" "io" "net/http" "net/url" "strconv" "strings" + "time" "github.com/tamnd/dbrest/auth" "github.com/tamnd/dbrest/authz" @@ -19,34 +22,66 @@ import ( "github.com/tamnd/dbrest/pgerr" "github.com/tamnd/dbrest/plan" "github.com/tamnd/dbrest/reqctx" + "github.com/tamnd/dbrest/rpc" "github.com/tamnd/dbrest/schema" ) // singularMediaType is the Accept value that asks for a single object. const singularMediaType = "application/vnd.pgrst.object+json" -// maxBodyBytes caps a request body, so a runaway payload cannot exhaust memory. -const maxBodyBytes = 16 << 20 // 16 MiB - // Server holds the resolved schema model and the backend, and serves the API. A // verifier, when set, resolves the request role from the JWT; with none, every // request runs as the static default role. type Server struct { - backend backend.Backend - model *schema.Model - searchPath []string - role string - verifier *auth.Verifier - authz *authz.Registry - openapiMode string - openapiProxy string + backend backend.Backend + cache *schema.Cache + searchPath []string + role string + verifier *auth.Verifier + authz *authz.Registry + openapiMode string + openapiProxy string + openapiSecurity bool + rootSpec string + corsOrigins []string // server-cors-allowed-origins; empty means any + maxRows int // db-max-rows; 0 means no cap + maxBody int64 // max-request-body bytes; 0 means unlimited + planEnabled bool // db-plan-enabled; plans are off by default + aggregatesOn bool // db-aggregates-enabled; aggregates are off by default + preRequest string // db-pre-request, carried to the backend per request + appSettings map[string]string + logQuery bool // log-query, carried to the backend per request + timingEnabled bool // server-timing-enabled; the Server-Timing header is off by default + txEnd ir.TxEnd // db-tx-end; governs whether Prefer: tx= is honored } // NewServer builds a Server over a backend, its introspected model, and the -// schema search path (the exposed schemas, in resolution order). It runs every -// request as the anon role until a verifier is attached with SetVerifier. +// schema search path (the exposed schemas, in resolution order). It has no +// default role: until SetDefaultRole or SetVerifier provides an identity +// source, every request is refused with 401 PGRST302, matching PostgREST's +// fail-closed posture when db-anon-role is unset. func NewServer(b backend.Backend, model *schema.Model, searchPath []string) *Server { - return &Server{backend: b, model: model, searchPath: searchPath, role: "anon"} + return &Server{backend: b, cache: schema.NewCache(model), searchPath: searchPath} +} + +// Model returns the current schema model snapshot. A handler loads it once at +// entry so one request never straddles a reload. +func (s *Server) Model() *schema.Model { return s.cache.Load() } + +// Reload re-runs introspection and publishes the fresh model, the schema +// cache reload PostgREST performs on SIGUSR1 and on NOTIFY over the +// db-channel. In-flight requests keep the snapshot they started with; a +// failed introspection leaves the old model published, so a transient +// database error never takes the running cache down. The OpenAPI document is +// generated per request from the published model and needs no separate +// regeneration. +func (s *Server) Reload(ctx context.Context) error { + model, err := s.backend.Introspect(ctx) + if err != nil { + return err + } + s.cache.Store(model) + return nil } // SetOpenAPI configures the root document. mode is the openapi-mode option: @@ -54,27 +89,123 @@ func NewServer(b backend.Backend, model *schema.Model, searchPath []string) *Ser // modes leave it on. proxyURI, when set, is the externally visible base URL the // document advertises (the openapi-server-proxy-uri option), overriding the // host and scheme the request arrived on so a document served behind a reverse -// proxy points at the public address. See spec 20. -func (s *Server) SetOpenAPI(mode, proxyURI string) { +// proxy points at the public address. securityActive is the +// openapi-security-active option: it attaches the JWT security requirement to +// every operation rather than just describing the scheme. See spec 20. +func (s *Server) SetOpenAPI(mode, proxyURI string, securityActive bool) { s.openapiMode = mode s.openapiProxy = proxyURI + s.openapiSecurity = securityActive } -// SetDefaultRole overrides the static role used for unauthenticated requests -// when no verifier is configured. It should be called with the db-anon-role -// option value so the server uses the configured anon role instead of the -// hardcoded "anon" placeholder. +// SetRootSpec names the function whose JSON result replaces the generated +// OpenAPI document, the db-root-spec option. Empty keeps the generated +// document. The function is called like GET /rpc/ with no arguments. +func (s *Server) SetRootSpec(fn string) { s.rootSpec = fn } + +// SetDefaultRole sets the static role used for unauthenticated requests when no +// verifier is configured. It should be called with the db-anon-role option +// value; left unset, tokenless requests are refused with 401 PGRST302. func (s *Server) SetDefaultRole(role string) { if role != "" { s.role = role } } +// SetMaxRows applies the db-max-rows option: a hard cap on the rows any read +// or RPC response may return, enforced as an implicit LIMIT at plan time. Zero +// means no cap. Mutation representations are exempt, matching PostgREST v10+. +func (s *Server) SetMaxRows(n int) { s.maxRows = n } + +// MaxRows reports the configured db-max-rows cap (0 when uncapped). The +// count=estimated logic uses it as the exactness threshold. +func (s *Server) MaxRows() int { return s.maxRows } + +// SetMaxRequestBody applies the max-request-body option: a byte cap on a +// request body. Zero, the default, leaves bodies unlimited as PostgREST does; +// a positive value is a runaway-payload guard that an operator opts into. +func (s *Server) SetMaxRequestBody(n int) { s.maxBody = int64(n) } + +// readBody reads a request body, honoring the optional max-request-body cap. A +// body over the cap is a 413 with the byte bound, not a parse error; a read +// error under the cap stays the generic could-not-read parse error. +func (s *Server) readBody(w http.ResponseWriter, r *http.Request) ([]byte, *pgerr.APIError) { + reader := r.Body + if s.maxBody > 0 { + reader = http.MaxBytesReader(w, r.Body, s.maxBody) + } + b, err := io.ReadAll(reader) + if err != nil { + var tooLarge *http.MaxBytesError + if errors.As(err, &tooLarge) { + return nil, pgerr.ErrBodyTooLarge(s.maxBody) + } + return nil, pgerr.ErrInvalidBody("could not read request body") + } + return b, nil +} + +// capLimit lowers *limit to the db-max-rows cap, installing the cap as the +// limit when the client did not ask for one. It returns the (possibly +// replaced) pointer so callers can assign it back into the query. +func (s *Server) capLimit(limit *int) *int { + if s.maxRows <= 0 { + return limit + } + if limit == nil || *limit > s.maxRows { + capped := s.maxRows + return &capped + } + return limit +} + +// SetCORSAllowedOrigins restricts cross-origin requests to the given origin +// list (the server-cors-allowed-origins option). With an empty list the server +// keeps the PostgREST default: any origin is accepted. +func (s *Server) SetCORSAllowedOrigins(origins []string) { s.corsOrigins = origins } + +// SetPlanEnabled applies the db-plan-enabled option. Execution plans leak +// schema and statistics detail, so PostgREST only honors the +// application/vnd.pgrst.plan+json media type when the option is on; the +// default is off, and a plan request then fails the same way as any other +// unproducible media type. +func (s *Server) SetPlanEnabled(on bool) { s.planEnabled = on } + +// SetAggregatesEnabled applies db-aggregates-enabled: when on, requests may use +// aggregate functions (count(), col.sum(), ...). It is off by default, matching +// PostgREST, so an aggregate request answers PGRST123 until an operator opts in. +func (s *Server) SetAggregatesEnabled(on bool) { s.aggregatesOn = on } + +// SetAppSettings carries the app.settings.* options to the backend on every +// request context, to be applied as transaction settings. +func (s *Server) SetAppSettings(settings map[string]string) { s.appSettings = settings } + +// SetLogQuery asks backends to echo the statements they execute, the +// log-query option. +func (s *Server) SetLogQuery(on bool) { s.logQuery = on } + +// SetServerTimingEnabled applies the server-timing-enabled option. When on, +// every response carries a Server-Timing header with the jwt/parse/plan/ +// transaction/response phase durations; the default is off, matching +// PostgREST, so the wire is unchanged until an operator opts in. +func (s *Server) SetServerTimingEnabled(on bool) { s.timingEnabled = on } + +// SetTxEnd applies the db-tx-end option, the policy that decides whether a +// request's Prefer: tx= may override the transaction outcome. The default +// commit ignores the preference, matching PostgREST. +func (s *Server) SetTxEnd(v string) { s.txEnd = ir.ParseTxEnd(v) } + // SetVerifier attaches a JWT verifier. Once set, the role and claims of each // request come from its bearer token (spec 13), and a bad token is rejected // before any query runs. With no verifier the server keeps the static role. func (s *Server) SetVerifier(v *auth.Verifier) { s.verifier = v } +// SetPreRequest names the db-pre-request function carried to the backend on +// every request context. The backend invokes it after the request context is +// in place and before the main statement (spec 13); the caller is responsible +// for refusing the option at startup on a backend that cannot honor it. +func (s *Server) SetPreRequest(fn string) { s.preRequest = fn } + // SetAuthz attaches an authorization registry. Once set, every read and write is // gated by the registry's table and column privileges and has any Row Level // Security policy injected before execution (spec 14). On a backend whose engine @@ -104,36 +235,76 @@ type identity struct { // buildContext assembles the per-request context the backend receives: the // resolved identity plus the request metadata that crosses the HTTP/query -// boundary (method, path, headers, cookies, and the selected schema). The -// frontend builds it once after authentication; on the emulated backend the -// values a policy references are later bound as parameters (spec 15). -func buildContext(r *http.Request, id identity) *reqctx.Context { +// boundary (method, path, headers, cookies, and the active schema), and the +// configured transaction-scoped settings (db-pre-request, app.settings.*, +// log-query). The frontend builds it once after authentication; on the +// emulated backend the values a policy references are later bound as +// parameters (spec 15). +func (s *Server) buildContext(r *http.Request, id identity, activeSchema string) *reqctx.Context { cookies := r.Cookies() jar := make(map[string]string, len(cookies)) for _, c := range cookies { jar[c.Name] = c.Value } return &reqctx.Context{ - Role: id.role, - Anonymous: id.anonymous, - Claims: id.claims, - Method: r.Method, - Path: r.URL.Path, - Headers: r.Header, - Cookies: jar, - Schema: requestSchema(r), + Role: id.role, + Anonymous: id.anonymous, + Claims: id.claims, + Method: r.Method, + Path: r.URL.Path, + Headers: r.Header, + Cookies: jar, + Schema: activeSchema, + PreRequest: s.preRequest, + AppSettings: s.appSettings, + LogQuery: s.logQuery, } } -// requestSchema reads the schema the client selected with the Accept-Profile -// header (reads) or the Content-Profile header (writes). It carries the choice -// onto the context; cross-schema identifier routing is the introspection -// subsystem's job (spec 08), so an unset header is the default schema. -func requestSchema(r *http.Request) string { - if r.Method == http.MethodGet || r.Method == http.MethodHead { - return r.Header.Get("Accept-Profile") +// resolveSchema negotiates the active schema for the request, the PostgREST +// profile rules: POST/PATCH/PUT/DELETE read Content-Profile, every other +// method reads Accept-Profile; no header selects the first exposed schema. A +// profile outside db-schemas is 406 PGRST106. The bool reports whether the +// schema was negotiated, which is when the client named one, or implicitly on +// a multi-schema deployment; a negotiated response echoes the active schema in +// a Content-Profile response header. +func (s *Server) resolveSchema(r *http.Request) (string, bool, *pgerr.APIError) { + var profile string + switch r.Method { + case http.MethodPost, http.MethodPatch, http.MethodPut, http.MethodDelete: + profile = r.Header.Get("Content-Profile") + default: + profile = r.Header.Get("Accept-Profile") + } + if profile == "" { + var def string + if len(s.searchPath) > 0 { + def = s.searchPath[0] + } + return def, len(s.searchPath) > 1, nil } - return r.Header.Get("Content-Profile") + for _, sch := range s.searchPath { + if sch == profile { + return profile, true, nil + } + } + return "", false, errUnacceptableSchema(profile, s.searchPath) +} + +// errUnacceptableSchema is PostgREST's PGRST106: a profile header naming a +// schema that is not exposed by db-schemas, a 406 whose hint lists the schemas +// that are. +func errUnacceptableSchema(profile string, schemas []string) *pgerr.APIError { + e := pgerr.New(http.StatusNotAcceptable, "PGRST106", "Invalid schema: "+profile) + return e.WithHint("Only the following schemas are exposed: " + strings.Join(schemas, ", ")) +} + +// applyTxPolicy resolves a request's Prefer: tx= against the db-tx-end server +// policy and returns the PGRST122 a handling=strict request earns when tx= is +// disallowed. It runs after parsing and before execution on every method. +func (s *Server) applyTxPolicy(p *ir.PreferSet) *pgerr.APIError { + p.ResolveTx(s.txEnd) + return p.StrictError() } // applyControls applies a backend's response controls and returns the status to @@ -153,9 +324,14 @@ func applyControls(w http.ResponseWriter, rc *reqctx.ResponseControls, def int) // authenticate resolves the request identity from the Authorization header. With // no verifier it is the static default role; otherwise the verifier maps the -// bearer token to a role (or anon), or returns the 401/403 the token earns. +// bearer token to a role (or anon), or returns the 401/403 the token earns. The +// no-verifier path fails closed: with no default role configured, tokenless +// requests are refused with 401 PGRST302 rather than run as anyone. func (s *Server) authenticate(r *http.Request) (identity, *pgerr.APIError) { if s.verifier == nil { + if s.role == "" { + return identity{}, pgerr.ErrJWTRequired() + } return identity{role: s.role, anonymous: true}, nil } res, apiErr := s.verifier.Authenticate(r.Header.Get("Authorization")) @@ -170,33 +346,187 @@ func (s *Server) authenticate(r *http.Request) (identity, *pgerr.APIError) { // PATCH updates; PUT upserts; DELETE deletes. RPC and OpenAPI arrive with their // subsystems; an unhandled method gets an honest error. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.serveCORS(w, r) { + return + } + if r.Method == http.MethodOptions { + // OPTIONS describes the resource with an Allow header and runs no + // transaction, so it answers before authentication and schema negotiation, + // the way PostgREST does. A CORS preflight was already handled by serveCORS. + s.handleOptions(w, r) + return + } + // server-timing-enabled wraps the response so every exit path emits the + // Server-Timing header, and carries a phaseTimer to the handlers through the + // request context. The jwt phase is the only one measured here; the rest are + // recorded inside the handlers. + var timer *phaseTimer + if s.timingEnabled { + timer = &phaseTimer{} + w = &timingWriter{ResponseWriter: w, timer: timer} + r = r.WithContext(withTimer(r.Context(), timer)) + } + jwtStart := time.Now() id, apiErr := s.authenticate(r) + timer.mark("jwt", jwtStart) + if apiErr != nil { + writeError(w, apiErr) + return + } + activeSchema, negotiated, apiErr := s.resolveSchema(r) if apiErr != nil { writeError(w, apiErr) return } + if negotiated { + // PostgREST echoes the negotiated schema on successful responses so the + // client knows which schema served it; writeError strips it on failure. + w.Header().Set("Content-Profile", activeSchema) + } if fn, ok := rpcName(r.URL.Path); ok { - s.handleRPC(w, r, fn, id) + s.handleRPC(w, r, fn, id, activeSchema) return } if r.URL.Path == "/" { - s.handleRoot(w, r) + s.handleRoot(w, r, id, activeSchema) return } switch r.Method { case http.MethodGet, http.MethodHead: - s.handleRead(w, r, id) + s.handleRead(w, r, id, activeSchema) case http.MethodPost: - s.handleWrite(w, r, ir.Insert, id) + s.handleWrite(w, r, ir.Insert, id, activeSchema) case http.MethodPatch: - s.handleWrite(w, r, ir.Update, id) + s.handleWrite(w, r, ir.Update, id, activeSchema) case http.MethodPut: - s.handleWrite(w, r, ir.Upsert, id) + s.handleWrite(w, r, ir.Upsert, id, activeSchema) case http.MethodDelete: - s.handleWrite(w, r, ir.Delete, id) + s.handleWrite(w, r, ir.Delete, id, activeSchema) default: - writeError(w, pgerr.ErrUnsupported(r.Method+" requests", "dbrest")) + // A verb the server implements nowhere (TRACE, CONNECT, a custom method) + // is PostgREST's 405 PGRST117, not the capability gate's PGRST127. + writeError(w, pgerr.ErrUnsupportedMethod(r.Method)) + } +} + +// tableAllow is the Allow value an OPTIONS on a table or view answers with: the +// full verb set a relation endpoint accepts, in PostgREST's order. +const tableAllow = "OPTIONS,GET,HEAD,POST,PUT,PATCH,DELETE" + +// handleOptions answers an OPTIONS request with an Allow header naming the +// methods the resource accepts and a 200 with no body, the way PostgREST does. +// The root answers its own verb set, a function answers by volatility (a +// volatile function is POST-only, otherwise GET/HEAD/POST are allowed too), and +// every table or view answers the full relation verb set. No transaction runs. +func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + w.Header().Set("Allow", rootAllow) + } else if fn, ok := rpcName(r.URL.Path); ok { + w.Header().Set("Allow", s.rpcAllow(fn)) + } else { + w.Header().Set("Allow", tableAllow) } + w.WriteHeader(http.StatusOK) +} + +// rpcAllow is the Allow value for an OPTIONS on /rpc/: a volatile function +// accepts only OPTIONS and POST, every other (read-only) function also accepts +// GET and HEAD. A non-registry (native) backend does not resolve volatility +// here, so it answers the read-capable set, matching PostgREST's default for a +// function whose volatility is not yet known. +func (s *Server) rpcAllow(fn string) string { + const readable = "OPTIONS,GET,HEAD,POST" + const writeOnly = "OPTIONS,POST" + if s.backend.Capabilities().NativeRPC { + return readable + } + for _, f := range s.backend.Functions().List() { + if f.Name == fn && !f.Volatility.ReadOnly() { + return writeOnly + } + } + return readable +} + +// corsExposedHeaders is the Access-Control-Expose-Headers value PostgREST +// returns on every cross-origin request. +const corsExposedHeaders = "Content-Encoding, Content-Location, Content-Range, Content-Type, " + + "Date, Location, Server, Transfer-Encoding, Range-Unit" + +// corsAllowedMethods is the Access-Control-Allow-Methods value PostgREST +// returns on a preflight. +const corsAllowedMethods = "GET, POST, PATCH, PUT, DELETE, OPTIONS, HEAD" + +// serveCORS answers CORS the way PostgREST v14 does and reports whether the +// request was fully handled (a preflight). A request without an Origin header +// is untouched. With server-cors-allowed-origins unset any origin is accepted +// with Access-Control-Allow-Origin: *; with the option set, a listed origin is +// reflected with Access-Control-Allow-Credentials: true and an unlisted one +// falls through to normal handling with no CORS headers (the browser enforces +// the denial). A preflight (OPTIONS with Access-Control-Request-Method) is +// answered directly with the allowed methods, the requested headers, and a +// one-day max age, before authentication and routing. +func (s *Server) serveCORS(w http.ResponseWriter, r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return false + } + allowOrigin := "*" + credentials := false + if len(s.corsOrigins) > 0 { + found := false + for _, o := range s.corsOrigins { + if o == origin { + found = true + break + } + } + if !found { + return false + } + allowOrigin = origin + credentials = true + } + + h := w.Header() + h.Set("Access-Control-Allow-Origin", allowOrigin) + if credentials { + h.Set("Access-Control-Allow-Credentials", "true") + } + + if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + h.Set("Access-Control-Allow-Methods", corsAllowedMethods) + h.Set("Access-Control-Allow-Headers", corsAllowedHeaders(r.Header.Get("Access-Control-Request-Headers"))) + h.Set("Access-Control-Max-Age", "86400") + w.WriteHeader(http.StatusOK) + return true + } + + h.Set("Access-Control-Expose-Headers", corsExposedHeaders) + return false +} + +// corsAllowedHeaders builds the preflight Access-Control-Allow-Headers value: +// Authorization, then the headers the client asked for, then the simple +// headers, deduplicated case-insensitively. The order matches PostgREST. +func corsAllowedHeaders(requested string) string { + out := []string{"Authorization"} + seen := map[string]bool{"authorization": true} + add := func(name string) { + name = strings.TrimSpace(name) + if name == "" || seen[strings.ToLower(name)] { + return + } + seen[strings.ToLower(name)] = true + out = append(out, name) + } + for _, name := range strings.Split(requested, ",") { + add(name) + } + for _, name := range []string{"Accept", "Accept-Language", "Content-Language"} { + add(name) + } + return strings.Join(out, ", ") } // rpcName extracts the function name from an /rpc/ path, reporting false for @@ -214,16 +544,24 @@ func rpcName(path string) (string, bool) { // string); POST reads or writes (arguments from the JSON body). A read method may // only reach a read-only function; the plan raises 405 otherwise. Any other // method is not allowed on a function. See spec 12-rpc. -func (s *Server) handleRPC(w http.ResponseWriter, r *http.Request, fn string, id identity) { - if fn == "" || strings.Contains(fn, "/") { - writeError(w, pgerr.ErrNoFunction(fn)) +func (s *Server) handleRPC(w http.ResponseWriter, r *http.Request, fn string, id identity, activeSchema string) { + if strings.Contains(fn, "/") { + // /rpc//extra is a multi-segment path, not a missing function: PostgREST + // answers PGRST125 "Invalid path specified in request URL" (item 04.8). + writeError(w, pgerr.ErrInvalidPath()) + return + } + if fn == "" { + writeError(w, pgerr.ErrNoFunction(activeSchema, fn, nil, "")) return } isGet := r.Method == http.MethodGet || r.Method == http.MethodHead if !isGet && r.Method != http.MethodPost { - writeError(w, pgerr.ErrMethodNotAllowed( - "Method "+r.Method+" not allowed on a function; use GET or POST")) + // PUT, PATCH, or DELETE on a function is PostgREST's PGRST101 with the + // exact "Cannot use the method on RPC" text. OPTIONS never + // reaches here; it is answered with an Allow header before routing. + writeError(w, pgerr.ErrInvalidRPCMethod(r.Method)) return } @@ -235,89 +573,210 @@ func (s *Server) handleRPC(w http.ResponseWriter, r *http.Request, fn string, id var body []byte if r.Method == http.MethodPost { - b, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodyBytes)) - if err != nil { - writeError(w, pgerr.ErrParse("could not read request body")) + b, apiErr := s.readBody(w, r) + if apiErr != nil { + writeError(w, apiErr) return } body = b } - call, apiErr := ir.ParseCall(fn, r.URL.RawQuery, r.Header.Values("Prefer"), isGet, r.Header.Get("Content-Type"), body) + // A function with a single unnamed parameter takes the whole POST body as that + // argument, decoded by Content-Type. Both the portable registry and the native + // registry (once introspected) report this form, so the body is bound to the + // parameter rather than read as a JSON object of named arguments. + rawBodyParam, rawBodyType := "", "" + if !isGet { + rawBodyParam, rawBodyType = singleRawBodyParam(s.rpcRegistry(activeSchema), fn) + } + + t := timerFrom(r.Context()) + + parseStart := time.Now() + call, apiErr := ir.ParseCall(fn, r.URL.RawQuery, r.Header.Values("Prefer"), isGet, r.Header.Get("Content-Type"), body, rawBodyParam, rawBodyType) if apiErr != nil { writeError(w, apiErr) return } - call.Singular = media == mediaObject + t.mark("parse", parseStart) + call.Singular = singularMedia(media) + if apiErr := s.applyTxPolicy(&call.Prefer); apiErr != nil { + writeError(w, apiErr) + return + } + + // A GET /rpc read honors the Range header the same way a table read does: + // it overrides ?limit=&offset= and an inverted range is 416. + if isGet { + if rangeHdr := r.Header.Get("Range"); rangeHdr != "" && !strings.Contains(rangeHdr, "=") { + off, lim, ok, inverted := parseRangeHeader(rangeHdr) + if inverted { + writeError(w, pgerr.ErrRangeNotSatisfiable(). + WithDetails("The lower boundary must be lower than or equal to the upper boundary in the Range header.")) + return + } + if ok { + call.Offset = &off + if lim >= 0 { + l := lim + call.Limit = &l + } + } + } + } + + // db-max-rows caps an RPC response like a read (an implicit LIMIT). + call.Limit = s.capLimit(call.Limit) + planStart := time.Now() var planned *ir.Plan if s.backend.Capabilities().NativeRPC { - // PostgreSQL (and any other NativeRPC backend) discovers and executes - // the function from its own catalog. We skip the portable-registry lookup - // and build a minimal plan: ReadOnly follows the HTTP method (GET/HEAD - // means read-only; POST means the function may write). The engine enforces - // the volatility constraint — if a GET reaches a volatile function the - // read-only transaction fails with SQLSTATE 25006, which maps to 405. - planned = &ir.Plan{Call: call, ReadOnly: isGet} + // A NativeRPC backend that introspects its functions resolves a known name + // through the shared planner: overload resolution (PGRST202/PGRST203), GET + // argument-versus-filter partitioning, declared-type argument coercion, and + // the volatility-driven access mode (a POST to a STABLE or IMMUTABLE function + // runs read-only) all match the portable path. An unknown name (one the + // introspection did not model) falls back to a minimal engine-planned call so + // it still reaches the catalog rather than 404ing on a registry miss; the + // engine then enforces volatility (a GET to a volatile function fails with + // SQLSTATE 25006, mapped to 405). + reg := s.rpcRegistry(activeSchema) + if registryKnows(reg, call.Function.Name) { + planned, apiErr = plan.Call(reg, s.Model(), call, isGet, []string{activeSchema}) + if apiErr != nil { + writeError(w, apiErr) + return + } + } else { + planned = &ir.Plan{Call: call, ReadOnly: isGet} + } } else { - planned, apiErr = plan.Call(s.backend.Functions(), call, isGet, s.searchPath) + planned, apiErr = plan.Call(s.backend.Functions(), s.Model(), call, isGet, []string{activeSchema}) if apiErr != nil { writeError(w, apiErr) return } } + t.mark("plan", planStart) + + rc := s.buildContext(r, id, activeSchema) + if call.Prefer.TimeZone != nil { + rc.TimeZone = *call.Prefer.TimeZone + } + + // A plan request on an RPC call returns the EXPLAIN for the function instead + // of running it; route it before Execute so mediaPlan never reaches the + // call renderer. + if media == mediaPlan { + s.servePlan(w, r, id, func(exp backend.Explainer, opts backend.PlanOptions) ([]byte, error) { + return exp.ExplainCall(r.Context(), planned, rc, opts) + }) + return + } - rc := buildContext(r, id) + txStart := time.Now() res, err := s.backend.Execute(r.Context(), planned, rc) if err != nil { writeError(w, mapExecError(s.backend, err, id.anonymous)) return } + t.mark("transaction", txStart) - out, apiErr := renderCall(media, res, planned.Func, fn) + respStart := time.Now() + var out *rendered + if len(call.Embeds) > 0 { + // An embedded call returns parent rows with nested resources, the same + // row-object shape a table read produces, so the read renderer drives it + // with the call's embed columns marked raw. + out, apiErr = renderFor(media, res, embedKeysFor(call.Embeds, call.Select)) + } else { + out, apiErr = renderCall(media, res, planned.Func, fn) + } if apiErr != nil { writeError(w, apiErr) return } + t.mark("response", respStart) s.writeCall(w, r, call, out, res.ResponseControls()) } -// writeCall writes a successful RPC response. The status is 200, or 206 when a -// bounded window over a table return did not cover the full count. A requested -// count sets Content-Range, matching a read. -func (s *Server) writeCall(w http.ResponseWriter, r *http.Request, call *ir.Call, out *rendered, ctrl *reqctx.ResponseControls) { - if applied := call.Prefer.AppliedHeader(); applied != "" { - w.Header().Set("Preference-Applied", applied) +// rpcRegistry is the registry an RPC resolves against in the active schema. A +// NativeRPC backend that introspects its functions (SchemaFunctioner) resolves +// against that schema's native registry merged with any declared portable +// registry, so overload resolution, GET argument partitioning, and the +// volatility-driven access mode all run through the shared planner over one +// function set; any other backend uses the single portable registry for every +// schema. +func (s *Server) rpcRegistry(activeSchema string) rpc.Registry { + if s.backend.Capabilities().NativeRPC { + if sf, ok := s.backend.(backend.SchemaFunctioner); ok { + // A NativeRPC backend can also carry a declared portable registry + // (the documented escape hatch for functions with no native + // equivalent). Merge keeps both reachable through one Registry, with + // a declared function shadowing a same-signature native one, so the + // call path and the OpenAPI document agree on exactly one function set. + return rpc.Merge(s.backend.Functions(), sf.SchemaFunctions(activeSchema)) + } } - w.Header().Set("Content-Type", out.contentType) + return s.backend.Functions() +} - offset := 0 - if call.Offset != nil { - offset = *call.Offset - } - if out.hasTotl { - w.Header().Set("Content-Range", contentRange(offset, out.nRows, out.total, true)) +// registryKnows reports whether a registry has any overload of a function name. The +// native path resolves a known name through the planner (so a wrong argument set is +// PGRST202 and an ambiguous one PGRST203) and falls back to a minimal engine-planned +// call for an unknown name, so a function the introspection did not model still +// reaches the engine rather than 404ing on a registry miss. +func registryKnows(reg rpc.Registry, name string) bool { + for _, fn := range reg.List() { + if fn.Name == name { + return true + } } + return false +} - status := http.StatusOK - hasWindow := call.Limit != nil || call.Offset != nil - if hasWindow && out.hasTotl && int64(out.nRows) < out.total { - status = http.StatusPartialContent +// singleRawBodyParam reports the parameter name and type of a function whose +// signature is a single unnamed argument, the form that takes the whole POST +// body as one value decoded by Content-Type. It scans every overload of the +// name and returns the first single-raw-body match; an absent or multi-parameter +// function yields empty strings, leaving the normal named-arguments path. See +// spec 12-rpc and the PostgREST single-unnamed-parameter rule. +func singleRawBodyParam(reg rpc.Registry, name string) (string, string) { + for _, fn := range reg.List() { + if fn.Name != name { + continue + } + if p, ok := fn.SingleRawBody(); ok { + return p.Name, p.Type + } } - w.WriteHeader(applyControls(w, ctrl, status)) - if r.Method != http.MethodHead { - w.Write(out.body) + return "", "" +} + +// writeCall writes a successful RPC response. An RPC read uses the same +// pagination contract as a table read: Content-Range is always present, an +// out-of-bounds offset is 416, and the 200/206 rule follows the count. +func (s *Server) writeCall(w http.ResponseWriter, r *http.Request, call *ir.Call, out *rendered, ctrl *reqctx.ResponseControls) { + offset := 0 + if call.Offset != nil { + offset = *call.Offset } + s.writePaged(w, r, call.Prefer.AppliedHeader(), offset, out, ctrl) } -func (s *Server) handleRead(w http.ResponseWriter, r *http.Request, id identity) { +func (s *Server) handleRead(w http.ResponseWriter, r *http.Request, id identity, activeSchema string) { relation := strings.Trim(r.URL.Path, "/") if relation == "" || strings.Contains(relation, "/") { - writeError(w, pgerr.ErrUnknownTable(relation)) + // A path with more than one segment names no routable resource; PostgREST + // answers PGRST125 "Invalid path specified in request URL", distinct from + // the PGRST205 a single unknown relation gets (item 04.8). + writeError(w, pgerr.ErrInvalidPath()) return } + t := timerFrom(r.Context()) + acceptHdrs := r.Header.Values("Accept") media, ok := negotiate(acceptHdrs) if !ok { @@ -325,12 +784,19 @@ func (s *Server) handleRead(w http.ResponseWriter, r *http.Request, id identity) return } + parseStart := time.Now() q, apiErr := ir.ParseRead(relation, r.URL.RawQuery, r.Header.Values("Prefer")) if apiErr != nil { writeError(w, apiErr) return } - q.Singular = media == mediaObject + t.mark("parse", parseStart) + q.Singular = singularMedia(media) + + if apiErr := s.applyTxPolicy(&q.Prefer); apiErr != nil { + writeError(w, apiErr) + return + } // Range: header overrides ?limit=&offset= and marks the request as a // Range request so the server can return 206 Partial Content. PostgREST @@ -338,7 +804,13 @@ func (s *Server) handleRead(w http.ResponseWriter, r *http.Request, id identity) // Only treat Range as item pagination when it has no unit prefix (i.e. // not "bytes=0-9" form), matching PostgREST's parsing behaviour. if rangeHdr := r.Header.Get("Range"); rangeHdr != "" && !strings.Contains(rangeHdr, "=") { - if off, lim, ok := parseRangeHeader(rangeHdr); ok { + off, lim, ok, inverted := parseRangeHeader(rangeHdr) + if inverted { + writeError(w, pgerr.ErrRangeNotSatisfiable(). + WithDetails("The lower boundary must be lower than or equal to the upper boundary in the Range header.")) + return + } + if ok { q.Offset = &off if lim >= 0 { l := lim @@ -348,56 +820,130 @@ func (s *Server) handleRead(w http.ResponseWriter, r *http.Request, id identity) } } - planned, apiErr := plan.Read(s.model, q, s.searchPath) + // db-max-rows is a hard cap on every read: the effective window is + // min(requested limit, max-rows), applied before planning so Content-Range + // and the 200/206 decision see the limit that actually ran. Mutation + // representations are exempt (PostgREST v10+), so this stays off the + // write path. + q.Limit = s.capLimit(q.Limit) + + // An estimated count crosses from exact to the planner estimate at db-max-rows; + // hand the backend that threshold so it can decide which side a result lands on. + if q.Count == ir.CountEstimated && s.maxRows > 0 { + q.CountMax = int64(s.maxRows) + } + + planStart := time.Now() + planned, apiErr := plan.Read(s.Model(), q, []string{activeSchema}, plan.Options{AggregatesEnabled: s.aggregatesOn}) if apiErr != nil { writeError(w, apiErr) return } + t.mark("plan", planStart) - rc := buildContext(r, id) + rc := s.buildContext(r, id, activeSchema) + if q.Prefer.TimeZone != nil { + rc.TimeZone = *q.Prefer.TimeZone + } if apiErr := s.authorize(rc, planned); apiErr != nil { writeError(w, apiErr) return } - // vnd.pgrst.plan+json: return EXPLAIN JSON when the backend supports it. + // application/vnd.pgrst.plan: return the EXPLAIN output when the backend + // supports it. servePlan applies the db-plan-enabled gate and the Explainer + // check so the plan media type never reaches the renderer. if media == mediaPlan { - exp, supported := s.backend.(backend.Explainer) - if !supported { - writeError(w, pgerr.ErrNotAcceptable(mediaPlan)) - return - } - planJSON, err := exp.ExplainRead(r.Context(), planned, rc, planAnalyze(acceptHdrs)) - if err != nil { - writeError(w, mapExecError(s.backend, err, id.anonymous)) - return - } - w.Header().Set("Content-Type", mediaPlan) - w.WriteHeader(http.StatusOK) - w.Write(planJSON) + s.servePlan(w, r, id, func(exp backend.Explainer, opts backend.PlanOptions) ([]byte, error) { + return exp.ExplainRead(r.Context(), planned, rc, opts) + }) return } + txStart := time.Now() res, err := s.backend.Execute(r.Context(), planned, rc) if err != nil { writeError(w, mapExecError(s.backend, err, id.anonymous)) return } + t.mark("transaction", txStart) + respStart := time.Now() out, apiErr := renderFor(media, res, embedKeys(q)) if apiErr != nil { writeError(w, apiErr) return } + t.mark("response", respStart) s.writeRead(w, r, q, out, res.ResponseControls()) } -func (s *Server) handleWrite(w http.ResponseWriter, r *http.Request, kind ir.QueryKind, id identity) { +// servePlan answers a negotiated application/vnd.pgrst.plan request for any of +// the three execution paths. It enforces the db-plan-enabled gate and the +// backend Explainer capability first (406 when either is absent, so the plan +// media type never falls through to the renderer and 500s), parses the plan +// options from the Accept header, runs the explain via the supplied selector, +// and echoes the full parameterized Content-Type. explain picks +// ExplainRead/ExplainWrite/ExplainCall on the resolved Explainer. +func (s *Server) servePlan(w http.ResponseWriter, r *http.Request, id identity, explain func(backend.Explainer, backend.PlanOptions) ([]byte, error)) { + if !s.planEnabled { + writeError(w, pgerr.ErrNotAcceptable(mediaPlan)) + return + } + exp, supported := s.backend.(backend.Explainer) + if !supported { + writeError(w, pgerr.ErrNotAcceptable(mediaPlan)) + return + } + opts, _ := parsePlan(r.Header.Values("Accept")) + planBytes, err := explain(exp, opts) + if err != nil { + writeError(w, mapExecError(s.backend, err, id.anonymous)) + return + } + w.Header().Set("Content-Type", planContentType(opts)) + w.WriteHeader(http.StatusOK) + w.Write(planBytes) +} + +// planContentType builds the response Content-Type echoed for a plan request: +// the format suffix, the for="" target, the options= flags that were +// set, and the charset, matching PostgREST's parameterized plan media type. +func planContentType(opts backend.PlanOptions) string { + sub := "application/vnd.pgrst.plan+text" + if opts.Format == backend.PlanJSON { + sub = "application/vnd.pgrst.plan+json" + } + parts := []string{sub} + if opts.For != "" { + parts = append(parts, `for="`+opts.For+`"`) + } + var flags []string + for _, f := range []struct { + on bool + name string + }{ + {opts.Analyze, "analyze"}, {opts.Verbose, "verbose"}, + {opts.Settings, "settings"}, {opts.Buffers, "buffers"}, {opts.Wal, "wal"}, + } { + if f.on { + flags = append(flags, f.name) + } + } + if len(flags) > 0 { + parts = append(parts, "options="+strings.Join(flags, "|")) + } + parts = append(parts, "charset=utf-8") + return strings.Join(parts, "; ") +} + +func (s *Server) handleWrite(w http.ResponseWriter, r *http.Request, kind ir.QueryKind, id identity, activeSchema string) { relation := strings.Trim(r.URL.Path, "/") if relation == "" || strings.Contains(relation, "/") { - writeError(w, pgerr.ErrUnknownTable(relation)) + // As in handleRead: a multi-segment path is PGRST125, not a missing table. + writeError(w, pgerr.ErrInvalidPath()) return } @@ -409,37 +955,71 @@ func (s *Server) handleWrite(w http.ResponseWriter, r *http.Request, kind ir.Que var body []byte if kind != ir.Delete { - b, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodyBytes)) - if err != nil { - writeError(w, pgerr.ErrParse("could not read request body")) + b, apiErr := s.readBody(w, r) + if apiErr != nil { + writeError(w, apiErr) return } body = b } + t := timerFrom(r.Context()) + + parseStart := time.Now() q, apiErr := ir.ParseWrite(kind, relation, r.URL.RawQuery, r.Header.Values("Prefer"), r.Header.Get("Content-Type"), body) if apiErr != nil { writeError(w, apiErr) return } - q.Singular = media == mediaObject + t.mark("parse", parseStart) + q.Singular = singularMedia(media) + + if apiErr := s.applyTxPolicy(&q.Prefer); apiErr != nil { + writeError(w, apiErr) + return + } + if q.Write != nil { + if q.Prefer.Tx != nil { + q.Write.Tx = *q.Prefer.Tx + } else { + q.Write.Tx = ir.TxAuto + } + } - planned, apiErr := plan.Write(s.model, q, s.searchPath) + planStart := time.Now() + planned, apiErr := plan.Write(s.Model(), q, []string{activeSchema}) if apiErr != nil { writeError(w, apiErr) return } + t.mark("plan", planStart) - rc := buildContext(r, id) + rc := s.buildContext(r, id, activeSchema) + if q.Prefer.TimeZone != nil { + rc.TimeZone = *q.Prefer.TimeZone + } if apiErr := s.authorize(rc, planned); apiErr != nil { writeError(w, apiErr) return } + + // A plan request on a write returns the EXPLAIN for the mutation instead of + // running it. servePlan gates and routes it so mediaPlan never reaches the + // renderer (which would 500 on a write). + if media == mediaPlan { + s.servePlan(w, r, id, func(exp backend.Explainer, opts backend.PlanOptions) ([]byte, error) { + return exp.ExplainWrite(r.Context(), planned, rc, opts) + }) + return + } + + txStart := time.Now() res, err := s.backend.Execute(r.Context(), planned, rc) if err != nil { writeError(w, mapExecError(s.backend, err, id.anonymous)) return } + t.mark("transaction", txStart) s.writeWrite(w, r, q, media, planned.Rel, res) } @@ -453,72 +1033,88 @@ func (s *Server) writeWrite(w http.ResponseWriter, r *http.Request, q *ir.Query, if applied := q.Prefer.AppliedHeader(); applied != "" { w.Header().Set("Preference-Applied", applied) } - // PostgREST v14 returns a Location header only for return=headers-only inserts/upserts. - // For return=representation or minimal, Location is omitted. - if (q.Kind == ir.Insert || q.Kind == ir.Upsert) && q.Write != nil && q.Write.Return == ir.ReturnHeadersOnly { + // A Location points at a newly created resource. PostgREST sets it only for a + // return=headers-only POST insert or upsert of a single row; a PUT never + // carries one (02.9). + if r.Method == http.MethodPost && (q.Kind == ir.Insert || q.Kind == ir.Upsert) && + q.Write != nil && q.Write.Return == ir.ReturnHeadersOnly { if loc := locationHeader(rel, q.Relation.Name, res); loc != "" { w.Header().Set("Location", loc) } } + // Content-Range is present on every write except PUT, shaped by method (02.8): + // POST and DELETE report the total-only "*/*" form ("*/N" with count=exact), + // PATCH the affected-row range "0-(n-1)/...". It does not depend on the return + // mode, so a minimal write carries it too. + affected, hasAff := res.Affected() + if cr := writeContentRange(r.Method, affected, hasAff, q.Count); cr != "" { + w.Header().Set("Content-Range", cr) + } + representation := q.Write.Return == ir.ReturnRepresentation if !representation { - // When count=exact was requested, include Content-Range: */ so the - // client knows how many rows were affected, matching PostgREST's wire. - if q.Count == ir.CountExact { - if n, ok := res.Affected(); ok { - w.Header().Set("Content-Range", fmt.Sprintf("*/%d", n)) - } - } w.WriteHeader(applyControls(w, ctrl, writeStatus(r.Method, q.Kind, false, ctrl))) return } + respStart := time.Now() out, apiErr := renderFor(media, res, embedKeys(q)) if apiErr != nil { writeError(w, apiErr) return } + timerFrom(r.Context()).mark("response", respStart) w.Header().Set("Content-Type", out.contentType) - if !q.Singular { - // For writes with count=exact, include the total in Content-Range. - if q.Count == ir.CountExact { - if n, ok := res.Affected(); ok { - w.Header().Set("Content-Range", contentRange(0, out.nRows, n, true)) - } else { - w.Header().Set("Content-Range", contentRange(0, out.nRows, 0, false)) - } - } else { - w.Header().Set("Content-Range", contentRange(0, out.nRows, 0, false)) - } - } w.WriteHeader(applyControls(w, ctrl, writeStatus(r.Method, q.Kind, true, ctrl))) if r.Method != http.MethodHead { w.Write(out.body) } } +// writeContentRange builds the Content-Range header for a write, shaped by the +// HTTP method (02.8). A PUT carries none. POST and DELETE report the total-only +// "*/*" form ("*/N" with count=exact); PATCH reports the affected-row range +// "0-(n-1)/..." and falls back to "*/..." when no row matched. +func writeContentRange(method string, affected int64, hasAff bool, count ir.CountKind) string { + if method == http.MethodPut { + return "" + } + total := "*" + if count == ir.CountExact && hasAff { + total = strconv.FormatInt(affected, 10) + } + if method == http.MethodPatch && hasAff && affected > 0 { + return fmt.Sprintf("0-%d/%s", affected-1, total) + } + return "*/" + total +} + // writeStatus is the status for a successful write. // - POST insert: 201 Created. -// - POST upsert where ALL rows were new inserts: 201 Created. -// - POST upsert where at least one row was an ON CONFLICT update: 200 OK. -// - PUT upsert where the row is known to be a new insert: 201 Created. -// - PUT upsert where the row is known to be an update, or unknown: 200 OK. -// - PATCH/DELETE with representation: 200 OK. -// - PATCH/DELETE without representation: 204 No Content. +// - POST merge-duplicates upsert with zero rows inserted: 200 OK. +// - POST upsert otherwise (ignore-duplicates, mixed, all-insert, unknown): 201. +// - PUT without representation (minimal, headers-only, none): 204 No Content. +// - PUT representation with a row inserted: 201 Created; else 200 OK. +// - PATCH/DELETE with representation: 200 OK; without: 204 No Content. func writeStatus(method string, kind ir.QueryKind, representation bool, ctrl *reqctx.ResponseControls) int { - if method == http.MethodPost { - // 200 when the upsert hit at least one existing row (ON CONFLICT UPDATE fired). - // 201 otherwise (new row was inserted or unknown). - if kind == ir.Upsert && ctrl != nil && ctrl.UpsertStatusKnown && !ctrl.UpsertInsert { + switch method { + case http.MethodPost: + // A POST upsert is 200 only when the resolution is merge-duplicates and no + // row was newly inserted; ignore-duplicates and mixed batches stay 201. The + // backend reports a known insert count only for a merge upsert, so a known + // zero here already implies merge-duplicates. + if kind == ir.Upsert && ctrl != nil && ctrl.UpsertStatusKnown && ctrl.InsertedRows == 0 { return http.StatusOK } return http.StatusCreated - } - if method == http.MethodPut && kind == ir.Upsert { - // PUT is semantically "create or replace"; default to 200. - // Only return 201 when the backend positively confirms a new insert. - if ctrl != nil && ctrl.UpsertStatusKnown && ctrl.UpsertInsert { + case http.MethodPut: + // A PUT answers 204 for every return mode except representation, which is + // 201 when a row was inserted and 200 when it replaced an existing one. + if !representation { + return http.StatusNoContent + } + if ctrl != nil && ctrl.UpsertStatusKnown && ctrl.InsertedRows > 0 { return http.StatusCreated } return http.StatusOK @@ -570,37 +1166,100 @@ func locationHeader(rel *schema.Relation, relation string, res backend.Result) s // quoting their text. Nested embeds are already inside their parent's JSON blob, // so only the top level matters here. func embedKeys(q *ir.Query) map[string]bool { - if len(q.Embeds) == 0 { - return nil + return embedKeysFor(q.Embeds, q.Select) +} + +// embedKeysFor computes the raw-JSON column set from an embed list and select, +// shared by the table-read path (ir.Query) and the embedded-RPC path (ir.Call), +// which carry the same two pieces. +func embedKeysFor(embeds []ir.Embed, sel []ir.SelectItem) map[string]bool { + var keys map[string]bool + add := func(k string) { + if keys == nil { + keys = make(map[string]bool) + } + keys[k] = true + } + for i := range embeds { + emb := &embeds[i] + // A spread embed lifts its columns flat rather than nesting under a key. + // A to-many spread lifts each column as a JSON array, so those lifted + // names must render raw; a to-one spread lifts plain scalars that render + // normally. A non-spread embed nests under its OutKey. + if emb.Spread { + if emb.Rel != nil && emb.Rel.Card != schema.CardToOne { + for _, name := range spreadLiftedNames(emb) { + add(name) + } + } + continue + } + add(emb.OutKey) } - keys := make(map[string]bool, len(q.Embeds)) - for i := range q.Embeds { - keys[q.Embeds[i].OutKey] = true + // A projection ending in -> (data->meta) yields JSON the renderer must splice + // verbatim, the same as an embed; a final ->> is text and renders normally. + for _, it := range sel { + if c, ok := it.(ir.Column); ok && c.Last == ir.JSONArrow { + add(c.Name()) + } } return keys } +// spreadLiftedNames returns the parent-row column names a spread embed lifts, +// mirroring the projection rules the sqlgen spread compiler uses: an empty or +// star select lifts every target column; otherwise each named column lifts under +// its output name. +func spreadLiftedNames(emb *ir.Embed) []string { + target := emb.Rel.Target + if len(emb.Query.Select) == 0 { + return target.ColumnNames() + } + var names []string + for _, it := range emb.Query.Select { + c, ok := it.(ir.Column) + if !ok { + continue + } + if len(c.Path) == 1 && c.Path[0] == "*" { + names = append(names, target.ColumnNames()...) + continue + } + names = append(names, c.Name()) + } + return names +} + // writeRead sets the headers and status for a successful read and writes the // body (omitted for HEAD). A function or policy can shape the response through // the controls: a control header is added and a non-zero control status wins // over the computed 200/206 default. func (s *Server) writeRead(w http.ResponseWriter, r *http.Request, q *ir.Query, out *rendered, ctrl *reqctx.ResponseControls) { - if applied := q.Prefer.AppliedHeader(); applied != "" { - w.Header().Set("Preference-Applied", applied) - } - w.Header().Set("Content-Type", out.contentType) - offset := 0 if q.Offset != nil { offset = *q.Offset } + s.writePaged(w, r, q.Prefer.AppliedHeader(), offset, out, ctrl) +} + +// writePaged sets the pagination headers and status shared by table reads and +// RPC reads. Content-Range is always present (the "*" total form without a +// count). An offset strictly past a known total is 416 with the upstream +// detail; an offset equal to the total is in range and yields 206 with +// Content-Range "*/total". The 200/206 rule is 206 whenever a total is known +// and the returned span is smaller, for every count kind (PostgREST v14 returns +// 206 for count=planned/estimated too). A function or policy can override the +// status and add headers through the controls. +func (s *Server) writePaged(w http.ResponseWriter, r *http.Request, applied string, offset int, out *rendered, ctrl *reqctx.ResponseControls) { + if applied != "" { + w.Header().Set("Preference-Applied", applied) + } + w.Header().Set("Content-Type", out.contentType) w.Header().Set("Content-Range", contentRange(offset, out.nRows, out.total, out.hasTotl)) - // An out-of-range offset is 416: the window starts past the end of the - // result. This is only knowable with a count, so it applies when one was - // requested (otherwise the empty window is a plain 200 with an empty array). - if offset > 0 && out.hasTotl && int64(offset) >= out.total { - rng := pgerr.ErrRangeNotSatisfiable() + if offset > 0 && out.hasTotl && int64(offset) > out.total { + rng := pgerr.ErrRangeNotSatisfiable(). + WithDetails(fmt.Sprintf("An offset of %d was requested, but there are only %d rows.", offset, out.total)) w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(rng.HTTPStatus) if r.Method != http.MethodHead { @@ -609,47 +1268,43 @@ func (s *Server) writeRead(w http.ResponseWriter, r *http.Request, q *ir.Query, return } - w.WriteHeader(applyControls(w, ctrl, readStatus(q, out, offset))) + status := http.StatusOK + if out.hasTotl && int64(out.nRows) < out.total { + status = http.StatusPartialContent + } + w.WriteHeader(applyControls(w, ctrl, status)) if r.Method != http.MethodHead { w.Write(out.body) } } -// readStatus applies PostgREST's 200/206 rule: 206 only when a count is known -// and the page returned is genuinely partial (nRows < total). PostgREST v14 -// returns 200 for count=planned/estimated even though the total is approximate; -// the estimate is informational, not a range boundary. -func readStatus(q *ir.Query, out *rendered, _ int) int { - if !out.hasTotl { - return http.StatusOK - } - if q.Count == ir.CountExact && int64(out.nRows) < out.total { - return http.StatusPartialContent - } - return http.StatusOK -} - // parseRangeHeader parses an HTTP Range header value of the form "start-end" // (as used with Range-Unit: items). Returns (offset, limit, true) where limit -// is -1 for an open-ended range ("0-"). Returns (0, 0, false) on parse error. -func parseRangeHeader(s string) (offset, limit int, ok bool) { +// is -1 for an open-ended range ("0-"). A malformed header returns ok=false with +// inverted=false so the caller serves the full result. A well-formed header whose +// upper bound is below its lower bound returns ok=false with inverted=true, which +// PostgREST answers with 416 rather than ignoring. +func parseRangeHeader(s string) (offset, limit int, ok, inverted bool) { dash := strings.LastIndex(s, "-") if dash < 0 { - return 0, 0, false + return 0, 0, false, false } startStr, endStr := s[:dash], s[dash+1:] start, err := strconv.Atoi(startStr) if err != nil || start < 0 { - return 0, 0, false + return 0, 0, false, false } if endStr == "" { - return start, -1, true // open-ended: "0-" + return start, -1, true, false // open-ended: "0-" } end, err := strconv.Atoi(endStr) - if err != nil || end < start { - return 0, 0, false + if err != nil { + return 0, 0, false, false + } + if end < start { + return 0, 0, false, true } - return start, end - start + 1, true + return start, end - start + 1, true, false } // asAPIError normalizes a backend execution error to the API envelope, asking @@ -666,20 +1321,27 @@ func asAPIError(b backend.Backend, err error) *pgerr.APIError { } // mapExecError wraps asAPIError with the PostgREST 401/403 rule: a 42501 -// (insufficient_privilege) error to an anonymous request is 401 (authentication -// required), not 403 (forbidden). An authenticated request that is denied -// remains 403 so the caller knows to authenticate, not just retry. -// The original PostgreSQL message is preserved to match PostgREST wire behavior. +// (insufficient_privilege) error is 403 for an authenticated request, so the +// caller knows the role is wrong rather than missing, and 401 for an anonymous +// one, so it knows to authenticate. GradePrivilegeStatus is the one place the +// rule lives, so the status is correct whatever status a backend's SQLSTATE +// table assigned; mapExecError adds the bare Bearer challenge PostgREST sends on +// the 401. The original PostgreSQL message is preserved for wire compatibility. func mapExecError(b backend.Backend, err error, anonymous bool) *pgerr.APIError { - e := asAPIError(b, err) + e := pgerr.GradePrivilegeStatus(asAPIError(b, err), !anonymous) if anonymous && e.Code == pgerr.CodeInsufficientPrivilege { lifted := *e - lifted.HTTPStatus = http.StatusUnauthorized + // PostgREST sends the bare Bearer challenge on every 401, including a + // privilege denial lifted from 403 for an unauthenticated request. + lifted.WWWAuthenticate = "Bearer" return &lifted } return e } func writeError(w http.ResponseWriter, e *pgerr.APIError) { + // PostgREST does not echo Content-Profile on an error response; drop the + // header ServeHTTP may have staged before the handler failed. + w.Header().Del("Content-Profile") e.Write(w) } diff --git a/httpapi/server_test.go b/httpapi/server_test.go index 66ad422..3243256 100644 --- a/httpapi/server_test.go +++ b/httpapi/server_test.go @@ -14,6 +14,15 @@ import ( ) func newServer(t testing.TB) *httpapi.Server { + t.Helper() + srv := newServerNoRole(t) + srv.SetDefaultRole("anon") + return srv +} + +// newServerNoRole builds the test server without a default role, the state a +// bare NewServer is in before db-anon-role is applied. +func newServerNoRole(t testing.TB) *httpapi.Server { t.Helper() // A uniquely named shared-cache memory DB isolates each test's data. dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" @@ -149,6 +158,14 @@ func TestGetSingularZeroRowsIs406(t *testing.T) { if env["code"] != "PGRST116" { t.Errorf("code = %v, want PGRST116", env["code"]) } + // v14 texts: the message dropped the pre-v12 spelling and the row count + // rides in details. + if env["message"] != "Cannot coerce the result to a single JSON object" { + t.Errorf("message = %v", env["message"]) + } + if env["details"] != "The result contains 0 rows" { + t.Errorf("details = %v, want row count", env["details"]) + } } func TestGetEmptyArray(t *testing.T) { @@ -171,15 +188,45 @@ func TestUnknownTableIs404Code(t *testing.T) { if env["code"] != "PGRST205" { t.Errorf("code = %v, want PGRST205", env["code"]) } + // PGRST205 schema-qualifies the relation name (item 04.3): a backend with no + // schema namespace still reports the default public schema, as PostgREST does. + if msg, _ := env["message"].(string); msg != "Could not find the table 'public.nope' in the schema cache" { + t.Errorf("message = %q, want it schema-qualified", msg) + } +} + +// A path with more than one segment names no routable resource and is PGRST125 +// at 404, not the PGRST205 a single unknown relation gets (item 04.8). +func TestNestedTablePathIsInvalidPath(t *testing.T) { + srv := newServer(t) + for _, method := range []string{http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete} { + resp := do(t, srv, method, "/films/extra", nil) + if resp.StatusCode != http.StatusNotFound { + t.Errorf("%s status = %d, want 404", method, resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "PGRST125" { + t.Errorf("%s code = %v, want PGRST125", method, env["code"]) + } + if msg, _ := env["message"].(string); msg != "Invalid path specified in request URL" { + t.Errorf("%s message = %q", method, msg) + } + } } func TestUnknownColumnIsError(t *testing.T) { srv := newServer(t) resp := do(t, srv, http.MethodGet, "/films?select=bogus", nil) + // An unknown select column reaches PostgreSQL: 42703 at 400 (item 04.5), not + // the schema-cache PGRST204 reserved for write payloads. + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } var env map[string]any json.NewDecoder(resp.Body).Decode(&env) - if env["code"] != "PGRST204" { - t.Errorf("code = %v, want PGRST204", env["code"]) + if env["code"] != "42703" { + t.Errorf("code = %v, want 42703", env["code"]) } } @@ -346,8 +393,9 @@ func TestPostBulkInsertNoLocation(t *testing.T) { if loc := resp.Header.Get("Location"); loc != "" { t.Errorf("bulk insert should not set Location, got %q", loc) } - if cr := resp.Header.Get("Content-Range"); cr != "0-1/*" { - t.Errorf("Content-Range = %q, want 0-1/*", cr) + // A POST reports the total-only range, never a row span (02.8). + if cr := resp.Header.Get("Content-Range"); cr != "*/*" { + t.Errorf("Content-Range = %q, want */*", cr) } if len(decodeArray(t, resp)) != 2 { t.Error("want 2 inserted rows") @@ -403,13 +451,15 @@ func TestDeleteRepresentation(t *testing.T) { } } -func TestPostUpsertMergeDuplicates(t *testing.T) { +// A merge-duplicates upsert that hits an existing row updates it; PostgREST v14 +// reports 200, not 201, because nothing new was created. +func TestPostUpsertMergeDuplicatesUpdateIs200(t *testing.T) { srv := newServer(t) resp := send(t, srv, http.MethodPost, "/films", `{"id":1,"title":"Metropolis (restored)"}`, map[string]string{ "Prefer": "return=representation, resolution=merge-duplicates", }) - if resp.StatusCode != http.StatusCreated { - t.Fatalf("status = %d, want 201", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) } rows := decodeArray(t, resp) if len(rows) != 1 || rows[0]["title"] != "Metropolis (restored)" { @@ -417,11 +467,35 @@ func TestPostUpsertMergeDuplicates(t *testing.T) { } } -func TestPutUpsertIs200(t *testing.T) { +// A merge-duplicates upsert whose key is new inserts a row, so v14 reports 201. +func TestPostUpsertMergeDuplicatesInsertIs201(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPost, "/films", `{"id":50,"title":"Brand New"}`, map[string]string{ + "Prefer": "return=representation, resolution=merge-duplicates", + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201", resp.StatusCode) + } +} + +// PUT replaces or creates the addressed row. When the key did not exist the row +// is created, so v14 reports 201. +func TestPutUpsertNewIs201(t *testing.T) { srv := newServer(t) resp := send(t, srv, http.MethodPut, "/films?id=eq.20", `{"id":20,"title":"New"}`, map[string]string{ "Prefer": "return=representation", }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201", resp.StatusCode) + } +} + +// PUT to an existing key replaces it; nothing is created, so v14 reports 200. +func TestPutUpsertExistingIs200(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPut, "/films?id=eq.1", `{"id":1,"title":"Metropolis (cut)"}`, map[string]string{ + "Prefer": "return=representation", + }) if resp.StatusCode != http.StatusOK { t.Fatalf("status = %d, want 200", resp.StatusCode) } @@ -488,3 +562,141 @@ func BenchmarkGetFilteredRead(b *testing.B) { } } } + +// TestReloadPublishesNewSchema pins the schema cache reload: DDL applied +// after startup is invisible (404 PGRST205) until Reload re-runs +// introspection, after which the new table serves and the OpenAPI document +// describes it. This is the dbrest side of PostgREST's SIGUSR1 / NOTIFY +// reload flow; the signal wiring lives in cmd. +func TestReloadPublishesNewSchema(t *testing.T) { + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + if _, err := be.DB().Exec(`CREATE TABLE films (id INTEGER PRIMARY KEY, title TEXT NOT NULL);`); err != nil { + t.Fatalf("seed: %v", err) + } + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + + if _, err := be.DB().Exec(`CREATE TABLE actors (id INTEGER PRIMARY KEY, name TEXT NOT NULL);`); err != nil { + t.Fatalf("ddl: %v", err) + } + + resp := do(t, srv, http.MethodGet, "/actors", nil) + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("pre-reload status = %d, want 404", resp.StatusCode) + } + var env map[string]any + json.NewDecoder(resp.Body).Decode(&env) + if env["code"] != "PGRST205" { + t.Errorf("pre-reload code = %v, want PGRST205", env["code"]) + } + + if err := srv.Reload(context.Background()); err != nil { + t.Fatalf("Reload: %v", err) + } + + resp = do(t, srv, http.MethodGet, "/actors", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("post-reload status = %d, want 200", resp.StatusCode) + } + + resp = do(t, srv, http.MethodGet, "/", nil) + var doc map[string]any + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + t.Fatalf("decode document: %v", err) + } + if _, ok := doc["paths"].(map[string]any)["/actors"]; !ok { + t.Error("the document should describe the new table after reload") + } +} + +// newJSONColumnServer builds a server over a table with a JSON column, the +// shape the array round-trip tests need (films has none). +func newJSONColumnServer(t testing.TB) *httpapi.Server { + t.Helper() + dsn := "file:" + strings.ReplaceAll(t.Name(), "/", "_") + "?mode=memory&cache=shared" + be, err := sqlite.Open(dsn) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { be.Close() }) + _, err = be.DB().Exec(` + CREATE TABLE todos ( + id INTEGER PRIMARY KEY, + task TEXT NOT NULL, + tags JSON + ); + INSERT INTO todos (id, task, tags) VALUES (1, 'write spec', '["pets"]'); + `) + if err != nil { + t.Fatalf("seed: %v", err) + } + model, err := be.Introspect(context.Background()) + if err != nil { + t.Fatalf("introspect: %v", err) + } + srv := httpapi.NewServer(be, model, nil) + srv.SetDefaultRole("anon") + return srv +} + +// A JSON array in a write payload must land in a JSON column as JSON text and +// read back as the same array, not as PostgreSQL {a,b} literal text. This was +// the bug that corrupted tags columns on every PATCH/POST carrying an array. +func TestPatchJSONArrayRoundTrips(t *testing.T) { + srv := newJSONColumnServer(t) + resp := send(t, srv, http.MethodPatch, "/todos?id=eq.1", `{"tags":["go","sql"]}`, map[string]string{ + "Prefer": "return=representation", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("patch status = %d, want 200", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("patch rows = %v", rows) + } + assertTags := func(stage string, v any) { + t.Helper() + tags, ok := v.([]any) + if !ok || len(tags) != 2 || tags[0] != "go" || tags[1] != "sql" { + t.Fatalf("%s tags = %#v, want [go sql]", stage, v) + } + } + assertTags("representation", rows[0]["tags"]) + + resp = do(t, srv, http.MethodGet, "/todos?id=eq.1", nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("get status = %d", resp.StatusCode) + } + rows = decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("get rows = %v", rows) + } + assertTags("stored", rows[0]["tags"]) +} + +func TestPostJSONArrayRoundTrips(t *testing.T) { + srv := newJSONColumnServer(t) + resp := send(t, srv, http.MethodPost, "/todos", `{"id":2,"task":"pack","tags":["home",2]}`, map[string]string{ + "Prefer": "return=representation", + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("post status = %d, want 201", resp.StatusCode) + } + rows := decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("post rows = %v", rows) + } + tags, ok := rows[0]["tags"].([]any) + if !ok || len(tags) != 2 || tags[0] != "home" || tags[1] != float64(2) { + t.Fatalf("tags = %#v, want [home 2]", rows[0]["tags"]) + } +} diff --git a/httpapi/singular_write_test.go b/httpapi/singular_write_test.go new file mode 100644 index 0000000..7b48064 --- /dev/null +++ b/httpapi/singular_write_test.go @@ -0,0 +1,112 @@ +package httpapi_test + +import ( + "net/http" + "testing" +) + +// 07.11: a singular write (Accept: application/vnd.pgrst.object+json) must affect +// exactly one row. PostgREST enforces this inside the write transaction and rolls +// back when the count is zero or many, so the mutation never becomes durable. The +// check runs pre-commit, not in the renderer, which means return=minimal is held +// to the same guarantee even though it produces no body to inspect. + +// TestPatchSingularManyRollsBack: a singular PATCH matching three rows fails with +// PGRST116 and leaves every row unchanged. +func TestPatchSingularManyRollsBack(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?year=gte.1900", `{"rating":"X"}`, map[string]string{ + "Accept": "application/vnd.pgrst.object+json", + "Prefer": "return=representation", + }) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } + if env := decodeEnvelope(t, resp); env["code"] != "PGRST116" { + t.Errorf("code = %v, want PGRST116", env["code"]) + } + // The transaction rolled back: no row took the new rating. + after := do(t, srv, http.MethodGet, "/films?rating=eq.X&select=id", nil) + if rows := decodeArray(t, after); len(rows) != 0 { + t.Errorf("rollback failed, %d rows were updated", len(rows)) + } +} + +// TestPatchSingularManyMinimalRollsBack: the same over-broad PATCH under +// return=minimal still fails closed before commit, even though no representation +// is computed. This is the case the renderer's post-commit check could not catch. +func TestPatchSingularManyMinimalRollsBack(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?year=gte.1900", `{"rating":"X"}`, map[string]string{ + "Accept": "application/vnd.pgrst.object+json", + "Prefer": "return=minimal", + }) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } + if env := decodeEnvelope(t, resp); env["code"] != "PGRST116" { + t.Errorf("code = %v, want PGRST116", env["code"]) + } + after := do(t, srv, http.MethodGet, "/films?rating=eq.X&select=id", nil) + if rows := decodeArray(t, after); len(rows) != 0 { + t.Errorf("rollback failed, %d rows were updated", len(rows)) + } +} + +// TestPatchSingularZeroRows: a singular PATCH whose filter matches nothing is +// PGRST116; there is nothing to undo, but the wire contract still holds. +func TestPatchSingularZeroRows(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.999", `{"rating":"X"}`, map[string]string{ + "Accept": "application/vnd.pgrst.object+json", + "Prefer": "return=representation", + }) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } + if env := decodeEnvelope(t, resp); env["code"] != "PGRST116" { + t.Errorf("code = %v, want PGRST116", env["code"]) + } +} + +// TestDeleteSingularManyRollsBack: a singular DELETE matching every row fails and +// deletes nothing. +func TestDeleteSingularManyRollsBack(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodDelete, "/films", "", map[string]string{ + "Accept": "application/vnd.pgrst.object+json", + "Prefer": "return=minimal", + }) + if resp.StatusCode != http.StatusNotAcceptable { + t.Fatalf("status = %d, want 406", resp.StatusCode) + } + if env := decodeEnvelope(t, resp); env["code"] != "PGRST116" { + t.Errorf("code = %v, want PGRST116", env["code"]) + } + after := do(t, srv, http.MethodGet, "/films?select=id", nil) + if rows := decodeArray(t, after); len(rows) != 4 { + t.Errorf("rollback failed, %d rows remain, want 4", len(rows)) + } +} + +// TestPatchSingularOneRowCommits: a singular PATCH that affects exactly one row +// proceeds and persists, returning the single object. +func TestPatchSingularOneRowCommits(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.2", `{"rating":"X"}`, map[string]string{ + "Accept": "application/vnd.pgrst.object+json", + "Prefer": "return=representation", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + env := decodeEnvelope(t, resp) + if env["rating"] != "X" { + t.Errorf("body rating = %v, want X", env["rating"]) + } + after := do(t, srv, http.MethodGet, "/films?id=eq.2&select=rating", nil) + rows := decodeArray(t, after) + if len(rows) != 1 || rows[0]["rating"] != "X" { + t.Errorf("write did not persist: %v", rows) + } +} diff --git a/httpapi/timezone_test.go b/httpapi/timezone_test.go new file mode 100644 index 0000000..6cd91ac --- /dev/null +++ b/httpapi/timezone_test.go @@ -0,0 +1,58 @@ +package httpapi_test + +import ( + "net/http" + "strings" + "testing" +) + +// 02.3: Prefer: timezone= sets the request timezone. A valid IANA zone is +// honored and echoed in Preference-Applied; an invalid zone is ignored under +// lenient handling and a PGRST122 violation under handling=strict. The +// engine-agnostic parse/validate/echo is exercised here against sqlite; the +// SET LOCAL timezone effect on temporal output is a live-postgres concern. + +// TestTimeZoneEchoed: a GET carrying a valid timezone echoes it back. +func TestTimeZoneEchoed(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?select=id", map[string]string{ + "Prefer": "timezone=America/Los_Angeles", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if pa := resp.Header.Get("Preference-Applied"); !strings.Contains(pa, "timezone=America/Los_Angeles") { + t.Errorf("Preference-Applied = %q, want the timezone echoed", pa) + } +} + +// TestTimeZoneInvalidLenientIgnored: an unknown zone under the default lenient +// handling is dropped, not echoed, and the request still succeeds. +func TestTimeZoneInvalidLenientIgnored(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?select=id", map[string]string{ + "Prefer": "timezone=Mars/Phobos", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if pa := resp.Header.Get("Preference-Applied"); strings.Contains(pa, "timezone") { + t.Errorf("Preference-Applied = %q, want no timezone echo for an invalid zone", pa) + } +} + +// TestTimeZoneInvalidStrictRejected: an unknown zone under handling=strict is a +// 400 PGRST122 preference violation. +func TestTimeZoneInvalidStrictRejected(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?select=id", map[string]string{ + "Prefer": "handling=strict, timezone=Mars/Phobos", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + env := decodeEnvelope(t, resp) + if env["code"] != "PGRST122" { + t.Errorf("code = %v, want PGRST122", env["code"]) + } +} diff --git a/httpapi/timing.go b/httpapi/timing.go new file mode 100644 index 0000000..7164af0 --- /dev/null +++ b/httpapi/timing.go @@ -0,0 +1,92 @@ +package httpapi + +import ( + "context" + "net/http" + "strconv" + "strings" + "time" +) + +// phaseTimer accumulates per-phase durations for a single request and renders +// them as a Server-Timing header. PostgREST emits this header under +// server-timing-enabled with the phases jwt, parse, plan, transaction, and +// response; dbrest records the same names in that pipeline order. The zero +// value is unused: a request that does not enable timing carries a nil +// *phaseTimer, and every method is nil-safe so the handlers stay uncluttered. +type phaseTimer struct { + marks []phaseMark +} + +type phaseMark struct { + name string + dur time.Duration +} + +// mark records the time since start under name. A nil receiver (timing +// disabled) is a no-op, so a handler can call t.mark unconditionally. +func (t *phaseTimer) mark(name string, start time.Time) { + if t == nil { + return + } + t.marks = append(t.marks, phaseMark{name, time.Since(start)}) +} + +// header renders the accumulated marks as a Server-Timing value, durations in +// milliseconds, the encoding PostgREST uses. An empty timer yields "". +func (t *phaseTimer) header() string { + if t == nil || len(t.marks) == 0 { + return "" + } + var b strings.Builder + for i, m := range t.marks { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(m.name) + b.WriteString(";dur=") + b.WriteString(strconv.FormatFloat(float64(m.dur.Microseconds())/1000, 'f', -1, 64)) + } + return b.String() +} + +// timerKey is the private context key under which a request's phaseTimer +// travels from ServeHTTP into the handlers. +type timerKey struct{} + +// withTimer attaches a phaseTimer to a context. +func withTimer(ctx context.Context, t *phaseTimer) context.Context { + return context.WithValue(ctx, timerKey{}, t) +} + +// timerFrom returns the request's phaseTimer, or nil when timing is disabled. +func timerFrom(ctx context.Context) *phaseTimer { + t, _ := ctx.Value(timerKey{}).(*phaseTimer) + return t +} + +// timingWriter sets the Server-Timing header from its phaseTimer the first time +// the response is committed, so every exit path (success, error, plan) carries +// the header without each call site knowing about it. +type timingWriter struct { + http.ResponseWriter + timer *phaseTimer + wrote bool +} + +func (tw *timingWriter) WriteHeader(code int) { + if !tw.wrote { + tw.wrote = true + if h := tw.timer.header(); h != "" { + tw.ResponseWriter.Header().Set("Server-Timing", h) + } + } + tw.ResponseWriter.WriteHeader(code) +} + +func (tw *timingWriter) Write(b []byte) (int, error) { + if !tw.wrote { + tw.WriteHeader(http.StatusOK) + } + return tw.ResponseWriter.Write(b) +} diff --git a/httpapi/timing_test.go b/httpapi/timing_test.go new file mode 100644 index 0000000..594ecad --- /dev/null +++ b/httpapi/timing_test.go @@ -0,0 +1,53 @@ +package httpapi_test + +import ( + "net/http" + "regexp" + "strings" + "testing" +) + +// TestServerTimingAbsentByDefault checks dbrest matches a default PostgREST: no +// Server-Timing header until server-timing-enabled is set. +func TestServerTimingAbsentByDefault(t *testing.T) { + srv := newServer(t) + resp := do(t, srv, http.MethodGet, "/films?order=id", nil) + if got := resp.Header.Get("Server-Timing"); got != "" { + t.Errorf("Server-Timing = %q, want absent", got) + } +} + +// TestServerTimingEnabledOnRead checks an enabled server emits the documented +// phase names on a read. +func TestServerTimingEnabledOnRead(t *testing.T) { + srv := newServer(t) + srv.SetServerTimingEnabled(true) + resp := do(t, srv, http.MethodGet, "/films?order=id", nil) + got := resp.Header.Get("Server-Timing") + if got == "" { + t.Fatal("Server-Timing header missing") + } + for _, phase := range []string{"jwt", "parse", "plan", "transaction", "response"} { + if !strings.Contains(got, phase+";dur=") { + t.Errorf("Server-Timing %q missing phase %q", got, phase) + } + } + // Every phase carries a numeric millisecond duration. + if !regexp.MustCompile(`dur=\d`).MatchString(got) { + t.Errorf("Server-Timing %q has no numeric durations", got) + } +} + +// TestServerTimingEnabledOnError checks the header is present even when the +// request fails before a transaction, since the wrapper emits it on every exit. +func TestServerTimingEnabledOnError(t *testing.T) { + srv := newServer(t) + srv.SetServerTimingEnabled(true) + resp := do(t, srv, http.MethodGet, "/nonesuch?order=id", nil) + if resp.StatusCode == http.StatusOK { + t.Fatal("expected an error status for an unknown table") + } + if got := resp.Header.Get("Server-Timing"); got == "" || !strings.Contains(got, "jwt;dur=") { + t.Errorf("Server-Timing on error = %q, want a jwt phase", got) + } +} diff --git a/httpapi/txend_test.go b/httpapi/txend_test.go new file mode 100644 index 0000000..c8b9cd3 --- /dev/null +++ b/httpapi/txend_test.go @@ -0,0 +1,91 @@ +package httpapi_test + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/tamnd/dbrest/httpapi" +) + +// yearOf reads the year of a film for the tx= persistence checks. +func yearOf(t *testing.T, srv *httpapi.Server, id string) float64 { + t.Helper() + resp := do(t, srv, http.MethodGet, "/films?id=eq."+id+"&select=year", nil) + rows := decodeArray(t, resp) + if len(rows) != 1 { + t.Fatalf("want one row, got %d", len(rows)) + } + return rows[0]["year"].(float64) +} + +// TestTxRollbackIgnoredUnderDefaultCommit checks the default db-tx-end=commit +// ignores Prefer: tx=rollback: the write persists and tx= is not echoed, the +// PostgREST default-deployment behavior. +func TestTxRollbackIgnoredUnderDefaultCommit(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.1", `{"year":1900}`, map[string]string{ + "Prefer": "tx=rollback", + }) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + if applied := resp.Header.Get("Preference-Applied"); applied != "" { + t.Errorf("Preference-Applied = %q, want tx= not echoed", applied) + } + if got := yearOf(t, srv, "1"); got != 1900 { + t.Errorf("year after default-commit = %v, want 1900 (persisted)", got) + } +} + +// TestTxRollbackHonoredUnderOverride checks an allow-override policy honors +// tx=rollback: the write does not persist and tx= is echoed. +func TestTxRollbackHonoredUnderOverride(t *testing.T) { + srv := newServer(t) + srv.SetTxEnd("commit-allow-override") + resp := send(t, srv, http.MethodPatch, "/films?id=eq.1", `{"year":1901}`, map[string]string{ + "Prefer": "tx=rollback", + }) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + if applied := resp.Header.Get("Preference-Applied"); applied != "tx=rollback" { + t.Errorf("Preference-Applied = %q, want tx=rollback", applied) + } + if got := yearOf(t, srv, "1"); got != 1927 { + t.Errorf("year after rolled-back patch = %v, want 1927 (unchanged)", got) + } +} + +// TestTxDisallowedIsStrictOffender checks a tx= under handling=strict with the +// default commit policy is the PGRST122 invalid-preference error. +func TestTxDisallowedIsStrictOffender(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.1", `{"year":1902}`, map[string]string{ + "Prefer": "tx=rollback, handling=strict", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + var env struct{ Code string } + if err := json.NewDecoder(resp.Body).Decode(&env); err != nil { + t.Fatalf("decode: %v", err) + } + if env.Code != "PGRST122" { + t.Errorf("code = %q, want PGRST122", env.Code) + } +} + +// TestRollbackPolicyForcesRollback checks db-tx-end=rollback rolls a write back +// even with no tx= preference, the mode test deployments use. +func TestRollbackPolicyForcesRollback(t *testing.T) { + srv := newServer(t) + srv.SetTxEnd("rollback") + resp := send(t, srv, http.MethodPatch, "/films?id=eq.1", `{"year":1903}`, nil) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + if got := yearOf(t, srv, "1"); got != 1927 { + t.Errorf("year under rollback policy = %v, want 1927 (unchanged)", got) + } +} diff --git a/httpapi/types_test.go b/httpapi/types_test.go index cab6aa2..9c846a6 100644 --- a/httpapi/types_test.go +++ b/httpapi/types_test.go @@ -24,7 +24,8 @@ func TestFilterCoercionRejectsBadInteger(t *testing.T) { if body.Code != "22P02" { t.Errorf("code = %q, want 22P02", body.Code) } - if body.Message != `invalid input syntax for type int4: "abc"` { + // The type is spelled the way PostgreSQL's own message spells it. + if body.Message != `invalid input syntax for type integer: "abc"` { t.Errorf("message = %q", body.Message) } } diff --git a/httpapi/writeheaders_test.go b/httpapi/writeheaders_test.go new file mode 100644 index 0000000..33ba473 --- /dev/null +++ b/httpapi/writeheaders_test.go @@ -0,0 +1,210 @@ +package httpapi_test + +import ( + "io" + "net/http" + "strings" + "testing" +) + +// The write-response Content-Range is shaped by method, not by the return mode +// (02.8): POST and DELETE report the total-only "*/*" form ("*/N" with +// count=exact), PATCH reports the affected-row range, and PUT reports none. + +func TestPostContentRangeIsTotalOnly(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPost, "/films", `{"id":30,"title":"X"}`, map[string]string{ + "Prefer": "return=minimal", + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201", resp.StatusCode) + } + // Present even with return=minimal and no count. + if cr := resp.Header.Get("Content-Range"); cr != "*/*" { + t.Errorf("Content-Range = %q, want */*", cr) + } +} + +func TestPostContentRangeWithCount(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPost, "/films", `{"id":31,"title":"X"}`, map[string]string{ + "Prefer": "return=minimal, count=exact", + }) + if cr := resp.Header.Get("Content-Range"); cr != "*/1" { + t.Errorf("Content-Range = %q, want */1", cr) + } +} + +func TestDeleteContentRangeIsTotalOnly(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodDelete, "/films?id=eq.3", "", nil) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "*/*" { + t.Errorf("Content-Range = %q, want */*", cr) + } +} + +func TestDeleteContentRangeWithCount(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodDelete, "/films?id=eq.3", "", map[string]string{ + "Prefer": "count=exact", + }) + if cr := resp.Header.Get("Content-Range"); cr != "*/1" { + t.Errorf("Content-Range = %q, want */1", cr) + } +} + +func TestPatchContentRangeIsRowSpan(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.2", `{"rating":"PG"}`, map[string]string{ + "Prefer": "return=representation", + }) + if cr := resp.Header.Get("Content-Range"); cr != "0-0/*" { + t.Errorf("Content-Range = %q, want 0-0/*", cr) + } +} + +func TestPatchContentRangeWithCount(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?year=gte.1980", `{"rating":"PG"}`, map[string]string{ + "Prefer": "return=representation, count=exact", + }) + if cr := resp.Header.Get("Content-Range"); cr != "0-1/2" { + t.Errorf("Content-Range = %q, want 0-1/2", cr) + } +} + +// A PATCH on the minimal path still carries the row span, since Content-Range +// does not depend on the return mode. +func TestPatchMinimalStillCarriesContentRange(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.2", `{"rating":"PG"}`, map[string]string{ + "Prefer": "return=minimal", + }) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "0-0/*" { + t.Errorf("Content-Range = %q, want 0-0/*", cr) + } +} + +// A PATCH that matches no row reports the empty-span "*/*" form (and "*/0" with +// count=exact), not a negative range. +func TestPatchNoMatchContentRange(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?id=eq.999", `{"rating":"PG"}`, map[string]string{ + "Prefer": "return=representation", + }) + if cr := resp.Header.Get("Content-Range"); cr != "*/*" { + t.Errorf("Content-Range = %q, want */*", cr) + } + withCount := send(t, srv, http.MethodPatch, "/films?id=eq.999", `{"rating":"PG"}`, map[string]string{ + "Prefer": "return=representation, count=exact", + }) + if cr := withCount.Header.Get("Content-Range"); cr != "*/0" { + t.Errorf("Content-Range = %q, want */0", cr) + } +} + +// A PUT carries no Content-Range in any return mode (02.8, 02.9). +func TestPutHasNoContentRange(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPut, "/films?id=eq.2", `{"id":2,"title":"Blade Runner"}`, map[string]string{ + "Prefer": "return=representation", + }) + if cr := resp.Header.Get("Content-Range"); cr != "" { + t.Errorf("Content-Range = %q, want none on PUT", cr) + } +} + +// A PUT with no representation answers 204 with an empty body and no Location or +// Content-Range, the same for return=minimal, headers-only, and no preference +// (02.9). +func TestPutWithoutRepresentationIs204(t *testing.T) { + cases := []struct { + name string + prefer string + }{ + {"no-preference", ""}, + {"minimal", "return=minimal"}, + {"headers-only", "return=headers-only"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + srv := newServer(t) + headers := map[string]string{} + if c.prefer != "" { + headers["Prefer"] = c.prefer + } + resp := send(t, srv, http.MethodPut, "/films?id=eq.2", `{"id":2,"title":"Replaced"}`, headers) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if len(body) != 0 { + t.Errorf("body = %q, want empty", body) + } + if loc := resp.Header.Get("Location"); loc != "" { + t.Errorf("Location = %q, want none on PUT", loc) + } + if cr := resp.Header.Get("Content-Range"); cr != "" { + t.Errorf("Content-Range = %q, want none on PUT", cr) + } + // The replacement persisted. + after := do(t, srv, http.MethodGet, "/films?id=eq.2&select=title", nil) + rows := decodeArray(t, after) + if len(rows) != 1 || rows[0]["title"] != "Replaced" { + t.Errorf("after PUT = %v, want title Replaced", rows) + } + }) + } +} + +// An ignore-duplicates upsert that hits only existing rows inserts nothing, yet +// PostgREST still reports 201 (only merge-duplicates with zero inserted is 200). +func TestPostUpsertIgnoreDuplicatesExistingIs201(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPost, "/films", `{"id":1,"title":"Ignored"}`, map[string]string{ + "Prefer": "return=minimal, resolution=ignore-duplicates", + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201", resp.StatusCode) + } + // The existing row was left untouched. + after := do(t, srv, http.MethodGet, "/films?id=eq.1&select=title", nil) + rows := decodeArray(t, after) + if len(rows) != 1 || rows[0]["title"] != "Metropolis" { + t.Errorf("ignore-duplicates altered the row: %v", rows) + } +} + +// A merge-duplicates upsert whose batch mixes a new key with an existing one +// inserted at least one row, so the status is 201, not 200. +func TestPostUpsertMergeMixedBatchIs201(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPost, "/films", + `[{"id":1,"title":"Metropolis v2"},{"id":70,"title":"Fresh"}]`, map[string]string{ + "Prefer": "return=representation, resolution=merge-duplicates", + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201 for a mixed batch", resp.StatusCode) + } +} + +// A return=headers-only POST insert of a single row keeps its Location header, +// the one write that still carries one. +func TestPostHeadersOnlyKeepsLocation(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPost, "/films", `{"id":80,"title":"Located"}`, map[string]string{ + "Prefer": "return=headers-only", + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201", resp.StatusCode) + } + if loc := resp.Header.Get("Location"); !strings.Contains(loc, "id=eq.80") { + t.Errorf("Location = %q, want it to address id=eq.80", loc) + } +} diff --git a/httpapi/writerep_test.go b/httpapi/writerep_test.go new file mode 100644 index 0000000..3d0fb78 --- /dev/null +++ b/httpapi/writerep_test.go @@ -0,0 +1,126 @@ +package httpapi_test + +import ( + "net/http" + "testing" +) + +// 07.12: an empty payload is a no-op the server accepts, not a 400. A POST with +// an empty array inserts nothing and returns 201; a PATCH with an empty object +// (or the empty-array forms [] and [{}]) updates nothing and returns 204, or 200 +// with an empty representation when one is asked for. Either way no row changes +// and the write Content-Range stays */*. +func TestPostEmptyArrayInsertsNothing(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPost, "/films", `[]`, map[string]string{ + "Prefer": "return=representation", + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "*/*" { + t.Errorf("Content-Range = %q, want */*", cr) + } + if rows := decodeArray(t, resp); len(rows) != 0 { + t.Errorf("body = %v, want empty array", rows) + } + // The table is untouched. + after := do(t, srv, http.MethodGet, "/films", nil) + if all := decodeArray(t, after); len(all) != 4 { + t.Errorf("row count = %d, want 4 (nothing inserted)", len(all)) + } +} + +func TestPostEmptyArrayMinimalIs201(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPost, "/films", `[]`, nil) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("status = %d, want 201", resp.StatusCode) + } + buf := make([]byte, 1) + if n, _ := resp.Body.Read(buf); n != 0 { + t.Error("minimal insert should have no body") + } +} + +func TestPatchEmptyObjectIs204NoOp(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films", `{}`, nil) + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status = %d, want 204", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "*/*" { + t.Errorf("Content-Range = %q, want */*", cr) + } +} + +func TestPatchEmptyObjectRepresentationIs200EmptyArray(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films", `{}`, map[string]string{ + "Prefer": "return=representation", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if cr := resp.Header.Get("Content-Range"); cr != "*/*" { + t.Errorf("Content-Range = %q, want */*", cr) + } + if rows := decodeArray(t, resp); len(rows) != 0 { + t.Errorf("body = %v, want empty array", rows) + } +} + +// The empty-array forms of a PATCH body are the same no-op as the empty object. +func TestPatchEmptyArrayFormsAreNoOp(t *testing.T) { + for _, body := range []string{`[]`, `[{}]`} { + t.Run(body, func(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films", body, nil) + if resp.StatusCode != http.StatusNoContent { + t.Errorf("PATCH %s: status = %d, want 204", body, resp.StatusCode) + } + }) + } +} + +// A PATCH array carrying a non-empty object is not a shape upstream defines; it +// stays a 400 rather than silently updating. +func TestPatchNonEmptyArrayIs400(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films", `[{"rating":"X"}]`, nil) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } +} + +// 07.3: PostgREST v13 dropped limited update/delete, so order and limit on a +// PATCH shape only the returned representation, never the mutation. The body is +// the ordered, limited slice; Content-Range still reports the full affected set; +// and every matching row is written. +func TestPatchOrderLimitShapesBodyNotMutation(t *testing.T) { + srv := newServer(t) + resp := send(t, srv, http.MethodPatch, "/films?order=id.desc&limit=2", `{"rating":"X"}`, map[string]string{ + "Prefer": "return=representation", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + // Content-Range carries the full affected count (4 seed rows), not the 2 the + // body shows: the limit bounds the body, not the write. + if got := resp.Header.Get("Content-Range"); got != "0-3/*" { + t.Errorf("Content-Range = %q, want 0-3/* (full affected set)", got) + } + rows := decodeArray(t, resp) + if len(rows) != 2 { + t.Fatalf("body rows = %d, want 2", len(rows)) + } + if rows[0]["id"] != float64(4) || rows[1]["id"] != float64(3) { + t.Errorf("body ids = [%v %v], want [4 3] (id desc)", rows[0]["id"], rows[1]["id"]) + } + + // Every row was updated, including the two outside the body window. + after := do(t, srv, http.MethodGet, "/films?rating=eq.X&select=id", nil) + if all := decodeArray(t, after); len(all) != 4 { + t.Errorf("updated rows = %d, want 4 (the whole table)", len(all)) + } +} diff --git a/ir/call_test.go b/ir/call_test.go index 8040854..0e60af6 100644 --- a/ir/call_test.go +++ b/ir/call_test.go @@ -6,7 +6,7 @@ import ( ) func TestParseCallGetArgsFromQuery(t *testing.T) { - c, err := ParseCall("add_them", "a=2&b=3", nil, true, "", nil) + c, err := ParseCall("add_them", "a=2&b=3", nil, true, "", nil, "", "") if err != nil { t.Fatalf("ParseCall: %v", err) } @@ -25,7 +25,7 @@ func TestParseCallGetArgsFromQuery(t *testing.T) { } func TestParseCallGetReservedKeysArePostFilters(t *testing.T) { - c, err := ParseCall("list_films", "select=title&order=year.desc&limit=5&year=gte.2000", nil, true, "", nil) + c, err := ParseCall("list_films", "select=title&order=year.desc&limit=5&year=gte.2000", nil, true, "", nil, "", "") if err != nil { t.Fatalf("ParseCall: %v", err) } @@ -49,7 +49,7 @@ func TestParseCallGetReservedKeysArePostFilters(t *testing.T) { } func TestParseCallPostArgsFromBody(t *testing.T) { - c, err := ParseCall("add_them", "", nil, false, "application/json", []byte(`{"a":2,"b":3}`)) + c, err := ParseCall("add_them", "", nil, false, "application/json", []byte(`{"a":2,"b":3}`), "", "") if err != nil { t.Fatalf("ParseCall: %v", err) } @@ -64,7 +64,7 @@ func TestParseCallPostArgsFromBody(t *testing.T) { } func TestParseCallPostQueryStringIsPostFilter(t *testing.T) { - c, err := ParseCall("list_films", "year=gte.2000&order=year", nil, false, "application/json", []byte(`{"genre":"scifi"}`)) + c, err := ParseCall("list_films", "year=gte.2000&order=year", nil, false, "application/json", []byte(`{"genre":"scifi"}`), "", "") if err != nil { t.Fatalf("ParseCall: %v", err) } @@ -84,7 +84,7 @@ func TestParseCallPostQueryStringIsPostFilter(t *testing.T) { } func TestParseCallPostNoBody(t *testing.T) { - c, err := ParseCall("now", "", nil, false, "application/json", nil) + c, err := ParseCall("now", "", nil, false, "application/json", nil, "", "") if err != nil { t.Fatalf("ParseCall: %v", err) } @@ -94,7 +94,7 @@ func TestParseCallPostNoBody(t *testing.T) { } func TestParseCallCountPrefer(t *testing.T) { - c, err := ParseCall("list_films", "", []string{"count=exact"}, true, "", nil) + c, err := ParseCall("list_films", "", []string{"count=exact"}, true, "", nil, "", "") if err != nil { t.Fatalf("ParseCall: %v", err) } @@ -104,7 +104,56 @@ func TestParseCallCountPrefer(t *testing.T) { } func TestParseCallBadJSONBody(t *testing.T) { - if _, err := ParseCall("f", "", nil, false, "application/json", []byte(`{nope`)); err == nil { + if _, err := ParseCall("f", "", nil, false, "application/json", []byte(`{nope`), "", ""); err == nil { t.Error("malformed body should error") } } + +// TestParseCallRawBodyJSON checks the single-unnamed-parameter form: a JSON body +// binds whole to the named raw-body parameter, keeping its JSON type, rather than +// being read as an object of named arguments. +func TestParseCallRawBodyJSON(t *testing.T) { + c, err := ParseCall("echo", "", nil, false, "application/json", []byte(`{"a":1}`), "payload", "json") + if err != nil { + t.Fatalf("ParseCall: %v", err) + } + if len(c.Args) != 1 { + t.Fatalf("args = %v, want one raw-body arg", c.Args) + } + obj, ok := c.Args["payload"].JSON.(map[string]any) + if !ok || obj["a"] == nil { + t.Errorf("payload = %+v, want the whole JSON object", c.Args["payload"]) + } +} + +// TestParseCallRawBodyText checks a text body binds to the raw-body parameter as +// text under text/plain, the form an unnamed text parameter takes. +func TestParseCallRawBodyText(t *testing.T) { + c, err := ParseCall("shout", "", nil, false, "text/plain", []byte("hello"), "line", "text") + if err != nil { + t.Fatalf("ParseCall: %v", err) + } + if c.Args["line"].Text != "hello" || c.Args["line"].JSON != nil { + t.Errorf("line = %+v, want text hello", c.Args["line"]) + } +} + +// TestParseCallRawBodyOctetStream checks application/octet-stream binds the raw +// bytes to the parameter as text, the bytea-bound form. +func TestParseCallRawBodyOctetStream(t *testing.T) { + c, err := ParseCall("store", "", nil, false, "application/octet-stream", []byte{0x1, 0x2}, "blob", "bytea") + if err != nil { + t.Fatalf("ParseCall: %v", err) + } + if c.Args["blob"].Text != string([]byte{0x1, 0x2}) { + t.Errorf("blob = %+v, want the raw bytes", c.Args["blob"]) + } +} + +// TestParseCallRawBodyUnsupportedMedia checks a media type the raw-body binder +// does not accept is the caller's 415, not a silent empty bind. +func TestParseCallRawBodyUnsupportedMedia(t *testing.T) { + if _, err := ParseCall("store", "", nil, false, "image/png", []byte("x"), "blob", "bytea"); err == nil { + t.Error("an unsupported media type must reject the raw body") + } +} diff --git a/ir/fts_test.go b/ir/fts_test.go index 20bff26..125c1b1 100644 --- a/ir/fts_test.go +++ b/ir/fts_test.go @@ -90,15 +90,19 @@ func TestParseFTSNegated(t *testing.T) { } // TestQuantifierStillParses guards that splitting fts-config from the quantifier -// branch did not break op(any)/op(all) on the comparison operators. +// branch did not break op(any)/op(all) on the comparison operators. PostgREST +// spells the operand as a {…} list, which becomes Value.List (item 01.1). func TestQuantifierStillParses(t *testing.T) { - cmp := fetchCompare(t, "id=eq(any).1") + cmp := fetchCompare(t, "id=eq(any).{1,2}") if cmp.Op != OpEq { t.Errorf("Op = %v, want OpEq", cmp.Op) } if cmp.Quant != QAny { t.Errorf("Quant = %d, want QAny", cmp.Quant) } + if len(cmp.Value.List) != 2 || cmp.Value.List[0] != "1" || cmp.Value.List[1] != "2" { + t.Errorf("List = %v, want [1 2]", cmp.Value.List) + } } func TestUnknownQuantifierStillErrors(t *testing.T) { diff --git a/ir/ir.go b/ir/ir.go index a081d21..5796404 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -8,7 +8,11 @@ // backend; its errors are the PGRST1xx family. See spec 05-query-ir-and-planning. package ir -import "github.com/tamnd/dbrest/schema" +import ( + "net/url" + + "github.com/tamnd/dbrest/schema" +) // QueryKind is the operation a / request performs. type QueryKind uint8 @@ -47,19 +51,51 @@ type Root struct { // Query is a /
request. type Query struct { - Kind QueryKind - Relation Ref - Select []SelectItem - Where *Cond - Order []OrderTerm - Limit *int - Offset *int - Embeds []Embed - Write *WriteSpec // non-nil for Insert/Update/Upsert/Delete - Singular bool - Count CountKind + Kind QueryKind + Relation Ref + Select []SelectItem + Where *Cond + Order []OrderTerm + Limit *int + Offset *int + Embeds []Embed + Write *WriteSpec // non-nil for Insert/Update/Upsert/Delete + Singular bool + Count CountKind + // CountMax is the db-max-rows threshold an estimated count crosses over at: a + // backend that supports estimation runs the exact count while the result stays + // at or below it and falls back to the planner estimate above it. Zero means no + // threshold was configured. It is only meaningful with Count == CountEstimated. + CountMax int64 Prefer PreferSet FromRange bool // limit/offset came from the Range request header, not ?limit= + IsPut bool // the request method was PUT, so PUT upsert validations apply + // Computed maps each of this relation's computed-field names to the schema of + // the function that backs it. The planner fills it from the resolved relation + // so the compiler can render a selected, filtered, or ordered computed field as + // a function call on the row instead of a bare column. Nil when the relation has + // no computed fields. Each embed carries its own map for its own relation. + Computed map[string]string + // Reps maps a column name to the cast functions that drive its data + // representation (PostgREST domain representations, spec 11). The planner fills + // it from the resolved relation so the compiler reformats the column through the + // domain's casts: ToJSON on read output, FromJSON on a write value, FromText on a + // filter literal. Nil when no column of the relation carries one; each embed + // carries its own map for its own relation. + Reps map[string]Rep +} + +// Rep carries the cast functions that drive a column's data representation +// (PostgREST domain representations, spec 11): a domain type with casts to and +// from json/text whose functions reformat the wire value. Each field is the +// schema-qualified function backing that direction, or empty when the domain +// declares no cast there. ToJSON formats the stored value for a response, +// FromText parses a query-string filter literal, FromJSON parses a write-body +// value. +type Rep struct { + ToJSONSchema, ToJSONFunc string + FromTextSchema, FromTextFunc string + FromJSONSchema, FromJSONFunc string } // Call is a /rpc/ request. @@ -72,9 +108,19 @@ type Call struct { Order []OrderTerm Limit *int Offset *int + // Embeds are the embedded resources requested on a function returning rows of + // a known relation (/rpc/f?select=id,client(*)). They resolve against the + // function's result relation exactly as a table read's embeds do; a call over + // a function with no relation return carries none. + Embeds []Embed Singular bool Count CountKind Prefer PreferSet + // RawGet holds a GET call's non-reserved query parameters before the + // argument-versus-filter split, which needs the resolved function's + // parameter names. PartitionGetArgs consumes it once the planner knows the + // signature. It is nil on a POST call. + RawGet url.Values } // RootSpec is a GET / request: render the OpenAPI document for a schema. @@ -105,6 +151,43 @@ type Column struct { func (Column) isSelect() {} +// ProjectedColumns returns the distinct base column names a plain select list +// names, in select order, so a write's representation reads back only the +// columns the client asked for instead of the whole row. It returns nil when +// the projection is not a simple base-column list (empty, a "*", an aggregate, +// or an embed present), telling the caller to fall back to every column. A +// column carrying an alias, a cast, or a JSON sub-path also forces the fallback, +// because the bare RETURNING/OUTPUT path cannot reshape those (that reshaping is +// the deferred write-representation embed work, item 01.19). +func (q *Query) ProjectedColumns() []string { + if len(q.Select) == 0 || len(q.Embeds) > 0 { + return nil + } + out := make([]string, 0, len(q.Select)) + seen := make(map[string]bool, len(q.Select)) + for _, it := range q.Select { + col, ok := it.(Column) + if !ok { + return nil // an aggregate or an embed reference + } + if len(col.Path) != 1 || col.Last != JSONNone || col.Cast != "" || col.Alias != "" { + return nil + } + name := col.Path[0] + if name == "*" { + return nil + } + if !seen[name] { + seen[name] = true + out = append(out, name) + } + } + if len(out) == 0 { + return nil + } + return out +} + // Name returns the output key for the column: its alias if set, else the last // path element. func (c Column) Name() string { @@ -128,12 +211,43 @@ const ( AggMax ) -// Aggregate is a column aggregate in the select list. +// Aggregate is a column aggregate in the select list. Cast is an output cast on +// the aggregate result; an input cast on the aggregated column rides on Arg.Cast. +// Legacy marks the pre-v12 bare `count` an embed select may carry: it renders a +// count of the embedded rows and is exempt from the db-aggregates-enabled gate, +// where the count()/col.agg() function forms are not. type Aggregate struct { - Func AggFunc - Arg *Column // nil for count(*) - Cast string - Alias string + Func AggFunc + Arg *Column // nil for count() + Cast string + Alias string + Legacy bool +} + +// Name is the response key an aggregate renders under: its explicit alias, else +// the function name (sum, avg, count, min, max), matching PostgREST's default. +func (a Aggregate) Name() string { + if a.Alias != "" { + return a.Alias + } + return a.Func.String() +} + +// String spells an aggregate function the way it appears in SQL and as the +// default response key. +func (f AggFunc) String() string { + switch f { + case AggSum: + return "sum" + case AggAvg: + return "avg" + case AggMin: + return "min" + case AggMax: + return "max" + default: + return "count" + } } func (Aggregate) isSelect() {} @@ -176,6 +290,12 @@ type Embed struct { Target Ref // the embedded relation as written; resolved at plan time Query Query Rel *schema.Relationship + + // EmptySelect records that the embed was written with empty parentheses, + // e.g. client(). PostgREST joins such a relation for filtering but omits its + // key from the output entirely, which an absent or rel(*) select does not. + // This distinguishes "no column list" from "select every column". + EmptySelect bool } // Cond is a node in the filter tree. @@ -196,6 +316,22 @@ type Not struct{ Kid Cond } func (Not) isCond() {} +// EmbedPredicate filters the parent on the existence of an embedded resource's +// rows. It is what an `embed=is.null` / `embed=not.is.null` filter lowers to: +// the planner reclassifies a Compare whose single-segment path names an embed's +// OutKey and whose operator is `is null` into this node, so the compiler can +// emit a semi/anti join instead of rejecting an unknown parent column. +// +// Index points into the owning Query's Embeds. Exists is true for not.is.null +// (the parent must have a matching embedded row, a semi-join / EXISTS) and false +// for is.null (it must have none, an anti-join / NOT EXISTS). See spec 09. +type EmbedPredicate struct { + Index int + Exists bool +} + +func (EmbedPredicate) isCond() {} + // FTSVariant selects the full-text query grammar of an fts predicate, one per // PostgREST operator. Parsing records the variant; each backend maps it onto its // own full-text query language (spec 21). @@ -211,6 +347,7 @@ const ( // Compare is a single column-operator-value predicate. type Compare struct { Path []string + Last JSONStep // final JSON hop kind when Path carries a -> / ->> sub-path Op Op Value Value Quant Quant @@ -222,6 +359,11 @@ type Compare struct { FTS FTSVariant Config string FullText *schema.FullTextIndex + // ColumnType is the canonical type of the column at Path[0], resolved by + // the planner from the schema. The dialect uses it to decide whether an + // engine-specific operator (e.g. json_each for array ops on SQLite) can + // apply; it is empty when the column is unknown or for multi-step paths. + ColumnType string } func (Compare) isCond() {} @@ -277,8 +419,17 @@ type Value struct { // OrderTerm is one entry in the order list. type OrderTerm struct { Path []string + Last JSONStep // final JSON hop kind when Path carries a -> / ->> sub-path Desc bool NullsFirst *bool // nil = PG default (NULLS LAST asc, NULLS FIRST desc) + + // Rel names an embedded resource when the term is order=rel(col): the parent + // is ordered by a column of a to-one embed, with Path/Last addressing the + // column inside that relation. Empty for an ordinary parent-column term. The + // name is the embed's written spelling or alias (client in client(name)); + // the planner resolves it to a relationship and the compiler lowers it as a + // correlated scalar subquery. + Rel string } // WriteSpec carries the mutation payload and options (spec 11). @@ -291,14 +442,22 @@ type WriteSpec struct { Return ReturnMode MaxRows *int64 Tx TxMode + // ColumnTypes is the canonical type of each written column, resolved by the + // planner from the relation. The compiler uses it to decide how a JSON array + // payload value lands: a json/jsonb column takes JSON text, an array column + // takes a PostgreSQL array literal. It is empty for backends or paths that do + // not resolve a schema. + ColumnTypes map[string]string } // MissingMode is the Prefer: missing= behavior for absent payload columns. type MissingMode uint8 +// MissingNull is the zero value because PostgREST inserts SQL NULL for payload +// columns a row omits; Prefer: missing=default is the opt-in for column DEFAULTs. const ( - MissingDefault MissingMode = iota - MissingNull + MissingNull MissingMode = iota + MissingDefault ) // Conflict describes an upsert conflict resolution. @@ -332,3 +491,32 @@ const ( TxCommit TxRollback ) + +// TxEnd is the db-tx-end server policy that governs whether a request may +// override the transaction outcome with Prefer: tx=. The two allow-override +// variants honor the preference; the two fixed variants ignore it and force +// their outcome server-side. +type TxEnd uint8 + +const ( + TxEndCommit TxEnd = iota + TxEndCommitAllowOverride + TxEndRollback + TxEndRollbackAllowOverride +) + +// ParseTxEnd maps a db-tx-end option string to a TxEnd. An empty or unknown +// value is the default commit, matching the config default; the config layer +// validates the spelling before this point. +func ParseTxEnd(s string) TxEnd { + switch s { + case "commit-allow-override": + return TxEndCommitAllowOverride + case "rollback": + return TxEndRollback + case "rollback-allow-override": + return TxEndRollbackAllowOverride + default: + return TxEndCommit + } +} diff --git a/ir/parse.go b/ir/parse.go index 08f65dd..20d11fc 100644 --- a/ir/parse.go +++ b/ir/parse.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/csv" "encoding/json" + "errors" "net/url" "slices" "sort" @@ -29,6 +30,9 @@ func ParseRead(relation, rawQuery string, preferHeaders []string) (*Query, *pger } q := &Query{Kind: Read, Relation: Ref{Name: relation}} q.Prefer = ParsePrefer(preferHeaders) + if perr := q.Prefer.StrictError(); perr != nil { + return nil, perr + } if q.Prefer.Count != nil { q.Count = *q.Prefer.Count } @@ -43,8 +47,14 @@ func ParseRead(relation, rawQuery string, preferHeaders []string) (*Query, *pger // limit/offset window, and the horizontal-filter tree. A write uses the filter // tree as its WHERE and the select list as its returning projection. func parseQueryString(q *Query, vals url.Values) *pgerr.APIError { - if sel := vals.Get("select"); sel != "" { - items, embeds, perr := parseSelect(sel) + // An omitted select defaults to all columns; an explicitly empty select= is a + // parse error, matching PostgREST (item 01.5). + if vals.Has("select") { + sel := vals.Get("select") + if sel == "" { + return pgerr.ErrParse("\"failed to parse select parameter ()\" (line 1, column 1)") + } + items, embeds, perr := parseSelect(sel, false) if perr != nil { return perr } @@ -66,15 +76,14 @@ func applyParams(q *Query, vals url.Values) *pgerr.APIError { if key == "select" { continue // consumed by the caller / by the embed parens } - if i := strings.IndexByte(key, '.'); i >= 0 { - if idx := findEmbed(q.Embeds, key[:i]); idx >= 0 { - prefix := key[:i] - ev := scoped[prefix] + if head, rest, ok := cutIdentAware(key, '.'); ok { + if idx := findEmbed(q.Embeds, head); idx >= 0 { + ev := scoped[head] if ev == nil { ev = url.Values{} - scoped[prefix] = ev + scoped[head] = ev } - ev[key[i+1:]] = vs + ev[rest] = vs continue } } @@ -90,9 +99,15 @@ func applyParams(q *Query, vals url.Values) *pgerr.APIError { } if lim := self.Get("limit"); lim != "" { n, e := strconv.Atoi(lim) - if e != nil || n < 0 { + if e != nil { return pgerr.ErrParse("limit must be a non-negative integer") } + if n < 0 { + // A well-formed but negative limit is PostgREST's 416 PGRST103, not a + // parse error: the requested range cannot be satisfied. + return pgerr.ErrRangeNotSatisfiable(). + WithDetails("Limit should be greater than or equal to zero.") + } q.Limit = &n } if off := self.Get("offset"); off != "" { @@ -139,8 +154,13 @@ func ParseWrite(kind QueryKind, relation, rawQuery string, preferHeaders []strin if err != nil { return nil, pgerr.ErrParse("could not parse query string") } - q := &Query{Kind: kind, Relation: Ref{Name: relation}} + // PUT is the only method the router maps to Upsert; capture it before the + // promotion below can also turn a POST into an upsert. + q := &Query{Kind: kind, Relation: Ref{Name: relation}, IsPut: kind == Upsert} q.Prefer = ParsePrefer(preferHeaders) + if perr := q.Prefer.StrictError(); perr != nil { + return nil, perr + } if q.Prefer.Count != nil { q.Count = *q.Prefer.Count } @@ -158,12 +178,17 @@ func ParseWrite(kind QueryKind, relation, rawQuery string, preferHeaders []strin if q.Prefer.Tx != nil { w.Tx = *q.Prefer.Tx } + // max-affected (strict-only; ParsePrefer cleared it under lenient) bounds the + // affected-row count the backend will tolerate before rolling back. + w.MaxRows = q.Prefer.MaxAffected - // An on_conflict target or a resolution preference makes this an upsert; PUT - // is always an upsert. The conflict target defaults to the primary key, - // which the planner fills in. + // PostgREST performs an upsert only for PUT or for a POST carrying a + // Prefer: resolution= preference. on_conflict alone leaves a POST a plain + // insert (a duplicate key then fails with 409), and both on_conflict and + // resolution are ignored entirely for PATCH and DELETE. The conflict target + // defaults to the primary key, which the planner fills in. onConflict := vals.Get("on_conflict") - if kind == Upsert || onConflict != "" || q.Prefer.Resolution != nil { + if q.IsPut || (kind == Insert && q.Prefer.Resolution != nil) { q.Kind = Upsert c := &Conflict{} if onConflict != "" { @@ -181,6 +206,14 @@ func ParseWrite(kind QueryKind, relation, rawQuery string, preferHeaders []strin if perr != nil { return nil, perr } + // Without an explicit columns= override, PostgREST requires every object + // in a bulk JSON array to carry exactly the first object's keys; columns= + // switches to RawJSON semantics and skips the check (item 01.15). + if vals.Get("columns") == "" && bodyFormat(contentType) == fmtJSON { + if perr := checkUniformKeys(objs); perr != nil { + return nil, perr + } + } w.Rows, w.Columns = buildInsert(objs, vals.Get("columns"), header) case Update: obj, perr := decodeBodyObject(contentType, body) @@ -212,13 +245,22 @@ var callReserved = map[string]bool{ // arguments (with their JSON types) and the whole query string post-filters. The // planner resolves the function and checks volatility against the method. All // errors are PGRST1xx. See spec 12-rpc. -func ParseCall(fn, rawQuery string, preferHeaders []string, isGet bool, contentType string, body []byte) (*Call, *pgerr.APIError) { +// +// rawBodyParam names the single parameter of an unnamed-argument function, when +// the resolved name is one; for such a function the whole POST body is bound to +// that parameter by Content-Type (rawBodyType is its declared type) rather than +// decoded as a JSON object of named arguments. It is "" for the ordinary +// named-argument form and on GET. +func ParseCall(fn, rawQuery string, preferHeaders []string, isGet bool, contentType string, body []byte, rawBodyParam, rawBodyType string) (*Call, *pgerr.APIError) { vals, err := url.ParseQuery(rawQuery) if err != nil { return nil, pgerr.ErrParse("could not parse query string") } c := &Call{Function: Ref{Name: fn}} c.Prefer = ParsePrefer(preferHeaders) + if perr := c.Prefer.StrictError(); perr != nil { + return nil, perr + } if c.Prefer.Count != nil { c.Count = *c.Prefer.Count } @@ -230,14 +272,20 @@ func ParseCall(fn, rawQuery string, preferHeaders []string, isGet bool, contentT if isGet { post := url.Values{} + raw := url.Values{} for k, vs := range vals { if callReserved[k] { post[k] = vs continue } - // A function argument; the last value wins, matching url.Values.Get. + // A candidate argument: the last value wins for the argument binding + // (matching url.Values.Get), but every value is retained on RawGet so the + // planner can re-read a key that turns out to be a filter, or collect a + // variadic parameter's repeats, once the signature is known. + raw[k] = vs args[k] = Value{Text: vs[len(vs)-1]} } + c.RawGet = raw if perr := parseQueryString(pq, post); perr != nil { return nil, perr } @@ -245,7 +293,15 @@ func ParseCall(fn, rawQuery string, preferHeaders []string, isGet bool, contentT if perr := parseQueryString(pq, vals); perr != nil { return nil, perr } - if len(body) > 0 { + if rawBodyParam != "" { + // Single-unnamed-parameter form: the whole body is the one argument, + // decoded by Content-Type rather than read as a named-argument object. + v, perr := bindRawBody(contentType, body, rawBodyType) + if perr != nil { + return nil, perr + } + args[rawBodyParam] = v + } else if len(body) > 0 { obj, perr := decodeBodyObject(contentType, body) if perr != nil { return nil, perr @@ -257,10 +313,54 @@ func ParseCall(fn, rawQuery string, preferHeaders []string, isGet bool, contentT } c.Select, c.Where, c.Order, c.Limit, c.Offset = pq.Select, pq.Where, pq.Order, pq.Limit, pq.Offset + c.Embeds = pq.Embeds c.Args = args return c, nil } +// PartitionGetArgs splits a GET /rpc call's query parameters into function +// arguments and post-filters once the resolved function's parameter names are +// known. A key naming a declared parameter stays an argument; every other key is +// re-read through the filter grammar and merged into the call's WHERE, matching +// how PostgREST treats a query key that does not name a parameter as a filter on +// a table-valued result. It is a no-op on a POST call, where the body carries the +// arguments and the query string already post-filtered. +func (c *Call) PartitionGetArgs(isParam func(string) bool, isVariadic func(string) bool) *pgerr.APIError { + if c.RawGet == nil { + return nil + } + filters := url.Values{} + for k, vs := range c.RawGet { + if isParam(k) { + // A variadic parameter collects every repeat of its key as a list; a + // scalar parameter already took the last value in ParseCall. + if isVariadic(k) { + c.Args[k] = Value{List: append([]string(nil), vs...)} + } + continue + } + delete(c.Args, k) + filters[k] = vs + } + if len(filters) == 0 { + return nil + } + cond, perr := parseFilters(filters) + if perr != nil { + return perr + } + if cond == nil { + return nil + } + if c.Where == nil { + c.Where = cond + return nil + } + merged := Cond(And{Kids: []Cond{*c.Where, *cond}}) + c.Where = &merged + return nil +} + // writeFormat is the request body encoding selected by Content-Type (spec 17). type writeFormat int @@ -312,18 +412,75 @@ func decodeBodyObjects(contentType string, body []byte) ([]map[string]any, []str } } +// bindRawBody binds the whole request body to a single unnamed parameter, +// decoded by Content-Type the way PostgREST routes a single-argument function: +// application/json (or an empty type) carries any JSON value, object or array or +// scalar; text/plain and text/xml carry the raw text; application/octet-stream +// carries the raw bytes as text. A content type with no raw-body decoder is the +// unsupported-media-type error, the same one the named-object path raises. +func bindRawBody(contentType string, body []byte, declaredType string) (Value, *pgerr.APIError) { + ct := strings.ToLower(strings.TrimSpace(contentType)) + if i := strings.IndexByte(ct, ';'); i >= 0 { + ct = strings.TrimSpace(ct[:i]) + } + switch ct { + case "", "application/json": + dec := json.NewDecoder(bytes.NewReader(body)) + dec.UseNumber() + var v any + if err := dec.Decode(&v); err != nil { + return Value{}, pgerr.ErrInvalidBody("") + } + return Value{JSON: v}, nil + case "text/plain", "text/xml", "application/xml": + return Value{Text: string(body)}, nil + case "application/octet-stream": + // A bytea parameter receives the raw bytes; they ride as text here and the + // engine binds them, with exact bytea typing left to the types subsystem. + return Value{Text: string(body)}, nil + default: + return Value{}, pgerr.ErrUnsupportedMediaType(contentType) + } +} + // decodeBodyObject decodes an update body into a single object of column -// assignments. CSV is not a meaningful patch format and is rejected. +// assignments. PostgREST accepts CSV for PATCH as well as POST, so a single-row +// CSV body is decoded with the same NULL rule the insert path uses. func decodeBodyObject(contentType string, body []byte) (map[string]any, *pgerr.APIError) { switch bodyFormat(contentType) { case fmtJSON: dec := json.NewDecoder(bytes.NewReader(body)) dec.UseNumber() - var obj map[string]any - if err := dec.Decode(&obj); err != nil { - return nil, pgerr.ErrParse("update body must be a JSON object") + var raw any + if err := dec.Decode(&raw); err != nil { + return nil, pgerr.ErrInvalidBody("") } - return obj, nil + switch v := raw.(type) { + case map[string]any: + return v, nil + case []any: + // PostgREST accepts the empty-array forms [] and [{}] for PATCH as a + // no-op update (an empty column set). An array carrying a non-empty + // object is not a shape upstream defines for update, so it stays a 400. + for _, e := range v { + obj, ok := e.(map[string]any) + if !ok || len(obj) > 0 { + return nil, pgerr.ErrInvalidBody("All object keys must match") + } + } + return map[string]any{}, nil + default: + return nil, pgerr.ErrInvalidBody("") + } + case fmtCSV: + objs, _, perr := decodeCSVObjects(body) + if perr != nil { + return nil, perr + } + if len(objs) != 1 { + return nil, pgerr.ErrInvalidBody("CSV update body must have exactly one data row") + } + return objs[0], nil case fmtForm: return decodeFormObject(body) default: @@ -338,7 +495,7 @@ func decodeJSONObjects(body []byte) ([]map[string]any, *pgerr.APIError) { dec.UseNumber() var raw any if err := dec.Decode(&raw); err != nil { - return nil, pgerr.ErrParse("request body is not valid JSON") + return nil, pgerr.ErrInvalidBody("") } switch v := raw.(type) { case map[string]any: @@ -348,38 +505,47 @@ func decodeJSONObjects(body []byte) ([]map[string]any, *pgerr.APIError) { for _, e := range v { obj, ok := e.(map[string]any) if !ok { - return nil, pgerr.ErrParse("insert array must contain objects") + return nil, pgerr.ErrInvalidBody("All object keys must match") } objs = append(objs, obj) } return objs, nil default: - return nil, pgerr.ErrParse("insert body must be an object or an array of objects") + return nil, pgerr.ErrInvalidBody("") } } // decodeCSVObjects parses an RFC 4180 body into row objects keyed by the header -// row. An empty field decodes to SQL NULL, matching PostgREST's default CSV null -// handling (Go's csv reader does not distinguish a quoted empty string from an -// unquoted empty field, so both map to null). +// row, with PostgREST's CSV semantics: the unquoted literal string NULL becomes +// SQL null and every other field, including an empty cell, becomes a string (an +// empty cell inserts an empty string). Go's csv reader enforces a uniform field +// count against the header, so a ragged row surfaces as PGRST102 "All lines must +// have same number of fields". func decodeCSVObjects(body []byte) ([]map[string]any, []string, *pgerr.APIError) { r := csv.NewReader(bytes.NewReader(body)) recs, err := r.ReadAll() if err != nil { - return nil, nil, pgerr.ErrParse("malformed CSV body") + var pe *csv.ParseError + if errors.As(err, &pe) && pe.Err == csv.ErrFieldCount { + return nil, nil, pgerr.ErrInvalidBody("All lines must have same number of fields") + } + return nil, nil, pgerr.ErrInvalidBody("malformed CSV body") } if len(recs) == 0 { - return nil, nil, pgerr.ErrParse("CSV body has no header row") + return nil, nil, pgerr.ErrInvalidBody("CSV body has no header row") } header := recs[0] objs := make([]map[string]any, 0, len(recs)-1) for _, rec := range recs[1:] { obj := make(map[string]any, len(header)) for i, h := range header { - if i < len(rec) && rec[i] != "" { - obj[h] = rec[i] - } else { + switch { + case i >= len(rec): + obj[h] = nil + case rec[i] == "NULL": obj[h] = nil + default: + obj[h] = rec[i] } } objs = append(objs, obj) @@ -392,7 +558,7 @@ func decodeCSVObjects(body []byte) ([]map[string]any, []string, *pgerr.APIError) func decodeFormObject(body []byte) (map[string]any, *pgerr.APIError) { vals, err := url.ParseQuery(string(body)) if err != nil { - return nil, pgerr.ErrParse("malformed form body") + return nil, pgerr.ErrInvalidBody("malformed form body") } obj := make(map[string]any, len(vals)) for k, v := range vals { @@ -403,6 +569,28 @@ func decodeFormObject(body []byte) (map[string]any, *pgerr.APIError) { return obj, nil } +// checkUniformKeys enforces PostgREST's rule that every object in a bulk insert +// shares the first object's exact key set; a mismatch is PGRST102 "All object +// keys must match". A single object (or none) is trivially uniform. The columns= +// parameter overrides the rule, so the caller skips this when it is present. +func checkUniformKeys(objs []map[string]any) *pgerr.APIError { + if len(objs) < 2 { + return nil + } + first := objs[0] + for _, obj := range objs[1:] { + if len(obj) != len(first) { + return pgerr.ErrInvalidBody("All object keys must match") + } + for k := range first { + if _, ok := obj[k]; !ok { + return pgerr.ErrInvalidBody("All object keys must match") + } + } + } + return nil +} + // buildInsert turns decoded objects into write rows and resolves the column set. // The column order is the explicit columns= parameter when present, else the CSV // header order, else the sorted keys of the first row (matching PostgREST: later @@ -455,7 +643,7 @@ func splitComma(s string) []string { // parseSelect parses the comma-separated select list at the top level. An item // containing a parenthesis is an embed (rel(...)); plain items are columns, // optionally alias:col::cast. "*" selects all columns. -func parseSelect(s string) ([]SelectItem, []Embed, *pgerr.APIError) { +func parseSelect(s string, nested bool) ([]SelectItem, []Embed, *pgerr.APIError) { // PostgREST treats a bare "*" as "all columns" — equivalent to omitting // the select parameter entirely. We normalise it to an empty list here so // the planner and compiler see no explicit projection. @@ -474,6 +662,15 @@ func parseSelect(s string) ([]SelectItem, []Embed, *pgerr.APIError) { return nil, nil, pgerr.ErrParse("empty item in select list") } if i := strings.IndexByte(raw, '('); i >= 0 { + // An item with empty parens is an aggregate (count(), amount.sum()); + // anything else is an embedded resource. The aggregate functions are a + // closed set, so a name(...) that is not one falls through to the embed. + if agg, ok, perr := parseAggregate(raw); perr != nil { + return nil, nil, perr + } else if ok { + items = append(items, agg) + continue + } emb, perr := parseEmbed(raw, i) if perr != nil { return nil, nil, perr @@ -482,10 +679,12 @@ func parseSelect(s string) ([]SelectItem, []Embed, *pgerr.APIError) { embeds = append(embeds, emb) continue } - // PostgREST supports a bare "count" inside an embed select as a virtual - // aggregate that maps to count(*) in the JSON output. - if raw == "count" { - items = append(items, Aggregate{Func: AggCount}) + // Inside an embed select, a bare "count" is the legacy virtual aggregate that + // maps to count(*) in the JSON output; it predates the count() form and is + // exempt from the db-aggregates-enabled gate. At the top level "count" is an + // ordinary column reference (PostgREST v12+). + if nested && raw == "count" { + items = append(items, Aggregate{Func: AggCount, Legacy: true}) continue } col, perr := parseColumnItem(raw) @@ -497,6 +696,103 @@ func parseSelect(s string) ([]SelectItem, []Embed, *pgerr.APIError) { return items, embeds, nil } +// aggFuncByName maps the PostgREST aggregate spellings to their IR function. +var aggFuncByName = map[string]AggFunc{ + "count": AggCount, "sum": AggSum, "avg": AggAvg, "min": AggMin, "max": AggMax, +} + +// parseAggregate recognizes the aggregate forms count() and path.func(), each +// with an optional response-key alias, an optional input cast on the aggregated +// column, and an optional output cast on the result. It reports ok=false (no +// error) when raw is not an aggregate so the caller can treat it as an embed. +func parseAggregate(raw string) (Aggregate, bool, *pgerr.APIError) { + // The function call is always empty parens. Their absence rules out an + // aggregate immediately; a non-empty pair means an embedded resource. + head, tail, found := strings.Cut(raw, "()") + if !found { + return Aggregate{}, false, nil + } + + var agg Aggregate + // Output cast trails the parens as ::type. + if tail != "" { + if !strings.HasPrefix(tail, "::") { + return Aggregate{}, false, nil + } + agg.Cast = tail[2:] + if agg.Cast == "" { + return Aggregate{}, false, pgerr.ErrParse("empty cast target") + } + if !validCastType(agg.Cast) { + return Aggregate{}, false, pgerr.ErrParse("invalid cast target " + agg.Cast) + } + } + // Strip a response-key alias: the leading name before a single ':' that is not + // part of a '::' cast and not inside quotes. + if alias, rest, ok := cutAliasAware(head); ok { + agg.Alias = unquoteIdent(alias) + head = rest + } + // The function name is the token after the last dot; no dot means the whole + // head is the function, which is only valid for the no-argument count(). + fn := head + argSpec := "" + if dot := strings.LastIndexByte(head, '.'); dot >= 0 { + fn = head[dot+1:] + argSpec = head[:dot] + } + f, ok := aggFuncByName[fn] + if !ok { + // An unknown function name with empty parens is not an aggregate; let the + // caller try it as an embed. + return Aggregate{}, false, nil + } + agg.Func = f + if argSpec == "" { + if f != AggCount { + return Aggregate{}, false, pgerr.ErrParse(fn + "() requires a column argument") + } + return agg, true, nil + } + arg, perr := parseColumnItem(argSpec) + if perr != nil { + return Aggregate{}, false, perr + } + agg.Arg = &arg + return agg, true, nil +} + +// cutAliasAware splits a select-item head on the alias colon: the first ':' that +// is a single colon (not part of a '::' cast) and lies outside double quotes. It +// returns ok=false when there is no such colon. +func cutAliasAware(s string) (alias, rest string, ok bool) { + inQuote := false + for i := 0; i < len(s); i++ { + c := s[i] + if inQuote { + if c == '\\' && i+1 < len(s) { + i++ + continue + } + if c == '"' { + inQuote = false + } + continue + } + switch c { + case '"': + inQuote = true + case ':': + if i+1 < len(s) && s[i+1] == ':' { + i++ // skip the cast '::' + continue + } + return s[:i], s[i+1:], true + } + } + return "", s, false +} + // parseEmbed parses rel(...) including an optional alias and hint. The inner // select is parsed recursively so the IR is complete; the planner resolves the // relationship. @@ -513,16 +809,27 @@ func parseEmbed(raw string, lparen int) (Embed, *pgerr.APIError) { emb.Alias = head[:c] head = head[c+1:] } - if b := strings.IndexByte(head, '!'); b >= 0 { - hint := head[b+1:] - head = head[:b] - switch hint { - case "inner": - emb.Join = JoinInner - case "left": - emb.Join = JoinLeft - default: - emb.Hint = hint + // A head may carry both a disambiguation hint and a join modifier, in either + // order (rel!hint!inner or rel!inner!hint). Split on every `!`: the first + // segment is the relation, each later one is "inner"/"left" (the join) or a + // hint. Two hints are a grammar error. + if strings.IndexByte(head, '!') >= 0 { + segs := strings.Split(head, "!") + head = segs[0] + sawHint := false + for _, seg := range segs[1:] { + switch seg { + case "inner": + emb.Join = JoinInner + case "left": + emb.Join = JoinLeft + default: + if sawHint { + return Embed{}, pgerr.ErrParse("embed carries more than one disambiguation hint") + } + emb.Hint = seg + sawHint = true + } } } if strings.HasPrefix(head, "...") { @@ -539,12 +846,17 @@ func parseEmbed(raw string, lparen int) (Embed, *pgerr.APIError) { emb.OutKey = emb.Alias } if inner != "" { - items, nested, perr := parseSelect(inner) + items, nested, perr := parseSelect(inner, true) if perr != nil { return Embed{}, perr } emb.Query.Select = items emb.Query.Embeds = nested + } else { + // Empty parentheses, client(). The relation is still joined for filtering + // but its key is omitted from the output; this is distinct from an absent + // list, which selects every column. + emb.EmptySelect = true } return emb, nil } @@ -560,11 +872,16 @@ func parseColumnItem(raw string) (Column, *pgerr.APIError) { if col.Cast == "" { return Column{}, pgerr.ErrParse("empty cast target") } + if !validCastType(col.Cast) { + return Column{}, pgerr.ErrParse("invalid cast target " + col.Cast) + } } - // alias: leading name before a single ':' (not '::', already stripped) - if i := strings.IndexByte(raw, ':'); i >= 0 { - col.Alias = raw[:i] - raw = raw[i+1:] + // alias: leading name before a single ':' (not '::', already stripped). The + // split is quote-aware so an aliased or target name may itself contain a colon + // when double-quoted (item 01.2). + if alias, rest, ok := cutIdentAware(raw, ':'); ok { + col.Alias = unquoteIdent(alias) + raw = rest } path, last, perr := parsePath(raw) if perr != nil { @@ -575,6 +892,32 @@ func parseColumnItem(raw string) (Column, *pgerr.APIError) { return col, nil } +// validCastType reports whether a ::cast target is a safe type name. PostgREST +// does not whitelist cast targets; it lets PostgreSQL resolve the name (money, +// interval, an enum, a domain, an array type), so the backend passes the type +// through verbatim. Because the type is spliced into SQL rather than bound, the +// grammar is restricted to what a real type spelling needs so nothing breaks out +// of the cast: letters, digits, underscore, spaces (double precision, time +// without time zone), a dot for schema qualification, parentheses and commas for +// precision/scale (numeric(10,2)), and brackets for arrays (int[]). The first +// character must begin an identifier. Anything else (a quote, a semicolon, an +// operator) is a parse error rather than a silent rewrite. +func validCastType(s string) bool { + for i, r := range s { + switch { + case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r == '_': + // always allowed + case r >= '0' && r <= '9', r == ' ', r == '.', r == '(', r == ')', r == ',', r == '[', r == ']': + if i == 0 { + return false + } + default: + return false + } + } + return true +} + // parsePath splits a column reference with optional JSON arrows into hops. // e.g. data->a->>b => {"data","a","b"} with Last=JSONArrow2. func parsePath(raw string) ([]string, JSONStep, *pgerr.APIError) { @@ -582,28 +925,46 @@ func parsePath(raw string) ([]string, JSONStep, *pgerr.APIError) { return nil, JSONNone, pgerr.ErrParse("empty column reference") } last := JSONNone - // normalize ->> and -> into a delimiter sweep + // Sweep ->> and -> into hops, but treat an arrow inside a double-quoted segment + // as part of the identifier rather than a delimiter (item 01.2). var hops []string - rest := raw - for { - i2 := strings.Index(rest, "->>") - i1 := strings.Index(rest, "->") - switch { - case i2 >= 0 && (i1 == -1 || i2 <= i1): - hops = append(hops, rest[:i2]) - rest = rest[i2+3:] - last = JSONArrow2 - case i1 >= 0: - hops = append(hops, rest[:i1]) - rest = rest[i1+2:] - last = JSONArrow - default: - hops = append(hops, rest) - rest = "" + start := 0 + inQuote := false + for i := 0; i < len(raw); { + c := raw[i] + if inQuote { + if c == '\\' && i+1 < len(raw) { + i += 2 + continue + } + if c == '"' { + inQuote = false + } + i++ + continue + } + if c == '"' { + inQuote = true + i++ + continue } - if rest == "" { - break + if c == '-' && i+1 < len(raw) && raw[i+1] == '>' { + hops = append(hops, raw[start:i]) + if i+2 < len(raw) && raw[i+2] == '>' { + last = JSONArrow2 + i += 3 + } else { + last = JSONArrow + i += 2 + } + start = i + continue } + i++ + } + hops = append(hops, raw[start:]) + for j := range hops { + hops[j] = unquoteIdent(hops[j]) } if slices.Contains(hops, "") { return nil, JSONNone, pgerr.ErrParse("empty hop in column path") @@ -623,24 +984,46 @@ func parseOrder(s string) ([]OrderTerm, *pgerr.APIError) { if p == "" { return nil, pgerr.ErrParse("empty order term") } - segs := strings.Split(p, ".") + // Peel the column quote-aware so a double-quoted name may contain a dot + // before the modifier list is split (item 01.2). + colPart, modPart, hasMods := cutIdentAware(p, '.') var t OrderTerm - path, _, perr := parsePath(segs[0]) + // An order term may name an embedded to-one resource's column: + // order=client(name) sorts the parent by the embed's column (item 07.6). + // The relation rides on Rel; the inner text is the column path, which may + // itself carry a JSON sub-path (trash_details(jsonb_col->key)). + if rel, inner, ok := cutRelOrder(colPart); ok { + t.Rel = rel + colPart = inner + } + path, last, perr := parsePath(colPart) if perr != nil { return nil, perr } t.Path = path - for _, mod := range segs[1:] { + t.Last = last + var mods []string + if hasMods { + mods = strings.Split(modPart, ".") + } + // PostgREST's grammar is column[.asc|.desc][.nullsfirst|.nullslast] in that + // fixed order: at most one direction, then at most one nulls modifier, no + // repeats and no direction after a nulls modifier (item 01.7). + var sawDir, sawNulls bool + for _, mod := range mods { switch mod { - case "asc": - t.Desc = false - case "desc": - t.Desc = true - case "nullsfirst": - v := true - t.NullsFirst = &v - case "nullslast": - v := false + case "asc", "desc": + if sawDir || sawNulls { + return nil, pgerr.ErrParse("unexpected order modifier: " + mod) + } + sawDir = true + t.Desc = mod == "desc" + case "nullsfirst", "nullslast": + if sawNulls { + return nil, pgerr.ErrParse("unexpected order modifier: " + mod) + } + sawNulls = true + v := mod == "nullsfirst" t.NullsFirst = &v default: return nil, pgerr.ErrParse("unknown order modifier: " + mod) @@ -651,6 +1034,24 @@ func parseOrder(s string) ([]OrderTerm, *pgerr.APIError) { return terms, nil } +// cutRelOrder splits an order column of the form rel(col) into the relation name +// and the inner column text. It returns ok=false for a plain column (no +// parenthesis) so the caller treats it as a parent column. The relation name must +// be non-empty and the parentheses must wrap the rest of the term; a stray or +// unbalanced parenthesis is left to parsePath, which reports it. +func cutRelOrder(s string) (rel, inner string, ok bool) { + open := strings.IndexByte(s, '(') + if open <= 0 || !strings.HasSuffix(s, ")") { + return "", "", false + } + rel = strings.TrimSpace(s[:open]) + inner = strings.TrimSpace(s[open+1 : len(s)-1]) + if rel == "" || inner == "" { + return "", "", false + } + return rel, inner, true +} + // parseFilters builds the top-level filter tree from column filters plus and=/or=. func parseFilters(vals url.Values) (*Cond, *pgerr.APIError) { var kids []Cond @@ -685,12 +1086,12 @@ func parseFilters(vals url.Values) (*Cond, *pgerr.APIError) { if reservedKeys[key] { continue } - path, _, perr := parsePath(key) + path, last, perr := parsePath(key) if perr != nil { return nil, perr } for _, v := range vals[key] { - cmp, perr := parseCompare(path, v) + cmp, perr := parseCompare(path, last, v) if perr != nil { return nil, perr } @@ -744,16 +1145,17 @@ func parseLogical(op, raw string) (Cond, *pgerr.APIError) { kids = append(kids, node) continue } - // column.op.value - col, rest, ok := strings.Cut(p, ".") + // column.op.value, the column split quote-aware so a double-quoted name may + // contain a dot (item 01.2). + col, rest, ok := cutIdentAware(p, '.') if !ok { return nil, pgerr.ErrParse("malformed predicate in logical: " + p) } - path, _, perr := parsePath(col) + path, last, perr := parsePath(col) if perr != nil { return nil, perr } - cmp, perr := parseCompare(path, rest) + cmp, perr := parseCompare(path, last, rest) if perr != nil { return nil, perr } @@ -772,8 +1174,8 @@ func parseLogical(op, raw string) (Cond, *pgerr.APIError) { } // parseCompare parses a "operator.operand" filter value against a column path. -func parseCompare(path []string, raw string) (Compare, *pgerr.APIError) { - c := Compare{Path: path} +func parseCompare(path []string, last JSONStep, raw string) (Compare, *pgerr.APIError) { + c := Compare{Path: path, Last: last} if strings.HasPrefix(raw, "not.") { c.Negate = true raw = strings.TrimPrefix(raw, "not.") @@ -820,6 +1222,25 @@ func parseCompare(path []string, raw string) (Compare, *pgerr.APIError) { return Compare{}, pgerr.ErrParse("unknown operator: " + base) } c.Op = op + // A quantifier applies to a braces list and is valid only for the operators + // PostgREST allows it on; every element is parsed from the {…} literal, with + // LIKE/ILIKE wildcards translated per element (item 01.1). + if c.Quant != QNone { + if !isQuantifiable(op) { + return Compare{}, pgerr.ErrParse("quantifier any/all is not valid for operator: " + base) + } + list, perr := parseBraceList(operand) + if perr != nil { + return Compare{}, perr + } + if op == OpLike || op == OpILike { + for i, p := range list { + list[i] = strings.ReplaceAll(p, "*", "%") + } + } + c.Value = Value{List: list} + return c, nil + } switch op { case OpIn: list, perr := parseInList(operand) @@ -836,25 +1257,23 @@ func parseCompare(path []string, raw string) (Compare, *pgerr.APIError) { } case OpLike, OpILike: // PostgREST maps * to % in LIKE/ILIKE patterns so URL-friendly wildcards work. - if c.Quant != QNone { - // like(any)/{*cat*,*laundry*} — expand {…} into a list, * → % in each. - list, perr := parseLikeList(operand) - if perr != nil { - return Compare{}, perr - } - for i, p := range list { - list[i] = strings.ReplaceAll(p, "*", "%") - } - c.Value = Value{List: list} - } else { - c.Value = Value{Text: strings.ReplaceAll(operand, "*", "%")} - } + c.Value = Value{Text: strings.ReplaceAll(operand, "*", "%")} default: c.Value = Value{Text: operand} } return c, nil } +// isQuantifiable reports whether an operator accepts an any/all quantifier, the +// set PostgREST allows: eq, gt, gte, lt, lte, like, ilike, match, imatch. +func isQuantifiable(op Op) bool { + switch op { + case OpEq, OpGt, OpGte, OpLt, OpLte, OpLike, OpILike, OpMatch, OpIMatch: + return true + } + return false +} + // parseInList parses (a,b,"c,d") into a slice, honoring double-quoted elements. func parseInList(raw string) ([]string, *pgerr.APIError) { raw = strings.TrimSpace(raw) @@ -862,8 +1281,10 @@ func parseInList(raw string) ([]string, *pgerr.APIError) { return nil, pgerr.ErrParse("in. expects a parenthesized list") } inner := raw[1 : len(raw)-1] + // PostgREST's grammar requires at least one element; ?id=in.() is a parse + // error, not an empty match (item 01.3). if inner == "" { - return []string{}, nil + return nil, pgerr.ErrParse("in. expects at least one value") } parts, err := splitTopLevel(inner, ',') if err != nil { @@ -873,29 +1294,58 @@ func parseInList(raw string) ([]string, *pgerr.APIError) { for _, p := range parts { p = strings.TrimSpace(p) if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { - p = p[1 : len(p)-1] + // A quoted element may escape an interior quote as \" and a backslash as + // \\ (item 01.2); strip the quotes and unescape. + p = unescapeQuoted(p[1 : len(p)-1]) } out = append(out, p) } return out, nil } -// parseLikeList parses a {pat1,pat2,...} literal (PostgREST quantified-LIKE -// syntax) into a slice of raw pattern strings. No wildcard translation is done -// here; the caller applies * → % after parsing. -func parseLikeList(raw string) ([]string, *pgerr.APIError) { +// unescapeQuoted reverses the in-list quoting escapes: \" -> " and \\ -> \. Any +// other backslash sequence keeps the following character literally. +func unescapeQuoted(s string) string { + if !strings.ContainsRune(s, '\\') { + return s + } + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); i++ { + if s[i] == '\\' && i+1 < len(s) { + i++ + b.WriteByte(s[i]) + continue + } + b.WriteByte(s[i]) + } + return b.String() +} + +// parseBraceList parses a {a,b,"c,d"} array literal (PostgREST's quantified +// operand) into its elements, honoring double-quoted elements so a comma or +// reserved character can appear inside one. No wildcard translation is done here; +// a LIKE/ILIKE caller applies * → % afterward (items 01.1, 01.2). +func parseBraceList(raw string) ([]string, *pgerr.APIError) { raw = strings.TrimSpace(raw) if len(raw) < 2 || raw[0] != '{' || raw[len(raw)-1] != '}' { - return nil, pgerr.ErrParse("like(any/all) expects a {…} list") + return nil, pgerr.ErrParse("any/all expects a {…} list") } inner := raw[1 : len(raw)-1] if inner == "" { - return []string{}, nil + return nil, pgerr.ErrParse("any/all list must have at least one value") } - parts := strings.Split(inner, ",") - out := make([]string, len(parts)) - for i, p := range parts { - out[i] = strings.TrimSpace(p) + parts, err := splitTopLevel(inner, ',') + if err != nil { + return nil, pgerr.ErrParse("malformed any/all list") + } + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if len(p) >= 2 && p[0] == '"' && p[len(p)-1] == '"' { + p = unescapeQuoted(p[1 : len(p)-1]) + } + out = append(out, p) } return out, nil } @@ -966,31 +1416,83 @@ func opFromToken(tok string) (Op, bool) { return 0, false } -// splitTopLevel splits s on sep, ignoring sep inside () and "". +// splitTopLevel splits s on sep, ignoring sep inside (), {}, and "". Inside a +// quoted span a backslash escapes the next byte, so an escaped quote does not end +// the span and an escaped separator is not a split point (items 01.1, 01.2). func splitTopLevel(s string, sep byte) ([]string, error) { var out []string depth := 0 inQuote := false start := 0 for i := 0; i < len(s); i++ { - switch c := s[i]; { - case c == '"': - inQuote = !inQuote - case inQuote: - // skip - case c == '(': + c := s[i] + if inQuote { + switch { + case c == '\\' && i+1 < len(s): + i++ // skip the escaped byte + case c == '"': + inQuote = false + } + continue + } + switch c { + case '"': + inQuote = true + case '(', '{': depth++ - case c == ')': + case ')', '}': depth-- - case c == sep && depth == 0: - out = append(out, s[start:i]) - start = i + 1 + case sep: + if depth == 0 { + out = append(out, s[start:i]) + start = i + 1 + } } } out = append(out, s[start:]) return out, nil } +// cutIdentAware splits s at the first sep byte that is not inside a double-quoted +// identifier segment, returning the text before and after it and whether one was +// found. A backslash inside quotes escapes the next byte. This lets a reserved +// character (dot, colon) sit inside a %22-quoted column or relation name without +// being treated as a delimiter (item 01.2). +func cutIdentAware(s string, sep byte) (before, after string, found bool) { + inQuote := false + for i := 0; i < len(s); i++ { + c := s[i] + if inQuote { + if c == '\\' && i+1 < len(s) { + i++ + continue + } + if c == '"' { + inQuote = false + } + continue + } + switch c { + case '"': + inQuote = true + case sep: + return s[:i], s[i+1:], true + } + } + return s, "", false +} + +// unquoteIdent strips one layer of surrounding double quotes from an identifier +// segment so a reserved character can appear in a column or relation name; an +// interior doubled quote ("") unescapes to a single quote, as in SQL. A segment +// that is not fully quoted is returned unchanged (item 01.2). +func unquoteIdent(s string) string { + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + return strings.ReplaceAll(s[1:len(s)-1], `""`, `"`) + } + return s +} + // sortStrings sorts in place (small slices; avoids importing sort everywhere). func sortStrings(s []string) { for i := 1; i < len(s); i++ { diff --git a/ir/parse_longtail_test.go b/ir/parse_longtail_test.go new file mode 100644 index 0000000..4a72278 --- /dev/null +++ b/ir/parse_longtail_test.go @@ -0,0 +1,256 @@ +package ir + +import ( + "reflect" + "testing" +) + +// errCode asserts ParseRead fails with the given PGRST code. +func errCode(t *testing.T, query, code string) { + t.Helper() + _, err := ParseRead("films", query, nil) + if err == nil { + t.Fatalf("ParseRead(%q): want error %s, got nil", query, code) + } + if err.Code != code { + t.Fatalf("ParseRead(%q): code = %s, want %s", query, err.Code, code) + } +} + +// --- 01.1: any/all quantifiers --- + +func TestQuantifierParsesListForEachOperator(t *testing.T) { + for _, op := range []string{"eq", "gt", "gte", "lt", "lte", "match", "imatch"} { + cmp := fetchCompare(t, "id="+op+"(any).{1,2,3}") + if cmp.Quant != QAny { + t.Errorf("%s(any): Quant = %d, want QAny", op, cmp.Quant) + } + if !reflect.DeepEqual(cmp.Value.List, []string{"1", "2", "3"}) { + t.Errorf("%s(any): List = %v, want [1 2 3]", op, cmp.Value.List) + } + } +} + +func TestQuantifierLikeTranslatesWildcards(t *testing.T) { + cmp := fetchCompare(t, "name=like(any).{*cat*,*dog*}") + if cmp.Op != OpLike || cmp.Quant != QAny { + t.Fatalf("got Op=%v Quant=%d", cmp.Op, cmp.Quant) + } + if !reflect.DeepEqual(cmp.Value.List, []string{"%cat%", "%dog%"}) { + t.Errorf("List = %v, want [%%cat%% %%dog%%]", cmp.Value.List) + } +} + +func TestQuantifierRejectedOnNonQuantifiable(t *testing.T) { + // neq and is do not take a quantifier in PostgREST. + errCode(t, "id=neq(any).{1,2}", "PGRST100") +} + +func TestQuantifierEmptyListRejected(t *testing.T) { + errCode(t, "id=eq(any).{}", "PGRST100") +} + +func TestQuantifierListInLogicalTree(t *testing.T) { + // The comma inside {…} must not split the or= tree (item 01.1 splitTopLevel). + q := mustRead(t, "or=(name.like(any).{*cat*,*dog*},year.eq.2000)") + or, ok := (*q.Where).(Or) + if !ok { + t.Fatalf("Where = %T, want Or", *q.Where) + } + if len(or.Kids) != 2 { + t.Fatalf("or has %d kids, want 2", len(or.Kids)) + } + first := or.Kids[0].(Compare) + if !reflect.DeepEqual(first.Value.List, []string{"%cat%", "%dog%"}) { + t.Errorf("first kid list = %v", first.Value.List) + } +} + +// --- 01.2: quoted identifiers and in-list escapes --- + +func TestQuotedIdentifierWithDotInFilter(t *testing.T) { + cmp := fetchCompare(t, `%22weird.name%22=eq.1`) + if !reflect.DeepEqual(cmp.Path, []string{"weird.name"}) { + t.Errorf("Path = %v, want [weird.name]", cmp.Path) + } +} + +func TestQuotedIdentifierInSelect(t *testing.T) { + q := mustRead(t, `select=%22a:b%22`) + c := q.Select[0].(Column) + if !reflect.DeepEqual(c.Path, []string{"a:b"}) { + t.Errorf("Path = %v, want [a:b]", c.Path) + } + if c.Alias != "" { + t.Errorf("Alias = %q, want empty (colon was inside quotes)", c.Alias) + } +} + +func TestQuotedIdentifierInOrder(t *testing.T) { + q := mustRead(t, `order=%22weird.name%22.desc`) + if len(q.Order) != 1 { + t.Fatalf("got %d order terms", len(q.Order)) + } + if !reflect.DeepEqual(q.Order[0].Path, []string{"weird.name"}) || !q.Order[0].Desc { + t.Errorf("order = %+v", q.Order[0]) + } +} + +func TestQuotedIdentifierInLogicalTree(t *testing.T) { + q := mustRead(t, `or=(%22weird.name%22.eq.1,year.eq.2)`) + or := (*q.Where).(Or) + first := or.Kids[0].(Compare) + if !reflect.DeepEqual(first.Path, []string{"weird.name"}) { + t.Errorf("Path = %v, want [weird.name]", first.Path) + } +} + +func TestInListBackslashEscapes(t *testing.T) { + // in.("a,b","c\"d","e\\f") -> elements with the comma, quote, and backslash. + cmp := fetchCompare(t, `tag=in.("a,b","c\"d","e\\f")`) + want := []string{"a,b", `c"d`, `e\f`} + if !reflect.DeepEqual(cmp.Value.List, want) { + t.Errorf("List = %v, want %v", cmp.Value.List, want) + } +} + +// --- 01.3: empty in.() --- + +func TestEmptyInListRejected(t *testing.T) { + errCode(t, "id=in.()", "PGRST100") +} + +// --- 01.5: empty select= --- + +func TestEmptySelectRejected(t *testing.T) { + errCode(t, "select=", "PGRST100") +} + +func TestOmittedSelectIsAllColumns(t *testing.T) { + q := mustRead(t, "year=eq.2000") + if len(q.Select) != 0 { + t.Errorf("omitted select should leave an empty projection, got %v", q.Select) + } +} + +// --- 01.4: aggregate select syntax --- + +func TestAggregateCountNoArg(t *testing.T) { + q := mustRead(t, "select=count()") + agg, ok := q.Select[0].(Aggregate) + if !ok { + t.Fatalf("Select[0] = %T, want Aggregate", q.Select[0]) + } + if agg.Func != AggCount || agg.Arg != nil || agg.Legacy { + t.Errorf("agg = %+v, want count() non-legacy no-arg", agg) + } + if agg.Name() != "count" { + t.Errorf("Name = %q, want count", agg.Name()) + } +} + +func TestAggregateColumnFunc(t *testing.T) { + for name, want := range map[string]AggFunc{ + "sum": AggSum, "avg": AggAvg, "min": AggMin, "max": AggMax, "count": AggCount, + } { + q := mustRead(t, "select=year."+name+"()") + agg := q.Select[0].(Aggregate) + if agg.Func != want { + t.Errorf("%s: Func = %d, want %d", name, agg.Func, want) + } + if agg.Arg == nil || !reflect.DeepEqual(agg.Arg.Path, []string{"year"}) { + t.Errorf("%s: Arg = %+v, want path [year]", name, agg.Arg) + } + } +} + +func TestAggregateAlias(t *testing.T) { + q := mustRead(t, "select=total:year.sum()") + agg := q.Select[0].(Aggregate) + if agg.Alias != "total" || agg.Name() != "total" { + t.Errorf("Alias = %q, want total", agg.Alias) + } +} + +func TestAggregateOutputCast(t *testing.T) { + q := mustRead(t, "select=year.sum()::text") + agg := q.Select[0].(Aggregate) + if agg.Cast != "text" { + t.Errorf("Cast = %q, want text", agg.Cast) + } + if agg.Arg == nil || agg.Arg.Cast != "" { + t.Errorf("input cast should be empty, Arg = %+v", agg.Arg) + } +} + +func TestAggregateInputCast(t *testing.T) { + q := mustRead(t, "select=year::numeric.sum()") + agg := q.Select[0].(Aggregate) + if agg.Arg == nil || agg.Arg.Cast != "numeric" { + t.Errorf("Arg = %+v, want input cast numeric", agg.Arg) + } + if agg.Cast != "" { + t.Errorf("output cast should be empty, got %q", agg.Cast) + } +} + +func TestAggregateAliasInputAndOutputCast(t *testing.T) { + q := mustRead(t, "select=total:year::numeric.sum()::text") + agg := q.Select[0].(Aggregate) + if agg.Alias != "total" || agg.Cast != "text" || agg.Arg == nil || agg.Arg.Cast != "numeric" { + t.Errorf("agg = %+v arg = %+v", agg, agg.Arg) + } +} + +func TestBareCountIsColumnAtTopLevel(t *testing.T) { + q := mustRead(t, "select=count") + c, ok := q.Select[0].(Column) + if !ok { + t.Fatalf("Select[0] = %T, want Column (top-level bare count is a column)", q.Select[0]) + } + if !reflect.DeepEqual(c.Path, []string{"count"}) { + t.Errorf("Path = %v, want [count]", c.Path) + } +} + +func TestBareCountIsLegacyAggregateInsideEmbed(t *testing.T) { + q := mustRead(t, "select=directors(count)") + emb := q.Embeds[0] + agg, ok := emb.Query.Select[0].(Aggregate) + if !ok { + t.Fatalf("embed select[0] = %T, want Aggregate", emb.Query.Select[0]) + } + if agg.Func != AggCount || !agg.Legacy { + t.Errorf("agg = %+v, want legacy count", agg) + } +} + +func TestAggregateMissingColumnRejected(t *testing.T) { + // sum() needs a column; only count() may stand alone. + errCode(t, "select=sum()", "PGRST100") +} + +// --- 01.7: order modifier grammar --- + +func TestOrderModifierGrammar(t *testing.T) { + good := []string{ + "order=year", + "order=year.asc", + "order=year.desc", + "order=year.asc.nullsfirst", + "order=year.desc.nullslast", + } + for _, q := range good { + if _, err := ParseRead("films", q, nil); err != nil { + t.Errorf("%q: unexpected error %v", q, err) + } + } + bad := []string{ + "order=year.nullsfirst.asc", // nulls before direction + "order=year.asc.desc", // two directions + "order=year.nullsfirst.nullslast", + } + for _, q := range bad { + errCode(t, q, "PGRST100") + } +} diff --git a/ir/parse_test.go b/ir/parse_test.go index 0e4a0df..9c8e300 100644 --- a/ir/parse_test.go +++ b/ir/parse_test.go @@ -37,6 +37,34 @@ func TestParseSelectAliasAndCast(t *testing.T) { } } +// A cast target is spliced into SQL, not bound, so the parser validates it +// against a safe type grammar: real type spellings pass, anything that could +// break out of the cast is a PGRST100 parse error. PostgREST itself does not +// whitelist the type name, so every well-formed spelling must survive. +func TestParseSelectCastValidation(t *testing.T) { + ok := []string{ + "price::money", "d::interval", "raw::bytea", "ip::inet", + "m::mood", "n::numeric(10,2)", "tags::int[]", "c::public.color", + "t::double precision", + } + for _, item := range ok { + if _, err := ParseRead("films", "select="+item, nil); err != nil { + t.Errorf("select=%s: unexpected error %v", item, err) + } + } + bad := []string{ + "x::te'xt", "x::text;drop", "x::text--", "x::int*2", "x::1nt", "x::ta\\b", + } + for _, item := range bad { + _, err := ParseRead("films", "select="+item, nil) + if err == nil { + t.Errorf("select=%s: want parse error, got none", item) + } else if err.Code != "PGRST100" { + t.Errorf("select=%s: code = %s, want PGRST100", item, err.Code) + } + } +} + func TestParseSelectJSONPath(t *testing.T) { q := mustRead(t, "select=data->meta->>id") c := q.Select[0].(Column) @@ -48,6 +76,34 @@ func TestParseSelectJSONPath(t *testing.T) { } } +// 07.1: a JSON-path filter keeps the base column and hops on Compare.Path and +// records the final ->/->> on Last so the compiler can type the access. +func TestParseFilterJSONPath(t *testing.T) { + q := mustRead(t, "data->phones->0->>number=eq.555") + cmp, ok := (*q.Where).(Compare) + if !ok { + t.Fatalf("where = %T, want Compare", *q.Where) + } + if !reflect.DeepEqual(cmp.Path, []string{"data", "phones", "0", "number"}) { + t.Errorf("path = %v", cmp.Path) + } + if cmp.Last != JSONArrow2 { + t.Errorf("last = %v, want JSONArrow2 (final ->>)", cmp.Last) + } +} + +// 07.1: ordering by a JSON path records the path and the final hop kind. +func TestParseOrderJSONPath(t *testing.T) { + q := mustRead(t, "order=data->>created_at.desc") + if len(q.Order) != 1 { + t.Fatalf("order terms = %d", len(q.Order)) + } + o := q.Order[0] + if !reflect.DeepEqual(o.Path, []string{"data", "created_at"}) || o.Last != JSONArrow2 || !o.Desc { + t.Errorf("order = %v last=%v desc=%v", o.Path, o.Last, o.Desc) + } +} + func TestParseEmbed(t *testing.T) { q := mustRead(t, "select=title,director(name,bio)") if len(q.Embeds) != 1 { @@ -65,6 +121,25 @@ func TestParseEmbed(t *testing.T) { } } +// 07.8: empty parentheses set EmptySelect so the compiler can join the relation +// for filtering yet hide its key, distinct from rel(*) which selects every +// column. +func TestParseEmbedEmptySelect(t *testing.T) { + q := mustRead(t, "select=title,director()") + emb := q.Embeds[0] + if !emb.EmptySelect { + t.Errorf("director() should set EmptySelect") + } + if len(emb.Query.Select) != 0 { + t.Errorf("director() select = %d, want 0", len(emb.Query.Select)) + } + + q = mustRead(t, "select=title,director(*)") + if q.Embeds[0].EmptySelect { + t.Errorf("director(*) must not set EmptySelect") + } +} + func TestParseEmbedInnerHint(t *testing.T) { q := mustRead(t, "select=director!inner(name)") if q.Embeds[0].Join != JoinInner { @@ -72,6 +147,41 @@ func TestParseEmbedInnerHint(t *testing.T) { } } +// TestParseEmbedHintWithInner covers a disambiguation hint composed with !inner +// in either order, plus !hint!left (item 01.13). +func TestParseEmbedHintWithInner(t *testing.T) { + cases := []struct { + sel string + hint string + join JoinKind + }{ + {"select=addresses!billing!inner(city)", "billing", JoinInner}, + {"select=addresses!inner!billing(city)", "billing", JoinInner}, + {"select=addresses!billing!left(city)", "billing", JoinLeft}, + {"select=addresses!billing(city)", "billing", JoinLeft}, + } + for _, c := range cases { + q := mustRead(t, c.sel) + emb := q.Embeds[0] + if emb.Target.Name != "addresses" { + t.Errorf("%s: target = %q, want addresses", c.sel, emb.Target.Name) + } + if emb.Hint != c.hint { + t.Errorf("%s: hint = %q, want %q", c.sel, emb.Hint, c.hint) + } + if emb.Join != c.join { + t.Errorf("%s: join = %v, want %v", c.sel, emb.Join, c.join) + } + } +} + +func TestParseEmbedTwoHintsRejected(t *testing.T) { + _, err := ParseRead("films", "select=addresses!one!two(city)", nil) + if err == nil || err.Code != "PGRST100" { + t.Fatalf("want PGRST100 for two hints, got %v", err) + } +} + func TestParseFiltersAnded(t *testing.T) { q := mustRead(t, "rating=gte.4&year=lt.2000") and, ok := (*q.Where).(And) @@ -119,6 +229,16 @@ func TestParseIs(t *testing.T) { } } +// is.unknown is the three-valued boolean test; the parser must accept it +// alongside null/true/false/not_null (item 07.4). +func TestParseIsUnknown(t *testing.T) { + q := mustRead(t, "done=is.unknown") + c := (*q.Where).(Compare) + if c.Op != OpIs || c.Value.Text != "unknown" { + t.Errorf("op/val = %v/%q, want OpIs/unknown", c.Op, c.Value.Text) + } +} + func TestParseQuantifier(t *testing.T) { q := mustRead(t, "tags=eq(any).{a}") c := (*q.Where).(Compare) @@ -168,11 +288,19 @@ func TestParseLimitOffset(t *testing.T) { if q.Limit == nil || *q.Limit != 10 || q.Offset == nil || *q.Offset != 20 { t.Errorf("limit/offset = %v/%v", q.Limit, q.Offset) } + // A well-formed negative limit is the 416 PGRST103 range error, with the + // upstream detail; a non-numeric limit is still a PGRST100 parse error. if _, err := ParseRead("films", "limit=-1", nil); err == nil { t.Error("negative limit should error") + } else if err.Code != "PGRST103" { + t.Errorf("negative limit code = %s, want PGRST103", err.Code) + } else if err.Details == nil || *err.Details != "Limit should be greater than or equal to zero." { + t.Errorf("negative limit details = %v", err.Details) } if _, err := ParseRead("films", "limit=abc", nil); err == nil { t.Error("non-numeric limit should error") + } else if err.Code != "PGRST100" { + t.Errorf("non-numeric limit code = %s, want PGRST100", err.Code) } } @@ -267,19 +395,52 @@ func TestParseWriteUpsertViaResolution(t *testing.T) { } } -func TestParseWriteOnConflictTarget(t *testing.T) { +// on_conflict without a resolution preference leaves a POST a plain insert, the +// way PostgREST does: a duplicate key then fails with 409 rather than silently +// overwriting the existing row (review item 01.14). +func TestParseWriteOnConflictAloneStaysInsert(t *testing.T) { q, err := ParseWrite(Insert, "films", "on_conflict=id", nil, "", []byte(`{"id":1}`)) if err != nil { t.Fatalf("ParseWrite: %v", err) } + if q.Kind != Insert { + t.Errorf("on_conflict alone should stay an insert, got %v", q.Kind) + } + if q.Write.Conflict != nil { + t.Errorf("conflict = %#v, want nil for a plain insert", q.Write.Conflict) + } +} + +// on_conflict combined with a resolution preference promotes to an upsert and +// carries the named target. +func TestParseWriteOnConflictWithResolution(t *testing.T) { + q, err := ParseWrite(Insert, "films", "on_conflict=id", []string{"resolution=merge-duplicates"}, "", []byte(`{"id":1}`)) + if err != nil { + t.Fatalf("ParseWrite: %v", err) + } if q.Kind != Upsert { - t.Errorf("on_conflict should make it an upsert, got %v", q.Kind) + t.Errorf("on_conflict with resolution should upsert, got %v", q.Kind) } if got := q.Write.Conflict.Target; len(got) != 1 || got[0] != "id" { t.Errorf("conflict target = %v, want [id]", got) } } +// A PATCH carrying a stale resolution preference is not promoted to an upsert; +// resolution and on_conflict are consulted only for inserts and PUT (01.14). +func TestParseWriteResolutionIgnoredForUpdate(t *testing.T) { + q, err := ParseWrite(Update, "films", "on_conflict=id", []string{"resolution=merge-duplicates"}, "", []byte(`{"title":"X"}`)) + if err != nil { + t.Fatalf("ParseWrite: %v", err) + } + if q.Kind != Update { + t.Errorf("Kind = %v, want Update (resolution ignored for PATCH)", q.Kind) + } + if q.Write.Conflict != nil { + t.Errorf("conflict = %#v, want nil for an update", q.Write.Conflict) + } +} + func TestParseWriteReturnAndMissing(t *testing.T) { q, err := ParseWrite(Insert, "films", "", []string{"return=representation", "missing=null"}, "", []byte(`{"title":"X"}`)) if err != nil { diff --git a/ir/prefer.go b/ir/prefer.go index d3739df..192912d 100644 --- a/ir/prefer.go +++ b/ir/prefer.go @@ -1,6 +1,12 @@ package ir -import "strings" +import ( + "strconv" + "strings" + "time" + + "github.com/tamnd/dbrest/pgerr" +) // Handling is the Prefer: handling= mode for unrecognized parameters/preferences. type Handling uint8 @@ -11,8 +17,9 @@ const ( ) // PreferSet is the parsed Prefer header. A nil pointer field means the client -// did not state that preference. Applied records, in order, the preferences the -// server actually honored, for the Preference-Applied response header. +// did not state that preference. applied records the honored "key=value" tokens +// for the Preference-Applied response header; invalid records the tokens a +// handling=strict request rejects. type PreferSet struct { Return *ReturnMode Count *CountKind @@ -21,15 +28,47 @@ type PreferSet struct { Tx *TxMode Handling Handling - // applied is the list of "key=value" tokens that were honored. - applied []string + // MaxAffected caps the rows a mutation (or RPC) may affect. It is honored only + // under handling=strict; ParsePrefer clears it under lenient so a backend can + // enforce on a non-nil pointer alone without consulting Handling. + MaxAffected *int64 + + // TimeZone is the Prefer: timezone= request timezone, validated against the Go + // tz database (the portable analog of pg_timezone_names). It is honored whenever + // valid; an invalid name is an offender, ignored under lenient and a strict + // violation under handling=strict. Backends that support it apply SET LOCAL + // timezone; the emulated render path converts temporals to it. + TimeZone *string + + // applied maps a preference key to its honored "key=value" token. The header + // is emitted in PostgREST's canonical order, not encounter order. + applied map[string]string + // invalid lists the verbatim tokens that named an unknown preference or gave a + // known one a bad value; handling=strict rejects a request carrying any. + invalid []string +} + +// preferKeys are the preference keys dbrest recognizes. A token whose key is not +// here is an unknown preference, an offender under handling=strict. +var preferKeys = map[string]bool{ + "return": true, "count": true, "resolution": true, + "missing": true, "tx": true, "handling": true, "max-affected": true, + "timezone": true, } +// applyOrder is PostgREST's fixed Preference-Applied ordering. timezone and +// max-affected are listed for when those preferences land (02.2, 02.3); an +// absent key is skipped. +var applyOrder = []string{"resolution", "missing", "return", "count", "tx", "handling", "timezone", "max-affected"} + // ParsePrefer parses one or more Prefer header values (comma-separated tokens) -// into a PreferSet. Unknown tokens are ignored here; strict handling is enforced -// by the caller against the recognized set. +// into a PreferSet. Only the first occurrence of a duplicated preference is +// honored, matching PostgREST. Unknown keys and bad values are recorded on +// invalid so a handling=strict caller can be rejected; under the default lenient +// handling they are ignored. func ParsePrefer(headers []string) PreferSet { - var p PreferSet + p := PreferSet{applied: map[string]string{}} + seen := map[string]bool{} for _, h := range headers { for tok := range strings.SplitSeq(h, ",") { tok = strings.TrimSpace(tok) @@ -38,89 +77,188 @@ func ParsePrefer(headers []string) PreferSet { } k, v, _ := strings.Cut(tok, "=") k, v = strings.TrimSpace(k), strings.TrimSpace(v) - switch k { - case "return": - switch v { - case "minimal": - m := ReturnMinimal - p.Return = &m - case "headers-only": - m := ReturnHeadersOnly - p.Return = &m - case "representation": - m := ReturnRepresentation - p.Return = &m - default: - continue - } - p.markApplied(k + "=" + v) - case "count": - switch v { - case "exact": - c := CountExact - p.Count = &c - case "planned": - c := CountPlanned - p.Count = &c - case "estimated": - c := CountEstimated - p.Count = &c - default: - continue - } - p.markApplied(k + "=" + v) - case "resolution": - switch v { - case "merge-duplicates": - r := ConflictMerge - p.Resolution = &r - case "ignore-duplicates": - r := ConflictIgnore - p.Resolution = &r - default: - continue - } - p.markApplied(k + "=" + v) - case "missing": - switch v { - case "default": - m := MissingDefault - p.Missing = &m - case "null": - m := MissingNull - p.Missing = &m - default: - continue - } - p.markApplied(k + "=" + v) - case "tx": - switch v { - case "commit": - t := TxCommit - p.Tx = &t - case "rollback": - t := TxRollback - p.Tx = &t - default: - continue - } - p.markApplied(k + "=" + v) - case "handling": - if v == "strict" { - p.Handling = HandlingStrict - p.markApplied(k + "=" + v) - } + if !preferKeys[k] { + p.invalid = append(p.invalid, tok) + continue + } + if seen[k] { + // Only the first occurrence of a preference is honored. + continue + } + seen[k] = true + if p.set(k, v) { + p.applied[k] = k + "=" + v + } else { + p.invalid = append(p.invalid, tok) } } } + // max-affected takes effect, and is echoed, only under handling=strict; under + // lenient PostgREST ignores it entirely. Clearing both here lets every later + // reader treat a non-nil MaxAffected as "enforce this" with no handling check. + if p.Handling != HandlingStrict { + delete(p.applied, "max-affected") + p.MaxAffected = nil + } return p } -// markApplied records that a "key=value" preference was honored. -func (p *PreferSet) markApplied(kv string) { p.applied = append(p.applied, kv) } +// set applies one recognized preference and reports whether the value was valid. +// A bad value leaves the field untouched and the token is recorded as an +// offender by the caller. +func (p *PreferSet) set(k, v string) bool { + switch k { + case "return": + switch v { + case "minimal": + m := ReturnMinimal + p.Return = &m + case "headers-only": + m := ReturnHeadersOnly + p.Return = &m + case "representation": + m := ReturnRepresentation + p.Return = &m + default: + return false + } + case "count": + switch v { + case "exact": + c := CountExact + p.Count = &c + case "planned": + c := CountPlanned + p.Count = &c + case "estimated": + c := CountEstimated + p.Count = &c + default: + return false + } + case "resolution": + switch v { + case "merge-duplicates": + r := ConflictMerge + p.Resolution = &r + case "ignore-duplicates": + r := ConflictIgnore + p.Resolution = &r + default: + return false + } + case "missing": + switch v { + case "default": + m := MissingDefault + p.Missing = &m + case "null": + m := MissingNull + p.Missing = &m + default: + return false + } + case "tx": + switch v { + case "commit": + t := TxCommit + p.Tx = &t + case "rollback": + t := TxRollback + p.Tx = &t + default: + return false + } + case "handling": + switch v { + case "strict": + p.Handling = HandlingStrict + case "lenient": + p.Handling = HandlingLenient + default: + return false + } + case "max-affected": + n, err := strconv.ParseInt(v, 10, 64) + if err != nil || n < 0 { + return false + } + p.MaxAffected = &n + case "timezone": + // Validate against the Go tz database, the portable analog of + // pg_timezone_names. An empty or unknown name is an offender. + if v == "" { + return false + } + if _, err := time.LoadLocation(v); err != nil { + return false + } + tz := v + p.TimeZone = &tz + } + return true +} + +// ResolveTx applies the db-tx-end server policy to a parsed tx= preference. +// PostgREST only honors Prefer: tx= under the two allow-override variants; with +// the fixed commit or rollback policy the preference is not honored, not echoed, +// and is a handling=strict offender, and the rollback policy forces a rollback +// even when the client said nothing. It must run after ParsePrefer and before +// the effective TxMode is read; a later StrictError call then sees a disallowed +// tx= as an offender. +func (p *PreferSet) ResolveTx(policy TxEnd) { + clientToken, clientSent := p.applied["tx"] + override := policy == TxEndCommitAllowOverride || policy == TxEndRollbackAllowOverride + if !override { + // tx= cannot override the outcome: drop the echo, reject it under strict, + // and force the policy's fixed outcome. + if clientSent { + delete(p.applied, "tx") + if p.Handling == HandlingStrict { + p.invalid = append(p.invalid, clientToken) + } + } + if policy == TxEndRollback { + t := TxRollback + p.Tx = &t + } else { + p.Tx = nil + } + return + } + // An allow-override mode honors a client tx= (already echoed) and otherwise + // applies the mode default, which is not echoed. + if !clientSent { + if policy == TxEndRollbackAllowOverride { + t := TxRollback + p.Tx = &t + } else { + p.Tx = nil + } + } +} -// AppliedHeader returns the Preference-Applied header value, or "" if nothing -// was applied. +// StrictError returns the PGRST122 a handling=strict request earns when it +// carries any unknown preference or bad value, and nil otherwise (including the +// default lenient handling, which ignores the offenders). +func (p *PreferSet) StrictError() *pgerr.APIError { + if p.Handling != HandlingStrict || len(p.invalid) == 0 { + return nil + } + return pgerr.ErrInvalidPreferences(p.invalid) +} + +// AppliedHeader returns the Preference-Applied header value in PostgREST's +// canonical order, or "" if nothing was applied. func (p *PreferSet) AppliedHeader() string { - return strings.Join(p.applied, ", ") + if len(p.applied) == 0 { + return "" + } + out := make([]string, 0, len(p.applied)) + for _, k := range applyOrder { + if v, ok := p.applied[k]; ok { + out = append(out, v) + } + } + return strings.Join(out, ", ") } diff --git a/ir/prefer_test.go b/ir/prefer_test.go index dd6a3cf..076523c 100644 --- a/ir/prefer_test.go +++ b/ir/prefer_test.go @@ -1,6 +1,9 @@ package ir -import "testing" +import ( + "strings" + "testing" +) func TestParsePreferRecognizesEachToken(t *testing.T) { p := ParsePrefer([]string{"return=representation, resolution=merge-duplicates, missing=null, tx=rollback, handling=strict"}) @@ -21,15 +24,125 @@ func TestParsePreferRecognizesEachToken(t *testing.T) { } } -func TestAppliedHeaderEchoesHonoredInOrder(t *testing.T) { +func TestAppliedHeaderEchoesHonoredInCanonicalOrder(t *testing.T) { + // Sent count before return; PostgREST's Preference-Applied is emitted in its + // fixed order (return before count), not request order. p := ParsePrefer([]string{"count=exact, return=minimal"}) - // Only honored tokens appear, in request order; the comma-joined form is the - // Preference-Applied response header. - if got, want := p.AppliedHeader(), "count=exact, return=minimal"; got != want { + if got, want := p.AppliedHeader(), "return=minimal, count=exact"; got != want { t.Errorf("AppliedHeader = %q, want %q", got, want) } } +// TestParsePreferFirstDuplicateWins checks only the first occurrence of a +// duplicated preference is honored, matching PostgREST, and the applied header +// carries one token. +func TestParsePreferFirstDuplicateWins(t *testing.T) { + p := ParsePrefer([]string{"count=exact, count=planned"}) + if p.Count == nil || *p.Count != CountExact { + t.Errorf("count = %v, want the first occurrence (exact)", p.Count) + } + if got, want := p.AppliedHeader(), "count=exact"; got != want { + t.Errorf("AppliedHeader = %q, want %q", got, want) + } +} + +// TestParsePreferLenientEchoed checks an explicit handling=lenient is recognized +// and echoed, where before it was dropped as unknown. +func TestParsePreferLenientEchoed(t *testing.T) { + p := ParsePrefer([]string{"handling=lenient"}) + if p.Handling != HandlingLenient { + t.Errorf("handling = %v, want lenient", p.Handling) + } + if got, want := p.AppliedHeader(), "handling=lenient"; got != want { + t.Errorf("AppliedHeader = %q, want %q", got, want) + } +} + +// TestStrictErrorRejectsOffenders checks handling=strict turns an unknown key or +// a bad value into a PGRST122, while the default lenient handling ignores them. +func TestStrictErrorRejectsOffenders(t *testing.T) { + strict := ParsePrefer([]string{"handling=strict, return=bogus, frobnicate=yes"}) + err := strict.StrictError() + if err == nil || err.Code != "PGRST122" { + t.Fatalf("StrictError = %v, want PGRST122", err) + } + lenient := ParsePrefer([]string{"return=bogus, frobnicate=yes"}) + if lenient.StrictError() != nil { + t.Error("lenient handling must not reject invalid preferences") + } +} + +// TestParsePreferMaxAffectedStrictOnly checks max-affected=N is parsed and +// echoed only under handling=strict; under lenient PostgREST ignores it, so both +// the value and the echo are dropped. +func TestParsePreferMaxAffectedStrictOnly(t *testing.T) { + strict := ParsePrefer([]string{"handling=strict, max-affected=5"}) + if strict.MaxAffected == nil || *strict.MaxAffected != 5 { + t.Fatalf("strict MaxAffected = %v, want 5", strict.MaxAffected) + } + if got, want := strict.AppliedHeader(), "handling=strict, max-affected=5"; got != want { + t.Errorf("strict AppliedHeader = %q, want %q", got, want) + } + + lenient := ParsePrefer([]string{"max-affected=5"}) + if lenient.MaxAffected != nil { + t.Errorf("lenient MaxAffected = %v, want nil (ignored)", *lenient.MaxAffected) + } + if got := lenient.AppliedHeader(); got != "" { + t.Errorf("lenient AppliedHeader = %q, want empty", got) + } +} + +// TestParsePreferMaxAffectedBadValue checks a non-integer or negative +// max-affected is an offender (PGRST122 under strict) and leaves the bound unset. +func TestParsePreferMaxAffectedBadValue(t *testing.T) { + for _, v := range []string{"abc", "-1", "1.5", ""} { + p := ParsePrefer([]string{"handling=strict, max-affected=" + v}) + if p.MaxAffected != nil { + t.Errorf("max-affected=%q set MaxAffected = %v, want nil", v, *p.MaxAffected) + } + if p.StrictError() == nil { + t.Errorf("max-affected=%q under strict should be a PGRST122 offender", v) + } + } +} + +// TestParsePreferTimeZoneValid checks a valid IANA name is captured and echoed, +// unlike max-affected it is honored under lenient too. +func TestParsePreferTimeZoneValid(t *testing.T) { + for _, h := range []string{"timezone=America/Los_Angeles", "handling=strict, timezone=America/Los_Angeles"} { + p := ParsePrefer([]string{h}) + if p.TimeZone == nil || *p.TimeZone != "America/Los_Angeles" { + t.Fatalf("%q: TimeZone = %v, want America/Los_Angeles", h, p.TimeZone) + } + if p.StrictError() != nil { + t.Errorf("%q: valid timezone should not be an offender", h) + } + // The Preference-Applied echo carries the honored timezone token. + if !strings.Contains(p.AppliedHeader(), "timezone=America/Los_Angeles") { + t.Errorf("%q: AppliedHeader = %q, missing timezone echo", h, p.AppliedHeader()) + } + } +} + +// TestParsePreferTimeZoneInvalid checks an unknown or empty zone is an offender: +// ignored (no echo) under lenient, a PGRST122 under strict. +func TestParsePreferTimeZoneInvalid(t *testing.T) { + for _, v := range []string{"Mars/Phobos", "Not_A_Zone", ""} { + lenient := ParsePrefer([]string{"timezone=" + v}) + if lenient.TimeZone != nil { + t.Errorf("timezone=%q set TimeZone = %v, want nil", v, *lenient.TimeZone) + } + if lenient.AppliedHeader() != "" { + t.Errorf("timezone=%q lenient AppliedHeader = %q, want empty", v, lenient.AppliedHeader()) + } + strict := ParsePrefer([]string{"handling=strict, timezone=" + v}) + if strict.StrictError() == nil { + t.Errorf("timezone=%q under strict should be a PGRST122 offender", v) + } + } +} + func TestAppliedHeaderSkipsUnknownAndEmpty(t *testing.T) { p := ParsePrefer([]string{"return=bogus, frobnicate=yes, count=exact"}) if got, want := p.AppliedHeader(), "count=exact"; got != want { diff --git a/ir/writebody_test.go b/ir/writebody_test.go index 5892446..45a9dfc 100644 --- a/ir/writebody_test.go +++ b/ir/writebody_test.go @@ -3,15 +3,17 @@ package ir import "testing" // A CSV insert body decodes to one row per data line, keyed by the header. The -// header order fixes the write column order, and an empty field is SQL NULL. +// header order fixes the write column order. PostgREST's CSV null rule is that +// only the unquoted literal NULL is SQL null; an empty cell is the empty string +// (item 01.16). func TestParseWriteCSVBody(t *testing.T) { - body := []byte("title,year\nDune,2021\nArrival,\n") + body := []byte("title,year,note\nDune,2021,good\nArrival,,NULL\n") q, err := ParseWrite(Insert, "films", "", nil, "text/csv", body) if err != nil { t.Fatalf("ParseWrite CSV: %v", err) } - if got := q.Write.Columns; len(got) != 2 || got[0] != "title" || got[1] != "year" { - t.Fatalf("Columns = %v, want [title year] in header order", got) + if got := q.Write.Columns; len(got) != 3 || got[0] != "title" || got[1] != "year" { + t.Fatalf("Columns = %v, want [title year note] in header order", got) } if len(q.Write.Rows) != 2 { t.Fatalf("Rows = %d, want 2", len(q.Write.Rows)) @@ -19,9 +21,22 @@ func TestParseWriteCSVBody(t *testing.T) { if v := q.Write.Rows[0]["title"]; v.JSON != "Dune" { t.Errorf("row0 title = %#v, want Dune", v.JSON) } - // The empty year field on the second row is NULL, not the empty string. - if v := q.Write.Rows[1]["year"]; v.JSON != nil { - t.Errorf("row1 year = %#v, want nil (NULL)", v.JSON) + // The empty year cell on the second row is the empty string, not NULL. + if v := q.Write.Rows[1]["year"]; v.JSON != "" { + t.Errorf("row1 year = %#v, want the empty string", v.JSON) + } + // The literal NULL token is SQL null. + if v := q.Write.Rows[1]["note"]; v.JSON != nil { + t.Errorf("row1 note = %#v, want nil (NULL)", v.JSON) + } +} + +// A CSV body whose data row has a different field count than the header is a +// PGRST102 "All lines must have same number of fields" (item 01.16). +func TestParseWriteCSVRaggedRejected(t *testing.T) { + _, err := ParseWrite(Insert, "films", "", nil, "text/csv", []byte("title,year\nDune\n")) + if err == nil || err.Code != "PGRST102" { + t.Fatalf("ragged CSV err = %v, want PGRST102", err) } } @@ -53,6 +68,37 @@ func TestParseWriteCSVMalformedRejected(t *testing.T) { } } +// A malformed JSON insert body is v14's PGRST102 at 400 with the canonical +// "Empty or invalid json" message, not a PGRST100 query-parse error (item 04.1). +func TestParseWriteMalformedJSONIsPGRST102(t *testing.T) { + _, err := ParseWrite(Insert, "films", "", nil, "application/json", []byte("{not json")) + if err == nil || err.Code != "PGRST102" { + t.Fatalf("malformed JSON err = %v, want PGRST102", err) + } + if err.HTTPStatus != 400 { + t.Errorf("status = %d, want 400", err.HTTPStatus) + } + if err.Message != "Empty or invalid json" { + t.Errorf("message = %q, want 'Empty or invalid json'", err.Message) + } +} + +// A request body whose Content-Type no parser handles is PGRST102 at 400 with +// "Content-Type not acceptable: ", not the stale 415 PGRST107 (item 04.1). +// PGRST107 stays reserved for Accept negotiation, which is always a 406. +func TestParseWriteUnsupportedContentTypeIsPGRST102(t *testing.T) { + _, err := ParseWrite(Insert, "films", "", nil, "application/x-yaml", []byte("title: Dune")) + if err == nil || err.Code != "PGRST102" { + t.Fatalf("unsupported content-type err = %v, want PGRST102", err) + } + if err.HTTPStatus != 400 { + t.Errorf("status = %d, want 400", err.HTTPStatus) + } + if err.Message != "Content-Type not acceptable: application/x-yaml" { + t.Errorf("message = %q", err.Message) + } +} + // A form-urlencoded insert body decodes to a single row of string columns. func TestParseWriteFormBody(t *testing.T) { q, err := ParseWrite(Insert, "films", "", nil, @@ -97,17 +143,39 @@ func TestParseWriteUpdateFormBody(t *testing.T) { func TestParseWriteUnsupportedMediaType(t *testing.T) { _, err := ParseWrite(Insert, "films", "", nil, "text/yaml", []byte("title: X")) - if err == nil || err.Code != "PGRST107" { - t.Fatalf("insert with unknown media type err = %v, want PGRST107", err) + if err == nil || err.Code != "PGRST102" { + t.Fatalf("insert with unknown media type err = %v, want PGRST102", err) + } +} + +// PostgREST accepts CSV for PATCH as well as POST, so a single-row CSV update +// body decodes to the column assignments (item 01.16). +func TestParseWriteUpdateCSVAccepted(t *testing.T) { + q, err := ParseWrite(Update, "films", "id=eq.1", nil, "text/csv", []byte("rating\nPG\n")) + if err != nil { + t.Fatalf("ParseWrite update CSV: %v", err) + } + if v := q.Write.Set["rating"]; v.JSON != "PG" { + t.Errorf("set rating = %#v, want PG", v.JSON) + } +} + +// A bulk JSON insert whose objects do not share the first object's keys is +// PGRST102 "All object keys must match" unless columns= overrides (item 01.15). +func TestParseWriteRaggedJSONRejected(t *testing.T) { + body := []byte(`[{"title":"A","year":2020},{"title":"B"}]`) + _, err := ParseWrite(Insert, "films", "", nil, "application/json", body) + if err == nil || err.Code != "PGRST102" { + t.Fatalf("ragged JSON array err = %v, want PGRST102", err) } } -// CSV is not a patch format, so an update body in CSV is rejected as an -// unsupported media type rather than silently parsed. -func TestParseWriteUpdateCSVRejected(t *testing.T) { - _, err := ParseWrite(Update, "films", "id=eq.1", nil, "text/csv", []byte("rating\nPG\n")) - if err == nil || err.Code != "PGRST107" { - t.Fatalf("update with CSV err = %v, want PGRST107", err) +// With columns= present the ragged-array check is skipped (RawJSON semantics): +// absent keys take the missing= behavior and extra keys are ignored. +func TestParseWriteRaggedJSONWithColumnsOK(t *testing.T) { + body := []byte(`[{"title":"A","year":2020},{"title":"B"}]`) + if _, err := ParseWrite(Insert, "films", "columns=title,year", nil, "application/json", body); err != nil { + t.Fatalf("ParseWrite with columns= should accept a ragged array: %v", err) } } @@ -221,3 +289,38 @@ func TestParseEmbedScopedParam(t *testing.T) { t.Errorf("embed Order = %+v, want name desc", emb.Order) } } + +// TestProjectedColumns covers the write-representation column projection helper +// (item 01.19): a plain base-column select narrows the returning set, while any +// shape the bare RETURNING path cannot reshape falls back to all columns (nil). +func TestProjectedColumns(t *testing.T) { + col := func(name string) SelectItem { return Column{Path: []string{name}} } + cases := []struct { + name string + q Query + want []string + }{ + {"plain list", Query{Select: []SelectItem{col("id"), col("title")}}, []string{"id", "title"}}, + {"dedup", Query{Select: []SelectItem{col("id"), col("id")}}, []string{"id"}}, + {"empty select", Query{}, nil}, + {"star", Query{Select: []SelectItem{Column{Path: []string{"*"}}}}, nil}, + {"alias falls back", Query{Select: []SelectItem{Column{Path: []string{"title"}, Alias: "t"}}}, nil}, + {"cast falls back", Query{Select: []SelectItem{Column{Path: []string{"id"}, Cast: "text"}}}, nil}, + {"json path falls back", Query{Select: []SelectItem{Column{Path: []string{"data", "k"}, Last: JSONArrow2}}}, nil}, + {"aggregate falls back", Query{Select: []SelectItem{Aggregate{Func: AggCount}}}, nil}, + {"embed present falls back", Query{Select: []SelectItem{col("id")}, Embeds: []Embed{{}}}, nil}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := tc.q.ProjectedColumns() + if len(got) != len(tc.want) { + t.Fatalf("ProjectedColumns() = %v, want %v", got, tc.want) + } + for i := range got { + if got[i] != tc.want[i] { + t.Fatalf("ProjectedColumns() = %v, want %v", got, tc.want) + } + } + }) + } +} diff --git a/openapi/document.go b/openapi/document.go index 3e4a7bc..6a0c9f7 100644 --- a/openapi/document.go +++ b/openapi/document.go @@ -17,11 +17,21 @@ type document struct { Definitions map[string]*schemaObject `json:"definitions"` Parameters map[string]*parameter `json:"parameters,omitempty"` SecurityDefinitions map[string]*securityScheme `json:"securityDefinitions,omitempty"` + Security []map[string][]string `json:"security,omitempty"` + ExternalDocs *externalDocs `json:"externalDocs,omitempty"` } type info struct { - Title string `json:"title"` - Version string `json:"version"` + Title string `json:"title"` + Description string `json:"description,omitempty"` + Version string `json:"version"` +} + +// externalDocs is the document-level pointer at the API reference, the block +// PostgREST emits pointing at its own documentation. +type externalDocs struct { + Description string `json:"description,omitempty"` + URL string `json:"url"` } type pathItem struct { @@ -32,33 +42,43 @@ type pathItem struct { } type operation struct { - Tags []string `json:"tags,omitempty"` - Summary string `json:"summary,omitempty"` - Parameters []*parameter `json:"parameters,omitempty"` - Responses map[string]*response `json:"responses"` - Security []map[string][]string `json:"security,omitempty"` + Tags []string `json:"tags,omitempty"` + Summary string `json:"summary,omitempty"` + Description string `json:"description,omitempty"` + Produces []string `json:"produces,omitempty"` + Parameters []*parameter `json:"parameters,omitempty"` + Responses map[string]*response `json:"responses"` } +// parameter is either a $ref to a shared definition (only Ref set) or an +// inline definition. Required is a pointer so a defined parameter carries an +// explicit "required": false the way PostgREST emits it, while a pure $ref +// entry carries nothing but the reference. type parameter struct { Ref string `json:"$ref,omitempty"` Name string `json:"name,omitempty"` In string `json:"in,omitempty"` Description string `json:"description,omitempty"` - Required bool `json:"required,omitempty"` + Required *bool `json:"required,omitempty"` Type string `json:"type,omitempty"` Format string `json:"format,omitempty"` + Enum []string `json:"enum,omitempty"` + Default any `json:"default,omitempty"` Schema *schemaObject `json:"schema,omitempty"` } type response struct { - Description string `json:"description"` + Description string `json:"description"` + Schema *schemaObject `json:"schema,omitempty"` } type schemaObject struct { - Ref string `json:"$ref,omitempty"` - Type string `json:"type,omitempty"` - Required []string `json:"required,omitempty"` - Properties map[string]*propertySchema `json:"properties,omitempty"` + Ref string `json:"$ref,omitempty"` + Description string `json:"description,omitempty"` + Type string `json:"type,omitempty"` + Items *schemaObject `json:"items,omitempty"` + Required []string `json:"required,omitempty"` + Properties map[string]*propertySchema `json:"properties,omitempty"` } type propertySchema struct { @@ -68,7 +88,8 @@ type propertySchema struct { } type securityScheme struct { - Type string `json:"type"` - Name string `json:"name"` - In string `json:"in"` + Type string `json:"type"` + Name string `json:"name"` + In string `json:"in"` + Description string `json:"description,omitempty"` } diff --git a/openapi/openapi.go b/openapi/openapi.go index 2982d2e..fba63bd 100644 --- a/openapi/openapi.go +++ b/openapi/openapi.go @@ -11,7 +11,6 @@ package openapi import ( "encoding/json" - "sort" "strings" "github.com/tamnd/dbrest/backend" @@ -24,28 +23,79 @@ import ( // MediaType is the content type PostgREST serves the root with. const MediaType = "application/openapi+json" +// docsURL is the externalDocs target, the v14 PostgREST reference the emitted +// document points at, matching the live v14 output byte for byte. +const docsURL = "https://postgrest.org/en/v14/references/api.html" + +// bodyMediaTypes is the top-level consumes and produces list v14 advertises: +// JSON, the singular-object vendor types, and CSV. +var bodyMediaTypes = []string{ + "application/json", + "application/vnd.pgrst.object+json;nulls=stripped", + "application/vnd.pgrst.object+json", + "text/csv", +} + +// rpcMediaTypes is the per-operation produces list on /rpc paths; v14 leaves +// CSV off there. +var rpcMediaTypes = []string{ + "application/json", + "application/vnd.pgrst.object+json;nulls=stripped", + "application/vnd.pgrst.object+json", +} + // Options configures the emitted document's identity, server block, and // security. The host/basePath/schemes are the externally visible address (the // listen address, or the proxy URL once that configuration lands, spec 20). type Options struct { - Title string // document title; defaults to "dbrest" + // Title and Description are the info block, sourced from the active + // schema's database comment when one exists (first line and remainder); + // the defaults are the strings v14 emits for an uncommented schema. + Title string + Description string + Version string // info.version; defaults to the compat target "14.0" Host string // host:port the API is reached at BasePath string // mount path; defaults to "/" Schemes []string // url schemes; defaults to ["http"] - // JWT advertises a bearer security scheme in securityDefinitions when true, - // matching a server with JWT auth configured (spec 13). - JWT bool - // SecurityActive attaches the security requirement to every operation, the - // PostgREST openapi-security-active setting (spec 20). With JWT defined but - // this false, the scheme is described but not enforced, PostgREST's default. + // ActiveSchema is the schema the document describes: the request's + // profile-negotiated schema, so a multi-schema deployment serves one + // document per schema and same-named relations never collide on a path key. + ActiveSchema string + + // Visibility filters the document for openapi-mode=follow-privileges: it + // answers which operations the requesting role may perform on a relation, + // and a relation with none is dropped from paths and definitions, so an + // anonymous caller cannot enumerate what it cannot touch. Nil emits + // everything, the ignore-privileges mode. + Visibility func(rel *schema.Relation) Actions + + // SecurityActive is the PostgREST openapi-security-active setting (spec + // 20): when true the document carries the JWT securityDefinitions block + // and a document-level security requirement; when false (the default) + // neither is emitted, exactly as v14 behaves regardless of whether JWT + // auth is configured. SecurityActive bool } +// Actions is the operation set a role holds on one relation, the visibility +// answer follow-privileges filters the document with. +type Actions struct { + Get, Post, Patch, Delete bool +} + +// AllActions marks every operation visible, the ignore-privileges answer. +var AllActions = Actions{Get: true, Post: true, Patch: true, Delete: true} + +func (a Actions) any() bool { return a.Get || a.Post || a.Patch || a.Delete } + func (o Options) withDefaults() Options { if o.Title == "" { - o.Title = "dbrest" + o.Title = "PostgREST API" + } + if o.Description == "" { + o.Description = "This is a dynamic API generated by PostgREST" } if o.Version == "" { o.Version = "14.0" @@ -72,84 +122,139 @@ func Generate(model *schema.Model, fns rpc.Registry, caps backend.Capabilities, func build(model *schema.Model, fns rpc.Registry, caps backend.Capabilities, opts Options) *document { opts = opts.withDefaults() doc := &document{ - Swagger: "2.0", - Info: info{Title: opts.Title, Version: opts.Version}, - Host: opts.Host, - BasePath: opts.BasePath, - Schemes: opts.Schemes, - Consumes: []string{"application/json"}, - Produces: []string{"application/json", "application/vnd.pgrst.object+json", "text/csv"}, - Paths: map[string]*pathItem{}, - Definitions: map[string]*schemaObject{}, - Parameters: reservedParameters(), + Swagger: "2.0", + Info: info{Title: opts.Title, Description: opts.Description, Version: opts.Version}, + Host: opts.Host, + BasePath: opts.BasePath, + Schemes: opts.Schemes, + Consumes: bodyMediaTypes, + Produces: bodyMediaTypes, + Paths: map[string]*pathItem{"/": rootPath()}, + Definitions: map[string]*schemaObject{}, + Parameters: reservedParameters(), + ExternalDocs: &externalDocs{Description: "PostgREST Documentation", URL: docsURL}, } ops := advertisedTokens(caps) - var security []map[string][]string - if opts.JWT { + if opts.SecurityActive { doc.SecurityDefinitions = map[string]*securityScheme{ - "JWT": {Type: "apiKey", Name: "Authorization", In: "header"}, - } - if opts.SecurityActive { - security = []map[string][]string{{"JWT": {}}} + "JWT": {Type: "apiKey", Name: "Authorization", In: "header", + Description: `Add the token prepending "Bearer " (without quotes) to it`}, } + doc.Security = []map[string][]string{{"JWT": {}}} } - for _, rel := range model.Relations() { - doc.Paths["/"+rel.Name] = relationPath(rel, ops, security) + for _, rel := range model.RelationsIn(opts.ActiveSchema) { + acts := AllActions + if opts.Visibility != nil { + acts = opts.Visibility(rel) + } + if rel.Kind == schema.KindView { + // A view path carries only get; a write-only grant leaves nothing + // to describe. + acts.Post, acts.Patch, acts.Delete = false, false, false + } + if !acts.any() { + continue + } + doc.Paths["/"+rel.Name] = relationPath(rel, acts) doc.Definitions[rel.Name] = relationDefinition(rel) + addRelationParameters(doc.Parameters, rel, ops) } if fns != nil { for _, fn := range fns.List() { - doc.Paths["/rpc/"+fn.Name] = functionPath(fn, security) + doc.Paths["/rpc/"+fn.Name] = functionPath(fn) } } return doc } -// relationPath emits the operations a relation supports. A base table gets the -// full read/write set; a view gets get only (updatable views land with the -// model flags that mark them so). Each operation lists the reserved parameters -// it honors plus one query parameter per column for horizontal filtering. -func relationPath(rel *schema.Relation, ops string, security []map[string][]string) *pathItem { - filters := columnParams(rel, ops) - get := &operation{ - Tags: []string{rel.Name}, - Parameters: concat(refs("select", "order", "limit", "offset", "rangeHeader", "preferRead"), filters), - Responses: okResponses("200", "OK"), - Security: security, +// rootPath is the "/" entry describing the root itself, byte-identical to the +// block v14 emits. +func rootPath() *pathItem { + return &pathItem{Get: &operation{ + Tags: []string{"Introspection"}, + Summary: "OpenAPI description (this document)", + Produces: []string{"application/openapi+json", "application/json"}, + Responses: map[string]*response{"200": {Description: "OK"}}, + }} +} + +// relationPath emits the operations a relation supports, each one referencing +// the shared rowFilter.
.and body.
parameter definitions +// the way v14 lays them out. A base table gets the full read/write set; a +// view gets get only (updatable views land with the model flags that mark +// them so). The acts set drops any operation the requesting role may not +// perform. +func relationPath(rel *schema.Relation, acts Actions) *pathItem { + filters := rowFilterRefs(rel) + body := refs("body." + rel.Name) + // The relation's database comment annotates every operation of the path: + // first line as the summary, the rest as the description, as v14 emits it. + summary, description := splitComment(rel.Comment) + p := &pathItem{} + if acts.Get { + p.Get = &operation{ + Tags: []string{rel.Name}, + Summary: summary, + Description: description, + Parameters: concat(filters, refs("select", "order", "range", "rangeUnit", "offset", "limit", "preferCount")), + Responses: map[string]*response{ + "200": {Description: "OK", Schema: &schemaObject{ + Type: "array", + Items: &schemaObject{Ref: "#/definitions/" + rel.Name}, + }}, + "206": {Description: "Partial Content"}, + }, + } } - p := &pathItem{Get: get} if rel.Kind == schema.KindTable { - bodyRef := "#/definitions/" + rel.Name - p.Post = &operation{ - Tags: []string{rel.Name}, - Parameters: concat(refs("select", "columns", "on_conflict", "preferWrite"), []*parameter{bodyParam(rel.Name, bodyRef)}), - Responses: okResponses("201", "Created"), - Security: security, + if acts.Post { + p.Post = &operation{ + Tags: []string{rel.Name}, + Summary: summary, + Description: description, + Parameters: concat(body, refs("select", "preferPost")), + Responses: okResponses("201", "Created"), + } } - p.Patch = &operation{ - Tags: []string{rel.Name}, - Parameters: concat(refs("select", "columns", "preferWrite"), filters, []*parameter{bodyParam(rel.Name, bodyRef)}), - Responses: okResponses("204", "No Content"), - Security: security, + if acts.Patch { + p.Patch = &operation{ + Tags: []string{rel.Name}, + Summary: summary, + Description: description, + Parameters: concat(filters, body, refs("preferReturn")), + Responses: okResponses("204", "No Content"), + } } - p.Delete = &operation{ - Tags: []string{rel.Name}, - Parameters: concat(refs("preferWrite"), filters), - Responses: okResponses("204", "No Content"), - Security: security, + if acts.Delete { + p.Delete = &operation{ + Tags: []string{rel.Name}, + Summary: summary, + Description: description, + Parameters: concat(filters, refs("preferReturn")), + Responses: okResponses("204", "No Content"), + } } } return p } +// splitComment divides a database comment the way v14 reads it: the first +// line becomes the summary, everything after the first newline the +// description, both empty when there is no comment. +func splitComment(comment string) (summary, description string) { + summary, description, _ = strings.Cut(comment, "\n") + return summary, strings.TrimSpace(description) +} + // relationDefinition builds the schema object for a relation from its columns. // A property's type/format comes from the column's canonical type; required -// lists the non-nullable columns without a default; the primary key and foreign -// keys surface in the property descriptions the way PostgREST annotates them. +// lists every non-nullable column in column order (a column with a default is +// still required, matching v14); the primary key and foreign keys surface in +// the property descriptions the way PostgREST annotates them. func relationDefinition(rel *schema.Relation) *schemaObject { - def := &schemaObject{Type: "object", Properties: map[string]*propertySchema{}} + def := &schemaObject{Type: "object", Properties: map[string]*propertySchema{}, Description: rel.Comment} pk := map[string]bool{} for _, c := range rel.PrimaryKey { pk[c] = true @@ -157,34 +262,45 @@ func relationDefinition(rel *schema.Relation) *schemaObject { for _, col := range rel.Columns { typ, format := swaggerType(col.Type) prop := &propertySchema{Type: typ, Format: format} - prop.Description = columnNote(col.Name, pk, rel.ForeignKeys) + prop.Description = columnDescription(col, pk, rel.ForeignKeys) def.Properties[col.Name] = prop - if !col.Nullable && !col.HasDefault { + if !col.Nullable { def.Required = append(def.Required, col.Name) } } - sort.Strings(def.Required) return def } -// functionPath emits the /rpc/ path. A read-only function (stable or -// immutable) is callable by GET with its arguments as query parameters and by -// POST with a body schema; a volatile function is POST only. See spec 12. -func functionPath(fn *rpc.Function, security []map[string][]string) *pathItem { +// functionPath emits the /rpc/ path with v14's layout: the "(rpc) " +// tag, the JSON-only produces list, a POST whose args body is always required +// followed by the preferParams reference, and a GET with inline query +// parameters for a read-only function (stable or immutable); a volatile +// function is POST only. See spec 12. +func functionPath(fn *rpc.Function) *pathItem { + // The function's database comment annotates its operations the same way a + // relation's does: first line as the summary, the rest as the description. + summary, description := splitComment(fn.Comment) p := &pathItem{} if fn.Volatility.ReadOnly() { p.Get = &operation{ - Tags: []string{fn.Name}, - Parameters: functionQueryParams(fn), - Responses: okResponses("200", "OK"), - Security: security, + Tags: []string{"(rpc) " + fn.Name}, + Summary: summary, + Description: description, + Produces: rpcMediaTypes, + Parameters: functionQueryParams(fn), + Responses: okResponses("200", "OK"), } } p.Post = &operation{ - Tags: []string{fn.Name}, - Parameters: []*parameter{{In: "body", Name: "args", Required: len(fn.Required()) > 0, Schema: functionBodySchema(fn)}}, - Responses: okResponses("200", "OK"), - Security: security, + Tags: []string{"(rpc) " + fn.Name}, + Summary: summary, + Description: description, + Produces: rpcMediaTypes, + Parameters: concat( + []*parameter{{In: "body", Name: "args", Required: boolPtr(true), Schema: functionBodySchema(fn)}}, + refs("preferParams"), + ), + Responses: okResponses("200", "OK"), } return p } @@ -193,7 +309,7 @@ func functionQueryParams(fn *rpc.Function) []*parameter { out := make([]*parameter, 0, len(fn.Params)) for _, pm := range fn.Params { typ, format := swaggerType(pm.Type) - out = append(out, ¶meter{Name: pm.Name, In: "query", Required: !pm.Optional, Type: typ, Format: format}) + out = append(out, ¶meter{Name: pm.Name, In: "query", Required: boolPtr(!pm.Optional), Type: typ, Format: format}) } return out } @@ -207,23 +323,52 @@ func functionBodySchema(fn *rpc.Function) *schemaObject { s.Required = append(s.Required, pm.Name) } } - sort.Strings(s.Required) return s } -// columnParams builds one query parameter per column, described with the -// operator grammar the backend can actually serve. -func columnParams(rel *schema.Relation, ops string) []*parameter { - out := make([]*parameter, 0, len(rel.Columns)) +// addRelationParameters defines the shared per-relation parameters operations +// reference: one rowFilter.
.per column and, for a table, the +// body.
payload. A rowFilter's description starts with the column's +// database comment the way v14 emits it, then carries the operator grammar +// the backend can actually serve, dbrest's capability advertisement (spec +// 19); v14 puts nothing after the comment, so the addition is tooling-safe. +// The body parameter's description is the relation comment, v14's fallback +// to the bare name when there is none. +func addRelationParameters(params map[string]*parameter, rel *schema.Relation, ops string) { for _, col := range rel.Columns { - _, format := swaggerType(col.Type) - out = append(out, ¶meter{ + desc := ops + if col.Comment != "" { + desc = col.Comment + "\n\n" + ops + } + params["rowFilter."+rel.Name+"."+col.Name] = ¶meter{ Name: col.Name, In: "query", + Required: boolPtr(false), Type: "string", - Format: format, - Description: ops, - }) + Description: desc, + } + } + if rel.Kind == schema.KindTable { + bodyDesc := rel.Name + if rel.Comment != "" { + bodyDesc = rel.Comment + } + params["body."+rel.Name] = ¶meter{ + Name: rel.Name, + In: "body", + Required: boolPtr(false), + Description: bodyDesc, + Schema: &schemaObject{Ref: "#/definitions/" + rel.Name}, + } + } +} + +// rowFilterRefs lists the relation's rowFilter references in column order, the +// horizontal-filter block every read and filtered write starts with. +func rowFilterRefs(rel *schema.Relation) []*parameter { + out := make([]*parameter, 0, len(rel.Columns)) + for _, col := range rel.Columns { + out = append(out, ¶meter{Ref: "#/parameters/rowFilter." + rel.Name + "." + col.Name}) } return out } @@ -307,48 +452,73 @@ func swaggerType(canonical string) (typ, format string) { } } -// columnNote builds the PostgREST property annotation: a primary-key note, a -// foreign-key note, or both, joined the way PostgREST concatenates them. An -// unannotated column gets an empty description. +// columnDescription builds the property description the way v14 lays it out: +// the column's database comment first, then the key annotations, a blank line +// between them. Each note carries the machine-readable or +// marker tooling parses out of the v14 +// documents. An uncommented, unannotated column gets an empty description. +func columnDescription(col *schema.Column, pk map[string]bool, fks []*schema.ForeignKey) string { + note := columnNote(col.Name, pk, fks) + switch { + case col.Comment == "": + return note + case note == "": + return col.Comment + default: + return col.Comment + "\n\n" + note + } +} + +// columnNote builds the PostgREST key annotation: a primary-key note, a +// foreign-key note, or both, empty for a plain column. func columnNote(name string, pk map[string]bool, fks []*schema.ForeignKey) string { var notes []string if pk[name] { - notes = append(notes, "Note:\nThis is a Primary Key.") + notes = append(notes, "Note:\nThis is a Primary Key.") } for _, fk := range fks { - for i, c := range fk.Columns { - if c != name { - continue - } - ref := fk.RefRelation - if i < len(fk.RefColumns) { - ref += "." + fk.RefColumns[i] - } - notes = append(notes, "Note:\nThis is a Foreign Key to `"+ref+"`.") + // v14 only annotates a single-column foreign key: makeProperty matches a + // relationship whose local columns are exactly [this column], so a composite + // FK gets no note on any of its columns. Match that. + if len(fk.Columns) != 1 || fk.Columns[0] != name { + continue } + ref := fk.RefRelation + col := "" + if len(fk.RefColumns) > 0 { + col = fk.RefColumns[0] + ref += "." + col + } + notes = append(notes, + "Note:\nThis is a Foreign Key to `"+ref+"`.") } return strings.Join(notes, "\n") } // reservedParameters defines the shared parameters operations reference by -// $ref, mirroring the reserved query and header grammar (spec 02). +// $ref. Names, descriptions, and enums match the v14 document exactly, +// including the on_conflict entry that is defined without being referenced. func reservedParameters() map[string]*parameter { return map[string]*parameter{ - "select": {Name: "select", In: "query", Type: "string", Description: "Filtering and renaming columns"}, - "order": {Name: "order", In: "query", Type: "string", Description: "Ordering"}, - "limit": {Name: "limit", In: "query", Type: "integer", Description: "Limiting and pagination"}, - "offset": {Name: "offset", In: "query", Type: "integer", Description: "Limiting and pagination"}, - "on_conflict": {Name: "on_conflict", In: "query", Type: "string", Description: "On conflict resolution columns"}, - "columns": {Name: "columns", In: "query", Type: "string", Description: "Restricting and ordering inserted columns"}, - "preferRead": {Name: "Prefer", In: "header", Type: "string", Description: "Preference: count, return"}, - "preferWrite": {Name: "Prefer", In: "header", Type: "string", Description: "Preference: return, resolution, missing"}, - "rangeHeader": {Name: "Range", In: "header", Type: "string", Description: "Limiting and pagination"}, + "select": {Name: "select", In: "query", Required: boolPtr(false), Type: "string", Description: "Filtering Columns"}, + "on_conflict": {Name: "on_conflict", In: "query", Required: boolPtr(false), Type: "string", Description: "On Conflict"}, + "order": {Name: "order", In: "query", Required: boolPtr(false), Type: "string", Description: "Ordering"}, + "range": {Name: "Range", In: "header", Required: boolPtr(false), Type: "string", Description: "Limiting and Pagination"}, + "rangeUnit": {Name: "Range-Unit", In: "header", Required: boolPtr(false), Type: "string", Description: "Limiting and Pagination", Default: "items"}, + "offset": {Name: "offset", In: "query", Required: boolPtr(false), Type: "string", Description: "Limiting and Pagination"}, + "limit": {Name: "limit", In: "query", Required: boolPtr(false), Type: "string", Description: "Limiting and Pagination"}, + "preferParams": {Name: "Prefer", In: "header", Required: boolPtr(false), Type: "string", + Description: "Preference"}, + "preferReturn": {Name: "Prefer", In: "header", Required: boolPtr(false), Type: "string", + Description: "Preference", Enum: []string{"return=representation", "return=minimal", "return=none"}}, + "preferCount": {Name: "Prefer", In: "header", Required: boolPtr(false), Type: "string", + Description: "Preference", Enum: []string{"count=none"}}, + "preferPost": {Name: "Prefer", In: "header", Required: boolPtr(false), Type: "string", + Description: "Preference", Enum: []string{"return=representation", "return=minimal", "return=none", "resolution=ignore-duplicates", "resolution=merge-duplicates"}}, } } -func bodyParam(name, ref string) *parameter { - return ¶meter{In: "body", Name: name, Schema: &schemaObject{Ref: ref}} -} +func boolPtr(b bool) *bool { return &b } func refs(names ...string) []*parameter { out := make([]*parameter, len(names)) diff --git a/openapi/openapi_test.go b/openapi/openapi_test.go index 68e6e76..4b2d755 100644 --- a/openapi/openapi_test.go +++ b/openapi/openapi_test.go @@ -72,8 +72,11 @@ func TestGenerateShape(t *testing.T) { t.Errorf("swagger = %v, want 2.0", doc["swagger"]) } info := doc["info"].(map[string]any) - if info["title"] != "dbrest" || info["version"] != "14.0" { - t.Errorf("info = %v, want dbrest/14.0", info) + if info["title"] != "PostgREST API" || info["version"] != "14.0" { + t.Errorf("info = %v, want the v14 defaults", info) + } + if info["description"] != "This is a dynamic API generated by PostgREST" { + t.Errorf("info description = %v", info["description"]) } if doc["host"] != "localhost:3000" { t.Errorf("host = %v", doc["host"]) @@ -87,6 +90,198 @@ func TestGenerateShape(t *testing.T) { } } +// TestInfoFromSchemaComment checks a provided title/description (sourced from +// the active schema's database comment) replaces the v14 defaults. +func TestInfoFromSchemaComment(t *testing.T) { + doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{Title: "Films API", Description: "All the films."}) + info := doc["info"].(map[string]any) + if info["title"] != "Films API" || info["description"] != "All the films." { + t.Errorf("info = %v", info) + } +} + +// TestDocumentFraming pins the v14 framing: externalDocs pointing at the +// PostgREST reference, the vendor media types in the top-level consumes and +// produces, and the "/" entry describing the root itself. +func TestDocumentFraming(t *testing.T) { + doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{}) + + ed, ok := doc["externalDocs"].(map[string]any) + if !ok { + t.Fatal("missing externalDocs") + } + if ed["description"] != "PostgREST Documentation" { + t.Errorf("externalDocs description = %v", ed["description"]) + } + if ed["url"] != "https://postgrest.org/en/v14/references/api.html" { + t.Errorf("externalDocs url = %v", ed["url"]) + } + + want := []string{ + "application/json", + "application/vnd.pgrst.object+json;nulls=stripped", + "application/vnd.pgrst.object+json", + "text/csv", + } + for _, key := range []string{"consumes", "produces"} { + got := doc[key].([]any) + if len(got) != len(want) { + t.Fatalf("%s = %v", key, got) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("%s[%d] = %v, want %v", key, i, got[i], want[i]) + } + } + } + + root, ok := doc["paths"].(map[string]any)["/"].(map[string]any) + if !ok { + t.Fatal(`missing the "/" path entry`) + } + get := root["get"].(map[string]any) + if get["summary"] != "OpenAPI description (this document)" { + t.Errorf("root summary = %v", get["summary"]) + } + tags := get["tags"].([]any) + if len(tags) != 1 || tags[0] != "Introspection" { + t.Errorf("root tags = %v", tags) + } + produces := get["produces"].([]any) + if len(produces) != 2 || produces[0] != "application/openapi+json" || produces[1] != "application/json" { + t.Errorf("root produces = %v", produces) + } +} + +// TestReservedParametersMatchV14 pins the shared parameter definitions against +// the live v14 document: names, locations, descriptions, enums, and the +// explicit required:false. +func TestReservedParametersMatchV14(t *testing.T) { + doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{}) + params := doc["parameters"].(map[string]any) + + cases := map[string]map[string]any{ + "select": {"name": "select", "in": "query", "description": "Filtering Columns"}, + "on_conflict": {"name": "on_conflict", "in": "query", "description": "On Conflict"}, + "order": {"name": "order", "in": "query", "description": "Ordering"}, + "range": {"name": "Range", "in": "header", "description": "Limiting and Pagination"}, + "rangeUnit": {"name": "Range-Unit", "in": "header", "description": "Limiting and Pagination", "default": "items"}, + "offset": {"name": "offset", "in": "query", "description": "Limiting and Pagination"}, + "limit": {"name": "limit", "in": "query", "description": "Limiting and Pagination"}, + "preferParams": {"name": "Prefer", "in": "header", "description": "Preference"}, + "preferReturn": {"name": "Prefer", "in": "header", "description": "Preference", + "enum": []any{"return=representation", "return=minimal", "return=none"}}, + "preferCount": {"name": "Prefer", "in": "header", "description": "Preference", + "enum": []any{"count=none"}}, + "preferPost": {"name": "Prefer", "in": "header", "description": "Preference", + "enum": []any{"return=representation", "return=minimal", "return=none", "resolution=ignore-duplicates", "resolution=merge-duplicates"}}, + } + for key, want := range cases { + got, ok := params[key].(map[string]any) + if !ok { + t.Errorf("missing parameter %q", key) + continue + } + if got["type"] != "string" { + t.Errorf("%s type = %v, want string", key, got["type"]) + } + if req, present := got["required"]; !present || req != false { + t.Errorf("%s required = %v, want explicit false", key, req) + } + for f, v := range want { + if f == "enum" { + gotEnum := got["enum"].([]any) + wantEnum := v.([]any) + if len(gotEnum) != len(wantEnum) { + t.Errorf("%s enum = %v, want %v", key, gotEnum, wantEnum) + continue + } + for i := range wantEnum { + if gotEnum[i] != wantEnum[i] { + t.Errorf("%s enum[%d] = %v, want %v", key, i, gotEnum[i], wantEnum[i]) + } + } + continue + } + if got[f] != v { + t.Errorf("%s %s = %v, want %v", key, f, got[f], v) + } + } + } +} + +// TestRelationOperationLayout pins each table operation's parameter reference +// sequence and responses against the v14 layout: GET is rowFilters then the +// read block with a 200-array-schema and a 206; POST is body/select/preferPost +// with a 201; PATCH is rowFilters/body/preferReturn and DELETE +// rowFilters/preferReturn, both 204. +func TestRelationOperationLayout(t *testing.T) { + doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{}) + films := doc["paths"].(map[string]any)["/films"].(map[string]any) + + filterRefs := []string{ + "#/parameters/rowFilter.films.id", + "#/parameters/rowFilter.films.title", + "#/parameters/rowFilter.films.year", + "#/parameters/rowFilter.films.director_id", + } + refSeq := func(op string) []string { + params := films[op].(map[string]any)["parameters"].([]any) + out := make([]string, len(params)) + for i, p := range params { + out[i], _ = p.(map[string]any)["$ref"].(string) + } + return out + } + assertSeq := func(op string, want []string) { + t.Helper() + got := refSeq(op) + if strings.Join(got, " ") != strings.Join(want, " ") { + t.Errorf("%s parameters = %v, want %v", op, got, want) + } + } + + assertSeq("get", append(append([]string{}, filterRefs...), + "#/parameters/select", "#/parameters/order", "#/parameters/range", "#/parameters/rangeUnit", + "#/parameters/offset", "#/parameters/limit", "#/parameters/preferCount")) + assertSeq("post", []string{"#/parameters/body.films", "#/parameters/select", "#/parameters/preferPost"}) + assertSeq("patch", append(append([]string{}, filterRefs...), + "#/parameters/body.films", "#/parameters/preferReturn")) + assertSeq("delete", append(append([]string{}, filterRefs...), "#/parameters/preferReturn")) + + get := films["get"].(map[string]any) + resp := get["responses"].(map[string]any) + ok200 := resp["200"].(map[string]any) + schema200 := ok200["schema"].(map[string]any) + if schema200["type"] != "array" { + t.Errorf("200 schema type = %v, want array", schema200["type"]) + } + if schema200["items"].(map[string]any)["$ref"] != "#/definitions/films" { + t.Errorf("200 schema items = %v", schema200["items"]) + } + if resp["206"].(map[string]any)["description"] != "Partial Content" { + t.Errorf("206 = %v", resp["206"]) + } + for op, want := range map[string]string{"post": "201", "patch": "204", "delete": "204"} { + r := films[op].(map[string]any)["responses"].(map[string]any) + if len(r) != 1 { + t.Errorf("%s responses = %v, want only %s", op, r, want) + } + if _, ok := r[want]; !ok { + t.Errorf("%s responses lack %s", op, want) + } + } + + // The shared body parameter names the table and references its definition. + body := doc["parameters"].(map[string]any)["body.films"].(map[string]any) + if body["name"] != "films" || body["in"] != "body" || body["description"] != "films" { + t.Errorf("body.films = %v", body) + } + if body["schema"].(map[string]any)["$ref"] != "#/definitions/films" { + t.Errorf("body.films schema = %v", body["schema"]) + } +} + func TestTableHasFullOperationSet(t *testing.T) { doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{}) films := doc["paths"].(map[string]any)["/films"].(map[string]any) @@ -123,14 +318,14 @@ func TestDefinitionTypesAndRequired(t *testing.T) { if id["type"] != "integer" || id["format"] != "integer" { t.Errorf("id = %v, want integer/integer", id) } - if !strings.Contains(id["description"].(string), "Primary Key") { - t.Errorf("id description = %v, want PK note", id["description"]) + if !strings.Contains(id["description"].(string), "Note:\nThis is a Primary Key.") { + t.Errorf("id description = %v, want the v14 PK marker", id["description"]) } title := props["title"].(map[string]any) if title["type"] != "string" || title["format"] != "text" { t.Errorf("title = %v, want string/text", title) } - // required = non-nullable columns without a default, sorted. + // required = every NOT NULL column in column order, as v14 lists them. req := films["required"].([]any) got := make([]string, len(req)) for i, c := range req { @@ -141,32 +336,96 @@ func TestDefinitionTypesAndRequired(t *testing.T) { } } +// TestRequiredIncludesDefaultedColumns pins the v14 quirk: a NOT NULL column +// stays in required even when it has a default (live 14.12 lists serial PKs). +func TestRequiredIncludesDefaultedColumns(t *testing.T) { + rel := &schema.Relation{ + Name: "items", Kind: schema.KindTable, + Columns: []*schema.Column{ + {Name: "id", Type: "integer", Nullable: false, HasDefault: true, Position: 1}, + {Name: "label", Type: "text", Nullable: false, Position: 2}, + {Name: "note", Type: "text", Nullable: true, Position: 3}, + }, + PrimaryKey: []string{"id"}, + } + doc := decode(t, schema.NewModel([]*schema.Relation{rel}), nil, sqliteCaps(), Options{}) + items := doc["definitions"].(map[string]any)["items"].(map[string]any) + req := items["required"].([]any) + got := make([]string, len(req)) + for i, c := range req { + got[i] = c.(string) + } + if strings.Join(got, ",") != "id,label" { + t.Errorf("required = %v, want [id label]", got) + } +} + func TestForeignKeyNote(t *testing.T) { doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{}) props := doc["definitions"].(map[string]any)["films"].(map[string]any)["properties"].(map[string]any) fk := props["director_id"].(map[string]any) desc, _ := fk["description"].(string) - if !strings.Contains(desc, "Foreign Key to `directors.id`") { - t.Errorf("director_id description = %q, want FK note", desc) + if !strings.Contains(desc, "Note:\nThis is a Foreign Key to `directors.id`.") { + t.Errorf("director_id description = %q, want the v14 FK marker", desc) + } +} + +// TestCompositeForeignKeyHasNoNote: v14 annotates only single-column foreign +// keys, so a composite FK leaves its columns unannotated. A computed +// relationship is a function with no local column, so it likewise never appears +// in the document; both are covered by sourcing notes from real single-column FKs. +func TestCompositeForeignKeyHasNoNote(t *testing.T) { + rel := &schema.Relation{ + Name: "assignments", + Columns: []*schema.Column{ + {Name: "org_id", Type: "integer"}, + {Name: "user_id", Type: "integer"}, + {Name: "role", Type: "text"}, + }, + ForeignKeys: []*schema.ForeignKey{ + {Name: "assignments_member_fkey", Columns: []string{"org_id", "user_id"}, RefRelation: "members", RefColumns: []string{"org_id", "user_id"}}, + }, + } + doc := decode(t, schema.NewModel([]*schema.Relation{rel}), nil, sqliteCaps(), Options{}) + props := doc["definitions"].(map[string]any)["assignments"].(map[string]any)["properties"].(map[string]any) + for _, c := range []string{"org_id", "user_id"} { + desc, _ := props[c].(map[string]any)["description"].(string) + if strings.Contains(desc, "Foreign Key") { + t.Errorf("%s description = %q, want no FK note for a composite key", c, desc) + } } } +// rowFilterDesc resolves a column's rowFilter parameter through the shared +// parameters map, the way a client follows the $ref an operation carries. +func rowFilterDesc(t *testing.T, doc map[string]any, table, col string) string { + t.Helper() + params, ok := doc["parameters"].(map[string]any) + if !ok { + t.Fatal("document has no parameters map") + } + p, ok := params["rowFilter."+table+"."+col].(map[string]any) + if !ok { + t.Fatalf("no rowFilter.%s.%s parameter found", table, col) + } + if p["name"] != col || p["in"] != "query" || p["type"] != "string" { + t.Errorf("rowFilter.%s.%s = %v", table, col, p) + } + if req, present := p["required"]; !present || req != false { + t.Errorf("rowFilter.%s.%s required = %v, want explicit false", table, col, req) + } + desc, _ := p["description"].(string) + return desc +} + // TestOperatorAdvertisingHonorsCapabilities is the heart of the contract: a // column parameter advertises match/imatch and the fts family on SQLite (regex // and FTS5 present) but never the array/range operators (Unsupported). func TestOperatorAdvertisingHonorsCapabilities(t *testing.T) { doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{}) - params := doc["paths"].(map[string]any)["/films"].(map[string]any)["get"].(map[string]any)["parameters"].([]any) - - var colDesc string - for _, p := range params { - pm := p.(map[string]any) - if pm["name"] == "title" { - colDesc = pm["description"].(string) - } - } + colDesc := rowFilterDesc(t, doc, "films", "title") if colDesc == "" { - t.Fatal("no title column parameter found") + t.Fatal("title rowFilter parameter has no description") } for _, want := range []string{"eq", "match", "imatch", "fts", "plfts", "phfts", "wfts"} { if !containsToken(colDesc, want) { @@ -187,43 +446,52 @@ func TestRegexOmittedWhenUnsupported(t *testing.T) { caps.Regex = backend.Unsupported caps.FullText = backend.FTNone doc := decode(t, filmsModel(), nil, caps, Options{}) - params := doc["paths"].(map[string]any)["/films"].(map[string]any)["get"].(map[string]any)["parameters"].([]any) - for _, p := range params { - pm := p.(map[string]any) - if pm["name"] != "title" { - continue - } - desc := pm["description"].(string) - for _, gone := range []string{"match", "imatch", "fts"} { - if containsToken(desc, gone) { - t.Errorf("operator %q should be omitted; desc = %q", gone, desc) - } + desc := rowFilterDesc(t, doc, "films", "title") + for _, gone := range []string{"match", "imatch", "fts"} { + if containsToken(desc, gone) { + t.Errorf("operator %q should be omitted; desc = %q", gone, desc) } } } -func TestSecurityDefinitionsWhenJWT(t *testing.T) { - doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{JWT: true}) +func TestSecurityInactiveOmitsDefinitions(t *testing.T) { + // PostgREST's default: with openapi-security-active off the document has + // neither securityDefinitions nor a security requirement, even with JWT + // auth configured on the server. + doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{}) + if _, ok := doc["securityDefinitions"]; ok { + t.Error("securityDefinitions should be absent when security is inactive") + } + if _, ok := doc["security"]; ok { + t.Error("security should be absent when security is inactive") + } +} + +func TestSecurityActiveEmitsJWTScheme(t *testing.T) { + doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{SecurityActive: true}) sd, ok := doc["securityDefinitions"].(map[string]any) if !ok { - t.Fatal("no securityDefinitions with JWT enabled") + t.Fatal("no securityDefinitions with security active") } jwt := sd["JWT"].(map[string]any) if jwt["type"] != "apiKey" || jwt["name"] != "Authorization" || jwt["in"] != "header" { t.Errorf("JWT scheme = %v", jwt) } - // Default: the scheme is defined but not attached to operations. - get := doc["paths"].(map[string]any)["/films"].(map[string]any)["get"].(map[string]any) - if _, attached := get["security"]; attached { - t.Error("security should not be attached to operations by default") + if jwt["description"] != `Add the token prepending "Bearer " (without quotes) to it` { + t.Errorf("JWT description = %v", jwt["description"]) } -} - -func TestSecurityActiveAttachesRequirement(t *testing.T) { - doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{JWT: true, SecurityActive: true}) + // The requirement is document-level, the way v14 attaches it. + sec, ok := doc["security"].([]any) + if !ok || len(sec) != 1 { + t.Fatalf("security = %v, want one document-level requirement", doc["security"]) + } + if _, ok := sec[0].(map[string]any)["JWT"]; !ok { + t.Errorf("security requirement = %v, want JWT", sec[0]) + } + // Operations carry no per-operation security in v14. get := doc["paths"].(map[string]any)["/films"].(map[string]any)["get"].(map[string]any) - if _, attached := get["security"]; !attached { - t.Error("security should be attached when SecurityActive is set") + if _, attached := get["security"]; attached { + t.Error("security should not be attached per operation") } } @@ -252,8 +520,48 @@ func TestRPCPaths(t *testing.T) { if _, ok := add["get"]; !ok { t.Error("immutable function should be callable by GET") } - if _, ok := add["post"]; !ok { - t.Error("function should be callable by POST") + post, ok := add["post"].(map[string]any) + if !ok { + t.Fatal("function should be callable by POST") + } + + // v14 tags rpc operations "(rpc) " and drops text/csv from produces. + tags := post["tags"].([]any) + if len(tags) != 1 || tags[0] != "(rpc) add" { + t.Errorf("rpc tags = %v, want [(rpc) add]", tags) + } + produces := post["produces"].([]any) + for _, mt := range produces { + if mt == "text/csv" { + t.Error("rpc produces should not list text/csv") + } + } + if len(produces) != 3 || produces[0] != "application/json" { + t.Errorf("rpc produces = %v", produces) + } + + // POST takes the args as one required body plus the preferParams ref; GET + // inlines one query parameter per argument. + postParams := post["parameters"].([]any) + if len(postParams) != 2 { + t.Fatalf("rpc post parameters = %v", postParams) + } + args := postParams[0].(map[string]any) + if args["name"] != "args" || args["in"] != "body" || args["required"] != true { + t.Errorf("rpc args parameter = %v", args) + } + if postParams[1].(map[string]any)["$ref"] != "#/parameters/preferParams" { + t.Errorf("rpc post parameters[1] = %v", postParams[1]) + } + getParams := add["get"].(map[string]any)["parameters"].([]any) + if len(getParams) != 2 { + t.Fatalf("rpc get parameters = %v", getParams) + } + for i, name := range []string{"a", "b"} { + p := getParams[i].(map[string]any) + if p["name"] != name || p["in"] != "query" || p["required"] != true { + t.Errorf("rpc get parameter %d = %v", i, p) + } } logEvent, ok := paths["/rpc/log_event"].(map[string]any) @@ -268,6 +576,110 @@ func TestRPCPaths(t *testing.T) { } } +// commentedModel is a films table annotated the way COMMENT ON does it: a +// two-line relation comment, a column comment on title, none elsewhere. +func commentedModel() *schema.Model { + rel := &schema.Relation{ + Name: "films", Kind: schema.KindTable, + Comment: "The films catalog\nEvery film we know about, one row each.", + Columns: []*schema.Column{ + {Name: "id", Type: "integer", Nullable: false, Position: 1}, + {Name: "title", Type: "text", Nullable: false, Position: 2, Comment: "Original release title"}, + }, + PrimaryKey: []string{"id"}, + } + return schema.NewModel([]*schema.Relation{rel}) +} + +// TestRelationCommentAnnotatesOperations pins the v14 comment pipeline on a +// relation: the first comment line becomes every operation's summary, the +// rest the description, and the definition and body parameter carry the full +// comment. +func TestRelationCommentAnnotatesOperations(t *testing.T) { + doc := decode(t, commentedModel(), nil, sqliteCaps(), Options{}) + films := doc["paths"].(map[string]any)["/films"].(map[string]any) + for _, op := range []string{"get", "post", "patch", "delete"} { + o := films[op].(map[string]any) + if o["summary"] != "The films catalog" { + t.Errorf("%s summary = %v", op, o["summary"]) + } + if o["description"] != "Every film we know about, one row each." { + t.Errorf("%s description = %v", op, o["description"]) + } + } + def := doc["definitions"].(map[string]any)["films"].(map[string]any) + if def["description"] != "The films catalog\nEvery film we know about, one row each." { + t.Errorf("definition description = %v", def["description"]) + } + body := doc["parameters"].(map[string]any)["body.films"].(map[string]any) + if body["description"] != "The films catalog\nEvery film we know about, one row each." { + t.Errorf("body.films description = %v", body["description"]) + } +} + +// TestUncommentedOperationsStayBare checks an uncommented relation emits no +// summary or description at all, matching v14's omission. +func TestUncommentedOperationsStayBare(t *testing.T) { + doc := decode(t, filmsModel(), nil, sqliteCaps(), Options{}) + get := doc["paths"].(map[string]any)["/films"].(map[string]any)["get"].(map[string]any) + if _, ok := get["summary"]; ok { + t.Errorf("summary should be omitted, got %v", get["summary"]) + } + if _, ok := get["description"]; ok { + t.Errorf("description should be omitted, got %v", get["description"]) + } +} + +// TestColumnCommentInPropertyAndRowFilter pins the column comment's two +// surfaces: ahead of the key notes in the definition property, and ahead of +// the operator advertisement in the rowFilter parameter. +func TestColumnCommentInPropertyAndRowFilter(t *testing.T) { + rel := &schema.Relation{ + Name: "films", Kind: schema.KindTable, + Columns: []*schema.Column{ + {Name: "id", Type: "integer", Nullable: false, Position: 1, Comment: "the film identifier"}, + }, + PrimaryKey: []string{"id"}, + } + doc := decode(t, schema.NewModel([]*schema.Relation{rel}), nil, sqliteCaps(), Options{}) + + prop := doc["definitions"].(map[string]any)["films"].(map[string]any)["properties"].(map[string]any)["id"].(map[string]any) + if prop["description"] != "the film identifier\n\nNote:\nThis is a Primary Key." { + t.Errorf("property description = %q", prop["description"]) + } + + desc := rowFilterDesc(t, doc, "films", "id") + if !strings.HasPrefix(desc, "the film identifier\n\n") { + t.Errorf("rowFilter description = %q, want the comment first", desc) + } + if !containsToken(desc, "eq") { + t.Errorf("rowFilter description lost the operator list: %q", desc) + } +} + +// TestFunctionCommentAnnotatesRPC pins the same split on an rpc path, sourced +// from the registry declaration's comment field. +func TestFunctionCommentAnnotatesRPC(t *testing.T) { + reg := rpc.NewStaticRegistry([]*rpc.Function{{ + Name: "add", + Comment: "Add two numbers\nReturns the sum of a and b.", + Params: []rpc.Param{{Name: "a", Type: "int4"}, {Name: "b", Type: "int4"}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar, Type: "int4"}, + Volatility: rpc.Immutable, + }}) + doc := decode(t, filmsModel(), reg, sqliteCaps(), Options{}) + add := doc["paths"].(map[string]any)["/rpc/add"].(map[string]any) + for _, op := range []string{"get", "post"} { + o := add[op].(map[string]any) + if o["summary"] != "Add two numbers" { + t.Errorf("%s summary = %v", op, o["summary"]) + } + if o["description"] != "Returns the sum of a and b." { + t.Errorf("%s description = %v", op, o["description"]) + } + } +} + // TestDeterministic checks the same inputs marshal byte-for-byte the same, so a // cached document and a regenerated one match. func TestDeterministic(t *testing.T) { diff --git a/pgerr/codes.go b/pgerr/codes.go index 2c06682..69a39c4 100644 --- a/pgerr/codes.go +++ b/pgerr/codes.go @@ -3,6 +3,7 @@ package pgerr import ( "fmt" "net/http" + "strings" ) // The PGRST code families (spec 18-errors.md, section "The PGRST code families"): @@ -10,25 +11,46 @@ import ( // PGRST1xx query-string syntax // PGRST2xx schema-cache and resolution // PGRST3xx JWT and auth -// PGRST127 the one dbrest-specific code: feature unsupported on this backend +// PGRST127 upstream's "Feature not implemented"; dbrest emits it for a feature +// a backend cannot do faithfully, far more often than upstream does // // Each constructor returns a fully-formed *APIError with the spec-mandated // status. Callers add details/hint with WithDetails / WithHint. const ( - CodeParse = "PGRST100" // 400 query-string parse error - CodeMethodNotAllowed = "PGRST101" // 405 method not allowed (GET on a volatile fn) - CodeRangeUnsatisfied = "PGRST103" // 416 requested range not satisfiable - CodeMediaType = "PGRST107" // 406/415 media type not negotiable - CodeSingularZeroMany = "PGRST116" // 406 singular requested, zero or many rows - CodeNoRelationship = "PGRST200" // 400 relationship not found - CodeAmbiguousEmbed = "PGRST201" // 300 embedding ambiguous - CodeNoFunction = "PGRST202" // 404 no function matches name/args - CodeUnknownColumn = "PGRST204" // 400 column in write payload not found - CodeUnknownTable = "PGRST205" // 404 table or view not found / not exposed - CodeJWTExpired = "PGRST301" // 401 JWT expired - CodeJWTInvalid = "PGRST302" // 401 JWT malformed/bad signature/alg/nbf/aud - CodeUnsupported = "PGRST127" // 400 feature not implemented on this backend - CodeInternal = "PGRSTXX0" // 500 internal error (XX family rendered as 500) + CodeParse = "PGRST100" // 400 query-string parse error + CodeMethodNotAllowed = "PGRST101" // 405 method not allowed (GET on a volatile fn) + CodeUnsupportedMethod = "PGRST117" // 405 unsupported HTTP method on the resource + CodeInvalidPreferences = "PGRST122" // 400 invalid preference under handling=strict + CodeInvalidBody = "PGRST102" // 400 invalid request body + CodeRangeUnsatisfied = "PGRST103" // 416 requested range not satisfiable + CodePutPrimaryKey = "PGRST105" // 405 PUT filters not exactly the PK with eq + CodePutLimit = "PGRST114" // 400 limit/offset on a PUT + CodePutPayloadKey = "PGRST115" // 400 PUT payload PK differs from the URL filter + CodeMediaType = "PGRST107" // 406 Accept negotiation failed + CodeGucHeaders = "PGRST111" // 500 invalid response.headers from a function + CodeGucStatus = "PGRST112" // 500 invalid response.status from a function + CodeSingularZeroMany = "PGRST116" // 406 singular requested, zero or many rows + CodeInvalidPath = "PGRST125" // 404 invalid path in request URL + CodeRelatedOrderNotToOne = "PGRST118" // 400 order=rel(col) on a non-to-one embed + CodeRelatedOrderNotEmbedded = "PGRST108" // 400 order=rel(col) on a resource not in select + CodeNoRelationship = "PGRST200" // 400 relationship not found + CodeAmbiguousEmbed = "PGRST201" // 300 embedding ambiguous + CodeNoFunction = "PGRST202" // 404 no function matches name/args + CodeAmbiguousFunc = "PGRST203" // 300 overloaded function call ambiguous + CodeUnknownColumn = "PGRST204" // 400 column in write payload not found + CodeUnknownTable = "PGRST205" // 404 table or view not found / not exposed + CodeJWTSecretMissing = "PGRST300" // 500 a token was presented but no jwt-secret is configured + CodeJWTDecode = "PGRST301" // 401 JWT could not be decoded (parts/key/alg/signature) + CodeJWTRequired = "PGRST302" // 401 no token sent and the anonymous role is disabled + CodeJWTClaims = "PGRST303" // 401 JWT claims validation or parsing failed + CodeAggregatesOff = "PGRST123" // 400 aggregate functions used while db-aggregates-enabled is off + CodeMaxAffected = "PGRST124" // 400 mutation/RPC affected more rows than Prefer: max-affected + CodeUnsupported = "PGRST127" // 400 feature not implemented on this backend + CodeDBConnection = "PGRST000" // 503 cannot connect to the database (bad URI or service down) + CodeDBClient = "PGRST001" // 503 database client error, retrying the connection + CodeAcquireTimeout = "PGRST003" // 504 timed out acquiring a connection from the pool + CodeInternal = "PGRSTX00" // 500 internal error (upstream group X has only X00) + CodeBodyTooLarge = "PGRSTX13" // 413 request body exceeds the configured max-request-body ) // ErrParse is a query-string syntax error (bad operator, malformed logic tree). @@ -36,11 +58,32 @@ func ErrParse(msg string) *APIError { return New(http.StatusBadRequest, CodeParse, msg) } +// ErrBodyTooLarge reports a request body over the configured max-request-body +// cap. PostgREST has no body-size limit, so this exists only when an operator +// opts into one; the 413 and the explicit byte bound tell the client to split +// the load rather than presenting it as a parse failure. +func ErrBodyTooLarge(limit int64) *APIError { + return New(http.StatusRequestEntityTooLarge, CodeBodyTooLarge, + fmt.Sprintf("Request body exceeds the %d byte max-request-body limit", limit)) +} + +// ErrInvalidBody is an invalid request body (PostgREST's PGRST102, HTTP 400): +// an empty or malformed JSON or CSV payload, or a bulk insert whose objects do +// not all share the same key set ("All object keys must match"). An empty msg +// falls back to PostgREST's generic JSON-body message. +func ErrInvalidBody(msg string) *APIError { + if msg == "" { + msg = "Empty or invalid json" + } + return New(http.StatusBadRequest, CodeInvalidBody, msg) +} + // ErrSingularZeroMany is raised when a singular response was requested but zero -// or many rows were produced. +// or many rows were produced. The text is v14's; render call sites attach the +// row count as details ("The result contains N rows"). func ErrSingularZeroMany() *APIError { return New(http.StatusNotAcceptable, CodeSingularZeroMany, - "JSON object requested, multiple (or no) rows returned") + "Cannot coerce the result to a single JSON object") } // ErrRangeNotSatisfiable is raised when the requested window starts past the end @@ -59,51 +102,211 @@ func ErrNotAcceptable(offered string) *APIError { } // ErrUnsupportedMediaType is raised when a write or RPC body arrives with a -// Content-Type no parser handles. It is PGRST107 with a 415, the request-side -// twin of ErrNotAcceptable. +// Content-Type no parser handles. The published v14 error table still shows a +// stale PGRST107/415 row for this, but live v14 answers 400 PGRST102 +// "Content-Type not acceptable: " (verified against a running PostgREST +// by compat/errors_v14_test.go), so the wire behavior wins: PGRST107 stays +// reserved for failed Accept negotiation (ErrNotAcceptable, 406). func ErrUnsupportedMediaType(contentType string) *APIError { - return New(http.StatusUnsupportedMediaType, CodeMediaType, - fmt.Sprintf("Content-Type not supported: '%s'", contentType)) + return New(http.StatusBadRequest, CodeInvalidBody, + fmt.Sprintf("Content-Type not acceptable: %s", contentType)) } // ErrUnknownTable is raised when a table or view is not in the schema model -// (unknown, or not exposed by db-schemas). -func ErrUnknownTable(name string) *APIError { +// (unknown, or not exposed by db-schemas). PostgREST schema-qualifies the name in +// its PGRST205 message ("Could not find the table 'api.films' in the schema +// cache"), so schemaName is the schema the request resolved to; an empty schema +// (a backend with no namespace) normalizes to public, the name PostgREST emits. +func ErrUnknownTable(schemaName, name string) *APIError { + if schemaName == "" { + schemaName = "public" + } return New(http.StatusNotFound, CodeUnknownTable, - fmt.Sprintf("Could not find the table '%s' in the schema cache", name)) + fmt.Sprintf("Could not find the table '%s.%s' in the schema cache", schemaName, name)) } -// ErrUnknownColumn is raised when a column named in a payload or select is not -// found on the target relation. -func ErrUnknownColumn(col string) *APIError { +// ErrUnknownColumn is raised when a column named in a write payload or the +// columns= parameter is not found on the target relation; rel is that relation, +// named in the message the way PostgREST spells PGRST204 ("Could not find the +// 'X' column of 'Y' in the schema cache"). PostgREST reserves PGRST204 for those +// two cases; a column referenced by select, a filter, or order reaches +// PostgreSQL instead and surfaces as 42703 (ErrUndefinedColumn). +func ErrUnknownColumn(col, rel string) *APIError { return New(http.StatusBadRequest, CodeUnknownColumn, - fmt.Sprintf("Could not find the '%s' column in the schema cache", col)) + fmt.Sprintf("Could not find the '%s' column of '%s' in the schema cache", col, rel)) +} + +// CodeUndefinedColumn is PostgreSQL's undefined_column. In PostgREST an unknown +// column in select, a filter, or order is not caught by the schema cache; it +// reaches the server and comes back as this SQLSTATE with a 400. +const CodeUndefinedColumn = "42703" + +// ErrUndefinedColumn mirrors PostgreSQL's own message for a reference to a +// column that does not exist; column is the relation-qualified name the query +// used ("todos.nope"). Callers add the server's "Perhaps you meant to reference +// the column ..." suggestion with WithHint when a near-miss exists. +func ErrUndefinedColumn(column string) *APIError { + return New(http.StatusBadRequest, CodeUndefinedColumn, + fmt.Sprintf("column %s does not exist", column)) } // ErrNoRelationship is raised when an embed names a resource the schema model // has no relationship to (no foreign key connects them, and none is declared). -// It is PostgREST's PGRST200 with a 400. -func ErrNoRelationship(parent, target string) *APIError { - return New(http.StatusBadRequest, CodeNoRelationship, +// It is PostgREST's PGRST200 with a 400. The details name the searched pair and +// the schema the search ran in, matching the sentence PostgREST returns so a +// client sees why the embed failed and not just that it did. schemaName is the +// parent's schema; embedHint, when non-empty, is the disambiguation hint the +// request gave (after the `!`), which the details echo. +func ErrNoRelationship(parent, target, schemaName, embedHint string) *APIError { + e := New(http.StatusBadRequest, CodeNoRelationship, fmt.Sprintf("Could not find a relationship between '%s' and '%s' in the schema cache", parent, target)) + hintClause := "" + if embedHint != "" { + hintClause = fmt.Sprintf(" using the hint '%s'", embedHint) + } + return e.WithDetails(fmt.Sprintf( + "Searched for a foreign key relationship between '%s' and '%s'%s in the schema '%s', but no matches were found.", + parent, target, hintClause, schemaName)) +} + +// EmbedCandidate is one relationship that connects a parent and an embedded +// resource, rendered into a PGRST201 details entry. Cardinality, Embedding, and +// Relationship are the three keys PostgREST serializes; Name is the edge name a +// client uses to disambiguate (target!name) and is carried for the hint, not the +// details body. +type EmbedCandidate struct { + Cardinality string `json:"cardinality"` + Embedding string `json:"embedding"` + Relationship string `json:"relationship"` + Name string `json:"-"` } // ErrAmbiguousEmbed is raised when more than one relationship connects the // parent and the embedded resource and no hint disambiguates. It is PostgREST's -// PGRST201 with a 300 Multiple Choices. -func ErrAmbiguousEmbed(parent, target string) *APIError { - return New(http.StatusMultipleChoices, CodeAmbiguousEmbed, +// PGRST201 with a 300 Multiple Choices. The details carry the candidate array +// clients read to auto-disambiguate, and the hint lists the disambiguated embed +// spellings (target!name) to try, pointing at the details for the full set. +func ErrAmbiguousEmbed(parent, target string, cands []EmbedCandidate) *APIError { + e := New(http.StatusMultipleChoices, CodeAmbiguousEmbed, fmt.Sprintf("Could not embed because more than one relationship was found for '%s' and '%s'", parent, target)) + if len(cands) == 0 { + return e + } + e = e.WithDetailsJSON(cands) + spellings := make([]string, len(cands)) + for i, c := range cands { + spellings[i] = fmt.Sprintf("'%s!%s'", target, c.Name) + } + return e.WithHint(fmt.Sprintf( + "Try changing '%s' to one of the following: %s. Find the desired relationship in the 'details' key.", + target, strings.Join(spellings, ", "))) +} + +// ErrRelatedOrderNotEmbedded is raised when an order=rel(col) term names a +// relation that the request did not embed in select. It is PostgREST's PGRST108 +// with a 400; the hint points the caller at the select parameter. +func ErrRelatedOrderNotEmbedded(rel string) *APIError { + return New(http.StatusBadRequest, CodeRelatedOrderNotEmbedded, + fmt.Sprintf("'%s' is not an embedded resource in this request", rel)). + WithHint(fmt.Sprintf("Verify that '%s' is included in the 'select' query parameter.", rel)) } -// ErrNoFunction is raised when no function matches the name and argument set. -func ErrNoFunction(name string) *APIError { - return New(http.StatusNotFound, CodeNoFunction, - fmt.Sprintf("Could not find the function '%s' in the schema cache", name)) +// ErrRelatedOrderNotToOne is raised when an order=rel(col) term names an embedded +// relation that is to-many rather than many-to-one or one-to-one: a parent row +// would map to many related rows, so the related column is not a single sort key. +// It is PostgREST's PGRST118 with a 400. +func ErrRelatedOrderNotToOne(parent, rel string) *APIError { + return New(http.StatusBadRequest, CodeRelatedOrderNotToOne, + fmt.Sprintf("A related order on '%s' is not possible", rel)). + WithDetails(fmt.Sprintf("'%s' and '%s' do not form a many-to-one or one-to-one relationship", parent, rel)) } -// ErrMethodNotAllowed is raised when a read method calls a volatile function: a -// GET to a function with side effects, which PostgREST rejects with 405. +// ErrNoFunction is raised when no function matches the name and argument set. It +// names the function schema-qualified with the argument list that was searched +// for ("public.add(a, b)"), or the "without parameters" form when the call +// supplied none, the way PostgREST spells PGRST202. A non-empty hint (the nearest +// registered signature) is attached so the caller sees the closest match. +func ErrNoFunction(schemaName, name string, argNames []string, hint string) *APIError { + qualified := name + if schemaName != "" { + qualified = schemaName + "." + name + } + var msg string + if len(argNames) == 0 { + msg = fmt.Sprintf("Could not find the function %s without parameters in the schema cache", qualified) + } else { + msg = fmt.Sprintf("Could not find the function %s(%s) in the schema cache", qualified, strings.Join(argNames, ", ")) + } + e := New(http.StatusNotFound, CodeNoFunction, msg) + if hint != "" { + e = e.WithHint(hint) + } + return e +} + +// ErrAmbiguousFunction is raised when more than one overload of a function +// survives argument matching, PostgREST's PGRST203 with a 300. candidates are +// the surviving signatures, schema-qualified with their parameter lists +// ("api.add(a => integer, b => integer)"), spelled into the message the way +// upstream does. +func ErrAmbiguousFunction(candidates []string) *APIError { + e := New(http.StatusMultipleChoices, CodeAmbiguousFunc, + "Could not choose the best candidate function between: "+strings.Join(candidates, ", ")) + return e.WithHint("Try renaming the parameters or the function itself in the database so function overloading can be resolved") +} + +// ErrPutPrimaryKey is raised when a PUT's URL filters are not exactly the +// relation's primary key columns, each with eq. PostgREST insists a PUT address +// one row by its whole key, so a partial, extra, or non-eq filter is its +// PGRST105 with a 405 (verified live). +func ErrPutPrimaryKey() *APIError { + return New(http.StatusMethodNotAllowed, CodePutPrimaryKey, + "Filters must include all and only primary key columns with 'eq' operators") +} + +// ErrPutLimit is raised when a PUT carries a limit or offset; PostgREST rejects +// paginating a single-row replace as its PGRST114 with a 400. +func ErrPutLimit() *APIError { + return New(http.StatusBadRequest, CodePutLimit, + "limit/offset querystring parameters are not allowed for PUT") +} + +// ErrPutPayloadKey is raised when a PUT body's primary key values differ from +// the URL filter values, or the body is not a single object. PostgREST condemns +// the transaction so nothing is written; it is its PGRST115 with a 400. +func ErrPutPayloadKey() *APIError { + return New(http.StatusBadRequest, CodePutPayloadKey, + "Payload values do not match URL in primary key column(s)") +} + +// ErrInvalidPath is raised for a request path PostgREST has no route for: more +// than one segment after the relation, or extra segments after /rpc/. It is +// v14's PGRST125, a 404 with this exact message (verified live), distinct from +// the PGRST205 an unknown relation gets. +func ErrInvalidPath() *APIError { + return New(http.StatusNotFound, CodeInvalidPath, + "Invalid path specified in request URL") +} + +// ErrInvalidResponseHeaders is raised when a function sets response.headers to +// something other than an array of one-key string objects. PostgREST returns +// PGRST111 at 500 rather than forwarding junk headers; the message is +// upstream's. +func ErrInvalidResponseHeaders() *APIError { + return New(http.StatusInternalServerError, CodeGucHeaders, + "response.headers guc must be a JSON array composed of objects with a single key and a string value") +} + +// ErrInvalidResponseStatus is raised when a function sets response.status to +// anything that is not a valid status code; PostgREST's PGRST112 at 500. +func ErrInvalidResponseStatus() *APIError { + return New(http.StatusInternalServerError, CodeGucStatus, + "response.status guc must be a valid status code") +} + +// ErrMethodNotAllowed is a 405 PGRST101 with a caller-supplied message. Prefer +// ErrInvalidRPCMethod for the wrong-verb-on-a-function case, which carries +// upstream's exact text. func ErrMethodNotAllowed(msg string) *APIError { if msg == "" { msg = "Method not allowed" @@ -111,22 +314,88 @@ func ErrMethodNotAllowed(msg string) *APIError { return New(http.StatusMethodNotAllowed, CodeMethodNotAllowed, msg) } -// ErrUnsupported is the dbrest-specific PGRST127. The details string always -// names both the feature and the backend, per spec 18 section "PGRST127". -// Emission must happen strictly before any backend call. +// ErrInvalidRPCMethod is raised when a function is called with a verb other +// than GET, HEAD, or POST. The text matches v14's PGRST101 ("Cannot use the +// DELETE method on RPC", verified live). +func ErrInvalidRPCMethod(method string) *APIError { + return New(http.StatusMethodNotAllowed, CodeMethodNotAllowed, + fmt.Sprintf("Cannot use the %s method on RPC", method)) +} + +// ErrUnsupportedMethod is PostgREST's PGRST117 (405): an HTTP method the server +// does not implement on any resource, such as TRACE or a verb the table or +// function endpoint never answers. The text is upstream's "Unsupported HTTP +// method: ". OPTIONS is never this error; it is answered with an Allow +// header. +func ErrUnsupportedMethod(method string) *APIError { + return New(http.StatusMethodNotAllowed, CodeUnsupportedMethod, + "Unsupported HTTP method: "+method) +} + +// ErrInvalidPreferences is PostgREST's PGRST122 (400): a request under +// Prefer: handling=strict that carries an unknown preference or a known one with +// a bad value. The message is upstream's; the offending tokens ride as details, +// comma-joined, so the caller sees exactly which preferences were rejected. +func ErrInvalidPreferences(offenders []string) *APIError { + e := New(http.StatusBadRequest, CodeInvalidPreferences, + "Invalid preferences given with handling=strict") + return e.WithDetails("Invalid preferences: " + strings.Join(offenders, ", ")) +} + +// CodeReadOnlyTransaction is PostgreSQL's read_only_sql_transaction. PostgREST +// runs a GET/HEAD function call in a read-only transaction; a function that +// writes fails with this SQLSTATE, surfaced as a 405 with the server's message. +// dbrest's registry path raises it up front when a GET reaches a function +// declared volatile, since registry backends cannot run the call to find out. +const CodeReadOnlyTransaction = "25006" + +// ErrReadOnlyTransaction mirrors PostgreSQL's "cannot execute X in a read-only +// transaction" for a write attempted under a read verb; action names what was +// attempted (a statement kind, or the function for the declared-volatility +// pre-check). +func ErrReadOnlyTransaction(action string) *APIError { + return New(http.StatusMethodNotAllowed, CodeReadOnlyTransaction, + fmt.Sprintf("cannot execute %s in a read-only transaction", action)) +} + +// ErrUnsupported is PGRST127, which v14 defines as "the feature specified in +// the details field is not implemented"; the message is upstream's "Feature not +// implemented" and the details string always names both the feature and the +// backend, per spec 18 section "PGRST127". Emission must happen strictly before +// any backend call. func ErrUnsupported(feature, backend string) *APIError { - e := New(http.StatusBadRequest, CodeUnsupported, "feature not implemented on this backend") + e := New(http.StatusBadRequest, CodeUnsupported, "Feature not implemented") e = e.WithDetails(fmt.Sprintf("%s is not supported by the %s backend", feature, backend)) return e.WithHint("see the capability matrix for supported features on this backend") } +// ErrMaxAffected is PGRST124 (400): a write or RPC under Prefer: +// handling=strict, max-affected=N affected more than N rows, so dbrest rolls the +// transaction back rather than committing the over-broad change. The message is +// upstream's; the actual affected count rides as details so the client sees how +// far over the bound the query reached. +func ErrMaxAffected(affected int64) *APIError { + e := New(http.StatusBadRequest, CodeMaxAffected, + "Query result exceeds max-affected preference constraint") + return e.WithDetails(fmt.Sprintf("The query affects %d rows", affected)) +} + +// ErrAggregatesDisabled is PGRST123, raised when a request uses an aggregate +// function (count(), col.sum(), ...) while db-aggregates-enabled is off. The +// message and hint are upstream's, pointing the operator at the config flag. +func ErrAggregatesDisabled() *APIError { + e := New(http.StatusBadRequest, CodeAggregatesOff, + "Use of aggregate functions is not allowed") + return e.WithHint("Enable the 'db-aggregates-enabled' config parameter to allow the use of aggregate functions") +} + // ErrFullTextUnavailable is the PGRST127 for a full-text predicate on a column the // backend has no full-text structure for (a SQLite column with no covering FTS5 // table). It names the column so the missing structure is actionable, per spec // 21's "never silently wrong" rule: dbrest errors rather than degrading to a // substring scan. Emission happens before any backend call. func ErrFullTextUnavailable(column, backend string) *APIError { - e := New(http.StatusBadRequest, CodeUnsupported, "feature not implemented on this backend") + e := New(http.StatusBadRequest, CodeUnsupported, "Feature not implemented") e = e.WithDetails(fmt.Sprintf("full-text search on column %q has no full-text index on the %s backend", column, backend)) return e.WithHint("create a full-text index covering the column") } @@ -143,29 +412,29 @@ const ( CodeCheckViolation = "23514" // 400 fails a CHECK constraint ) -// ErrUniqueViolation is a duplicate-key conflict (PostgreSQL 23505). -func ErrUniqueViolation(detail string) *APIError { - return New(http.StatusConflict, CodeUniqueViolation, - "duplicate key value violates unique constraint").WithDetails(detail) -} - -// ErrNotNullViolation is a NULL written to a NOT NULL column (23502). -func ErrNotNullViolation(detail string) *APIError { - return New(http.StatusBadRequest, CodeNotNullViolation, - "null value violates not-null constraint").WithDetails(detail) -} - -// ErrForeignKeyViolation is a reference to a row that does not exist (23503). -func ErrForeignKeyViolation(detail string) *APIError { - return New(http.StatusConflict, CodeForeignKeyViolation, - "insert or update violates foreign key constraint").WithDetails(detail) -} - -// ErrCheckViolation is a row that fails a CHECK constraint (23514). It is also -// the fallback for any other integrity violation the backend cannot classify. -func ErrCheckViolation(detail string) *APIError { - return New(http.StatusBadRequest, CodeCheckViolation, - "new row violates check constraint").WithDetails(detail) +// ErrConstraintViolation surfaces a backend's integrity-constraint error with +// the engine's text carried through verbatim, the way PostgREST forwards +// PostgreSQL's: message names the constraint ("duplicate key value violates +// unique constraint \"todos_pkey\"") and detail carries the key ("Key (id)=(1) +// already exists."). Clients parse both, so pgerr contributes only the status: +// a key that conflicts with an existing row (23505, 23503) is a 409, the rest +// of class 23 is a 400. Drivers whose engine reports structure instead of +// PG-shaped text synthesize the message before calling this; an engine that +// supplies neither a constraint name nor the offending value passes the bare +// PostgreSQL wording with empty detail rather than leaking native text. +func ErrConstraintViolation(sqlstate, message, detail, hint string) *APIError { + status := http.StatusBadRequest + if sqlstate == CodeUniqueViolation || sqlstate == CodeForeignKeyViolation { + status = http.StatusConflict + } + e := New(status, sqlstate, message) + if detail != "" { + e = e.WithDetails(detail) + } + if hint != "" { + e = e.WithHint(hint) + } + return e } // CodeInvalidText is PostgreSQL's invalid_text_representation: an operand or @@ -174,26 +443,62 @@ func ErrCheckViolation(detail string) *APIError { // same 400 on every backend (spec 16). const CodeInvalidText = "22P02" +// pgTypeSpelling maps dbrest's canonical type names to the spellings PostgreSQL +// uses in its own error messages, so a 22P02 reads exactly like the server's +// ("invalid input syntax for type integer", never "type int4"). +var pgTypeSpelling = map[string]string{ + "int2": "smallint", + "int4": "integer", + "int8": "bigint", + "float4": "real", + "float8": "double precision", + "bool": "boolean", +} + // ErrInvalidInput is raised when a query-string operand or a payload value cannot // be coerced to its canonical type. It mirrors PostgreSQL's "invalid input syntax // for type T" message and surfaces the 22P02 SQLSTATE as a 400. func ErrInvalidInput(canonicalType, input string) *APIError { + if s, ok := pgTypeSpelling[canonicalType]; ok { + canonicalType = s + } return New(http.StatusBadRequest, CodeInvalidText, fmt.Sprintf("invalid input syntax for type %s: %q", canonicalType, input)) } -// ErrJWTExpired is raised when a JWT is past its exp (with skew applied). -func ErrJWTExpired() *APIError { - return New(http.StatusUnauthorized, CodeJWTExpired, "JWT expired") +// ErrJWTSecretMissing is raised when a request presents a Bearer token but the +// server has no key material to verify it with. It is PostgREST's PGRST300, a +// 500: the misconfiguration is on the server, not the client, and no challenge +// header is sent. +func ErrJWTSecretMissing() *APIError { + return New(http.StatusInternalServerError, CodeJWTSecretMissing, "Server lacks JWT secret") } -// ErrJWTInvalid is raised for a malformed token, bad signature, disallowed alg, -// or a failed nbf/aud check. -func ErrJWTInvalid(msg string) *APIError { - if msg == "" { - msg = "JWT invalid" - } - return New(http.StatusUnauthorized, CodeJWTInvalid, msg) +// ErrJWTDecode is raised when a JWT cannot be decoded: a wrong number of parts, +// no suitable key, a disallowed algorithm, or a failed signature check. It is +// PostgREST's PGRST301 with the RFC 6750 invalid_token challenge. +func ErrJWTDecode(msg string) *APIError { + e := New(http.StatusUnauthorized, CodeJWTDecode, msg) + e.WWWAuthenticate = BearerInvalidToken(msg) + return e +} + +// ErrJWTClaims is raised when a decoded JWT fails claims validation or parsing: +// exp/nbf/iat out of range, an audience mismatch, or an unparseable claim set. +// It is PostgREST's PGRST303 with the RFC 6750 invalid_token challenge. +func ErrJWTClaims(msg string) *APIError { + e := New(http.StatusUnauthorized, CodeJWTClaims, msg) + e.WWWAuthenticate = BearerInvalidToken(msg) + return e +} + +// ErrJWTRequired is raised when a request presents no token and the anonymous +// role is disabled, so there is no role to run it as. It is PostgREST's PGRST302 +// with the bare Bearer challenge. +func ErrJWTRequired() *APIError { + e := New(http.StatusUnauthorized, CodeJWTRequired, "Anonymous access is disabled") + e.WWWAuthenticate = "Bearer" + return e } // CodeInsufficientPrivilege is PostgreSQL's class-42 SQLSTATE for a denied role @@ -214,11 +519,34 @@ func ErrRoleNotAllowed(role string) *APIError { // JWT and was denied to anon (spec 14). func ErrPermissionDenied(relation string, anonymous bool) *APIError { status := http.StatusForbidden + e := New(status, CodeInsufficientPrivilege, + fmt.Sprintf("permission denied for table %s", relation)) if anonymous { - status = http.StatusUnauthorized + // PostgREST sends the bare Bearer challenge on every 401, including a + // privilege denial lifted from 403 for an unauthenticated request. + e.HTTPStatus = http.StatusUnauthorized + e.WWWAuthenticate = "Bearer" } - return New(status, CodeInsufficientPrivilege, - fmt.Sprintf("permission denied for table %s", relation)) + return e +} + +// GradePrivilegeStatus applies PostgREST's 42501 rule to e: insufficient +// privilege is 403 when the request was authenticated and 401 when it ran as +// anon, so an authenticated client never gets the 401 that would trigger a +// token-refresh loop. An error with any other code passes through unchanged. +// This is the one place the rule lives; the exec-error mapping and the +// per-driver SQLSTATE tables defer to it. +func GradePrivilegeStatus(e *APIError, authenticated bool) *APIError { + if e == nil || e.Code != CodeInsufficientPrivilege { + return e + } + c := *e + if authenticated { + c.HTTPStatus = http.StatusForbidden + } else { + c.HTTPStatus = http.StatusUnauthorized + } + return &c } // ErrRLSViolation is a row that fails a WITH CHECK policy on a write, mirroring @@ -234,3 +562,35 @@ func ErrRLSViolation(relation string) *APIError { func ErrInternal(msg string) *APIError { return New(http.StatusInternalServerError, CodeInternal, msg) } + +// ErrDBConnection is a failed or refused database connection (a bad URI or a +// service that is down), mapped to 503 with PostgREST's group-0 message and the +// driver's own text carried in details. A load balancer treats the 503 as +// retryable, matching PostgREST's ConnectionUsageError. +func ErrDBConnection(detail string) *APIError { + e := New(http.StatusServiceUnavailable, CodeDBConnection, + "Database connection error. Retrying the connection.") + if detail != "" { + e = e.WithDetails(detail) + } + return e +} + +// ErrDBClient is a database client error during a session (a dropped or reset +// connection mid-request), mapped to 503 with PostgREST's group-0 message and +// the driver's text in details, matching PostgREST's ClientError. +func ErrDBClient(detail string) *APIError { + e := New(http.StatusServiceUnavailable, CodeDBClient, + "Database client error. Retrying the connection.") + if detail != "" { + e = e.WithDetails(detail) + } + return e +} + +// ErrAcquireTimeout is a pool-acquisition timeout, mapped to 504 with +// PostgREST's exact group-0 message, matching AcquisitionTimeoutUsageError. +func ErrAcquireTimeout() *APIError { + return New(http.StatusGatewayTimeout, CodeAcquireTimeout, + "Timed out acquiring connection from connection pool.") +} diff --git a/pgerr/codes_test.go b/pgerr/codes_test.go index a9da682..527c33f 100644 --- a/pgerr/codes_test.go +++ b/pgerr/codes_test.go @@ -20,30 +20,45 @@ func TestConstructorStatusAndCode(t *testing.T) { code string }{ {"parse", ErrParse("bad operator"), http.StatusBadRequest, CodeParse}, + {"invalid-body", ErrInvalidBody(""), http.StatusBadRequest, CodeInvalidBody}, {"singular", ErrSingularZeroMany(), http.StatusNotAcceptable, CodeSingularZeroMany}, {"range", ErrRangeNotSatisfiable(), http.StatusRequestedRangeNotSatisfiable, CodeRangeUnsatisfied}, {"not-acceptable", ErrNotAcceptable("text/csv"), http.StatusNotAcceptable, CodeMediaType}, - {"unsupported-media", ErrUnsupportedMediaType("text/yaml"), http.StatusUnsupportedMediaType, CodeMediaType}, - {"unknown-table", ErrUnknownTable("films"), http.StatusNotFound, CodeUnknownTable}, - {"unknown-column", ErrUnknownColumn("titel"), http.StatusBadRequest, CodeUnknownColumn}, - {"no-relationship", ErrNoRelationship("films", "actors"), http.StatusBadRequest, CodeNoRelationship}, - {"ambiguous-embed", ErrAmbiguousEmbed("films", "actors"), http.StatusMultipleChoices, CodeAmbiguousEmbed}, - {"no-function", ErrNoFunction("add"), http.StatusNotFound, CodeNoFunction}, + {"unsupported-media", ErrUnsupportedMediaType("text/yaml"), http.StatusBadRequest, CodeInvalidBody}, + {"unknown-table", ErrUnknownTable("public", "films"), http.StatusNotFound, CodeUnknownTable}, + {"unknown-column", ErrUnknownColumn("titel", "films"), http.StatusBadRequest, CodeUnknownColumn}, + {"undefined-column", ErrUndefinedColumn("todos.nope"), http.StatusBadRequest, CodeUndefinedColumn}, + {"no-relationship", ErrNoRelationship("films", "actors", "public", ""), http.StatusBadRequest, CodeNoRelationship}, + {"ambiguous-embed", ErrAmbiguousEmbed("films", "actors", nil), http.StatusMultipleChoices, CodeAmbiguousEmbed}, + {"no-function", ErrNoFunction("public", "add", []string{"a", "b"}, ""), http.StatusNotFound, CodeNoFunction}, + {"ambiguous-function", ErrAmbiguousFunction([]string{"api.add(a => integer)", "api.add(a => text)"}), http.StatusMultipleChoices, CodeAmbiguousFunc}, + {"invalid-path", ErrInvalidPath(), http.StatusNotFound, CodeInvalidPath}, + {"guc-headers", ErrInvalidResponseHeaders(), http.StatusInternalServerError, CodeGucHeaders}, + {"guc-status", ErrInvalidResponseStatus(), http.StatusInternalServerError, CodeGucStatus}, {"method-not-allowed", ErrMethodNotAllowed(""), http.StatusMethodNotAllowed, CodeMethodNotAllowed}, + {"invalid-rpc-method", ErrInvalidRPCMethod("DELETE"), http.StatusMethodNotAllowed, CodeMethodNotAllowed}, + {"read-only-txn", ErrReadOnlyTransaction("UPDATE"), http.StatusMethodNotAllowed, CodeReadOnlyTransaction}, {"unsupported", ErrUnsupported("the sl operator", "mysql"), http.StatusBadRequest, CodeUnsupported}, {"fts-unavailable", ErrFullTextUnavailable("body", "sqlite"), http.StatusBadRequest, CodeUnsupported}, - {"unique", ErrUniqueViolation("Key (id)=(1) already exists"), http.StatusConflict, CodeUniqueViolation}, - {"not-null", ErrNotNullViolation("column title"), http.StatusBadRequest, CodeNotNullViolation}, - {"foreign-key", ErrForeignKeyViolation("Key (dir)=(9) is not present"), http.StatusConflict, CodeForeignKeyViolation}, - {"check", ErrCheckViolation("rating must be positive"), http.StatusBadRequest, CodeCheckViolation}, + {"constraint-unique", ErrConstraintViolation("23505", "m", "", ""), http.StatusConflict, CodeUniqueViolation}, + {"constraint-fk", ErrConstraintViolation("23503", "m", "", ""), http.StatusConflict, CodeForeignKeyViolation}, + {"constraint-not-null", ErrConstraintViolation("23502", "m", "", ""), http.StatusBadRequest, CodeNotNullViolation}, + {"constraint-check", ErrConstraintViolation("23514", "m", "", ""), http.StatusBadRequest, CodeCheckViolation}, {"invalid-input", ErrInvalidInput("integer", "abc"), http.StatusBadRequest, CodeInvalidText}, - {"jwt-expired", ErrJWTExpired(), http.StatusUnauthorized, CodeJWTExpired}, - {"jwt-invalid", ErrJWTInvalid(""), http.StatusUnauthorized, CodeJWTInvalid}, + {"jwt-secret-missing", ErrJWTSecretMissing(), http.StatusInternalServerError, CodeJWTSecretMissing}, + {"jwt-decode", ErrJWTDecode("JWT couldn't be decoded"), http.StatusUnauthorized, CodeJWTDecode}, + {"jwt-required", ErrJWTRequired(), http.StatusUnauthorized, CodeJWTRequired}, + {"jwt-claims", ErrJWTClaims("JWT expired"), http.StatusUnauthorized, CodeJWTClaims}, {"role-not-allowed", ErrRoleNotAllowed("admin"), http.StatusForbidden, CodeInsufficientPrivilege}, {"permission-denied", ErrPermissionDenied("films", false), http.StatusForbidden, CodeInsufficientPrivilege}, {"rls-violation", ErrRLSViolation("films"), http.StatusForbidden, CodeInsufficientPrivilege}, {"internal", ErrInternal("boom"), http.StatusInternalServerError, CodeInternal}, } + // The internal code is pinned to its literal: clients and monitors match the + // documented PGRSTX00, so a private spelling would never match anything. + if CodeInternal != "PGRSTX00" { + t.Errorf("CodeInternal = %q, want PGRSTX00", CodeInternal) + } for _, c := range cases { t.Run(c.name, func(t *testing.T) { if c.err.HTTPStatus != c.status { @@ -77,20 +92,169 @@ func TestPermissionDeniedAnonymousIs401(t *testing.T) { } } +// GradePrivilegeStatus is the single spelling of the 42501 rule: 403 when +// authenticated, 401 when anonymous, untouched for every other code. +func TestGradePrivilegeStatus(t *testing.T) { + native := New(http.StatusUnauthorized, CodeInsufficientPrivilege, "permission denied for table films") + if got := GradePrivilegeStatus(native, true).HTTPStatus; got != http.StatusForbidden { + t.Errorf("authenticated 42501 = %d, want 403", got) + } + if got := GradePrivilegeStatus(native, false).HTTPStatus; got != http.StatusUnauthorized { + t.Errorf("anonymous 42501 = %d, want 401", got) + } + if native.HTTPStatus != http.StatusUnauthorized { + t.Error("GradePrivilegeStatus mutated its argument") + } + other := ErrConstraintViolation("23505", "duplicate key", "", "") + if got := GradePrivilegeStatus(other, true); got != other { + t.Error("non-42501 errors must pass through unchanged") + } + if GradePrivilegeStatus(nil, true) != nil { + t.Error("nil must pass through") + } +} + // The empty-message constructors fall back to a non-empty default rather than // shipping a blank message to the client. func TestEmptyMessageDefaults(t *testing.T) { if got := ErrMethodNotAllowed("").Message; got == "" { t.Error("ErrMethodNotAllowed default message is empty") } - if got := ErrJWTInvalid("").Message; got == "" { - t.Error("ErrJWTInvalid default message is empty") - } if got := ErrMethodNotAllowed("custom").Message; got != "custom" { t.Errorf("ErrMethodNotAllowed override = %q, want custom", got) } - if got := ErrJWTInvalid("custom").Message; got != "custom" { - t.Errorf("ErrJWTInvalid override = %q, want custom", got) +} + +// The JWT errors carry the WWW-Authenticate challenge PostgREST sends on every +// 401: the RFC 6750 invalid_token form on PGRST301/PGRST303, the bare Bearer on +// PGRST302 and on an anonymous privilege denial. +func TestJWTErrorsCarryWWWAuthenticate(t *testing.T) { + wantInvalid := `Bearer error="invalid_token", error_description="JWT expired"` + if got := ErrJWTClaims("JWT expired").WWWAuthenticate; got != wantInvalid { + t.Errorf("ErrJWTClaims challenge = %q, want %q", got, wantInvalid) + } + if got := ErrJWTDecode("JWT couldn't be decoded").WWWAuthenticate; got == "" { + t.Error("ErrJWTDecode must carry an invalid_token challenge") + } + if got := ErrJWTRequired().WWWAuthenticate; got != "Bearer" { + t.Errorf("ErrJWTRequired challenge = %q, want Bearer", got) + } + if got := ErrPermissionDenied("films", true).WWWAuthenticate; got != "Bearer" { + t.Errorf("anonymous ErrPermissionDenied challenge = %q, want Bearer", got) + } + if got := ErrPermissionDenied("films", false).WWWAuthenticate; got != "" { + t.Errorf("authenticated ErrPermissionDenied challenge = %q, want none", got) + } + if got := ErrJWTSecretMissing().WWWAuthenticate; got != "" { + t.Errorf("ErrJWTSecretMissing challenge = %q, want none on a 500", got) + } +} + +// The v14 message texts replaced several pre-v12 spellings; clients match on +// them, so each retired text is pinned to its current form here. +func TestV14MessageTexts(t *testing.T) { + if got, want := ErrSingularZeroMany().Message, "Cannot coerce the result to a single JSON object"; got != want { + t.Errorf("PGRST116 message = %q, want %q", got, want) + } + if got, want := ErrUnsupported("the sl operator", "mysql").Message, "Feature not implemented"; got != want { + t.Errorf("PGRST127 message = %q, want %q", got, want) + } + if got, want := ErrFullTextUnavailable("body", "sqlite").Message, "Feature not implemented"; got != want { + t.Errorf("PGRST127 fts message = %q, want %q", got, want) + } +} + +// A 22P02 names the type the way PostgreSQL's own message does: the SQL +// standard spelling, never the internal catalog name. +func TestInvalidInputTypeSpelling(t *testing.T) { + cases := map[string]string{ + "int2": "smallint", + "int4": "integer", + "int8": "bigint", + "float4": "real", + "float8": "double precision", + "bool": "boolean", + "uuid": "uuid", // no PG alias, passes through + } + for canonical, spelled := range cases { + got := ErrInvalidInput(canonical, "abc").Message + want := `invalid input syntax for type ` + spelled + `: "abc"` + if got != want { + t.Errorf("ErrInvalidInput(%q) message = %q, want %q", canonical, got, want) + } + } +} + +// A constraint violation carries the engine's text through untouched: the +// message keeps its constraint name and the detail its key, the parts clients +// parse out of a live PostgREST response. +func TestConstraintViolationPassesTextThrough(t *testing.T) { + e := ErrConstraintViolation("23505", + `duplicate key value violates unique constraint "todos_pkey"`, + "Key (id)=(1) already exists.", "") + if e.Message != `duplicate key value violates unique constraint "todos_pkey"` { + t.Errorf("message = %q", e.Message) + } + if e.Details == nil || *e.Details != "Key (id)=(1) already exists." { + t.Errorf("details = %v", e.Details) + } + if e.Hint != nil { + t.Errorf("hint = %v, want null when the engine gave none", e.Hint) + } +} + +// 42703 carries PostgreSQL's own message shape: the qualified column, no +// quotes, exactly as a live v14 forwards it. +func TestUndefinedColumnMessage(t *testing.T) { + got := ErrUndefinedColumn("todos.nope").Message + if want := "column todos.nope does not exist"; got != want { + t.Errorf("message = %q, want %q", got, want) + } +} + +// The wrong-verb and read-only texts match a live v14's exactly. +func TestRPCMethodMessages(t *testing.T) { + if got, want := ErrInvalidRPCMethod("TRACE").Message, "Cannot use the TRACE method on RPC"; got != want { + t.Errorf("PGRST101 message = %q, want %q", got, want) + } + if got, want := ErrReadOnlyTransaction("UPDATE").Message, "cannot execute UPDATE in a read-only transaction"; got != want { + t.Errorf("25006 message = %q, want %q", got, want) + } +} + +// PGRST203 spells the surviving overloads into the message and tells the +// client how to break the tie; PGRST125's message is pinned to the live text. +func TestAmbiguousFunctionAndInvalidPath(t *testing.T) { + e := ErrAmbiguousFunction([]string{"api.add(a => integer)", "api.add(a => text)"}) + want := "Could not choose the best candidate function between: api.add(a => integer), api.add(a => text)" + if e.Message != want { + t.Errorf("PGRST203 message = %q, want %q", e.Message, want) + } + if e.Hint == nil || !strings.Contains(*e.Hint, "function overloading can be resolved") { + t.Errorf("PGRST203 hint = %v, want the renaming suggestion", e.Hint) + } + if got, want := ErrInvalidPath().Message, "Invalid path specified in request URL"; got != want { + t.Errorf("PGRST125 message = %q, want %q", got, want) + } +} + +// PGRST102 is the v14 code for every request-body failure. The default message +// is PostgREST's generic JSON-body text; a specific parser failure overrides it. +func TestInvalidBodyMessages(t *testing.T) { + if got := ErrInvalidBody("").Message; got != "Empty or invalid json" { + t.Errorf("default message = %q, want %q", got, "Empty or invalid json") + } + if got := ErrInvalidBody("All object keys must match").Message; got != "All object keys must match" { + t.Errorf("override message = %q", got) + } +} + +// The request-side media type error carries PostgREST's exact message shape, +// naming the offending Content-Type. +func TestUnsupportedMediaTypeMessage(t *testing.T) { + got := ErrUnsupportedMediaType("application/yaml").Message + if want := "Content-Type not acceptable: application/yaml"; got != want { + t.Errorf("message = %q, want %q", got, want) } } @@ -132,6 +296,65 @@ func TestWithMessage(t *testing.T) { } } +// ErrNoRelationship names the searched pair and the schema in its details, and +// the schema comes from the parent relation (item 04.4). A bare search reports +// no hint; a hinted one echoes the hint clause before the schema. +func TestNoRelationshipDetails(t *testing.T) { + bare := ErrNoRelationship("films", "directors", "public", "") + if bare.Details == nil { + t.Fatal("details are nil") + } + want := "Searched for a foreign key relationship between 'films' and 'directors' in the schema 'public', but no matches were found." + if *bare.Details != want { + t.Errorf("details = %q, want %q", *bare.Details, want) + } + + hinted := ErrNoRelationship("films", "directors", "api", "fk_director") + wantHinted := "Searched for a foreign key relationship between 'films' and 'directors' using the hint 'fk_director' in the schema 'api', but no matches were found." + if hinted.Details == nil || *hinted.Details != wantHinted { + t.Errorf("hinted details = %v, want %q", hinted.Details, wantHinted) + } +} + +// ErrAmbiguousEmbed renders the candidate array verbatim and a Try-changing hint +// listing each candidate's disambiguated embed spelling (item 04.4). With no +// candidates it degrades to message only rather than an empty array and hint. +func TestAmbiguousEmbedDetailsAndHint(t *testing.T) { + cands := []EmbedCandidate{ + {Cardinality: "many-to-one", Embedding: "films with people", Relationship: "films_director_id_fkey using films(director_id) and people(id)", Name: "films_director_id_fkey"}, + {Cardinality: "many-to-one", Embedding: "films with people", Relationship: "films_writer_id_fkey using films(writer_id) and people(id)", Name: "films_writer_id_fkey"}, + } + e := ErrAmbiguousEmbed("films", "people", cands) + if e.RawDetails == nil { + t.Fatal("details are nil, want the candidate array") + } + var got []EmbedCandidate + if err := json.Unmarshal(e.RawDetails, &got); err != nil { + t.Fatalf("details not an array: %v", err) + } + if len(got) != 2 || got[0].Cardinality != "many-to-one" { + t.Errorf("candidates = %v", got) + } + // Name is carried for the hint, not the details body. + if bytesHasName := e.RawDetails; jsonContains(bytesHasName, `"name"`) { + t.Error("details array leaked the unexported name key") + } + wantHint := "Try changing 'people' to one of the following: 'people!films_director_id_fkey', 'people!films_writer_id_fkey'. Find the desired relationship in the 'details' key." + if e.Hint == nil || *e.Hint != wantHint { + t.Errorf("hint = %v, want %q", e.Hint, wantHint) + } + + if bare := ErrAmbiguousEmbed("films", "people", nil); bare.RawDetails != nil || bare.Hint != nil { + t.Errorf("no-candidate ambiguous embed should be message only, got details=%s hint=%v", bare.RawDetails, bare.Hint) + } +} + +// jsonContains reports whether raw JSON bytes contain the literal substring, +// used to assert the details array does not serialize the candidate name key. +func jsonContains(raw json.RawMessage, sub string) bool { + return strings.Contains(string(raw), sub) +} + // JSON rendering is on the error path of every failed request, so it carries its // own benchmark: the four-key envelope with details and hint populated, the // shape a PGRST127 actually ships. diff --git a/pgerr/pgerr.go b/pgerr/pgerr.go index 12e3b96..e29aa11 100644 --- a/pgerr/pgerr.go +++ b/pgerr/pgerr.go @@ -9,6 +9,7 @@ package pgerr import ( "encoding/json" "net/http" + "strconv" ) // APIError is the canonical error value. It carries the wire envelope @@ -20,14 +21,30 @@ import ( type APIError struct { // HTTPStatus is the HTTP status code. It is not part of the JSON body. HTTPStatus int `json:"-"` + // WWWAuthenticate, when set, is emitted as the WWW-Authenticate response + // header. PostgREST sends it on every 401: the RFC 6750 invalid_token form + // on PGRST301/PGRST303 and the bare "Bearer" challenge otherwise. It is not + // part of the JSON body. + WWWAuthenticate string `json:"-"` // Code is the PGRST code (or a backend SQLSTATE passed through). Code string `json:"code"` // Message is the human-facing summary. Message string `json:"message"` // Details is extra context, or null. Details *string `json:"details"` + // RawDetails carries a details payload that is not a string: PostgREST's + // PGRST201 returns details as a JSON array of candidate relationship + // objects, which clients read to auto-disambiguate an embed. When set it + // takes precedence over Details in the rendered envelope. + RawDetails json.RawMessage `json:"-"` // Hint is a suggested fix, or null. Hint *string `json:"hint"` + // Headers are extra response headers emitted with the error. A function that + // raises a full-control error (SQLSTATE 'PGRST') supplies them in the DETAIL + // JSON's headers object; Write merges them onto the response. They are not + // part of the JSON body. This is the error-path analog of ResponseControls + // headers on the success path. + Headers http.Header `json:"-"` } // Error implements the error interface. @@ -45,6 +62,18 @@ func (e *APIError) WithDetails(details string) *APIError { return &c } +// WithDetailsJSON returns a copy of e with details set to a non-string JSON +// value, the shape PGRST201 uses for its candidate relationship array. v is +// marshaled immediately; a value that cannot marshal leaves details unchanged +// rather than corrupting the envelope. +func (e *APIError) WithDetailsJSON(v any) *APIError { + c := *e + if b, err := json.Marshal(v); err == nil { + c.RawDetails = b + } + return &c +} + // WithHint returns a copy of e with hint set. func (e *APIError) WithHint(hint string) *APIError { c := *e @@ -52,6 +81,24 @@ func (e *APIError) WithHint(hint string) *APIError { return &c } +// WithHeaders returns a copy of e carrying the given response headers, the shape +// FromRaise returns for a full-control raised error. The headers ride on the +// error and Write merges them onto the response; an empty map is a no-op. +func (e *APIError) WithHeaders(h map[string]string) *APIError { + if len(h) == 0 { + return e + } + c := *e + c.Headers = http.Header{} + for k, vs := range e.Headers { + c.Headers[k] = vs + } + for k, v := range h { + c.Headers.Set(k, v) + } + return &c +} + // WithMessage returns a copy of e with the message replaced. func (e *APIError) WithMessage(msg string) *APIError { c := *e @@ -60,33 +107,66 @@ func (e *APIError) WithMessage(msg string) *APIError { } // body is the exact JSON shape sent to the client. Keys are always present; -// Details and Hint are encoded as null when nil because they are pointers. +// details and hint are encoded as null when unset. details is raw so it can be +// a string, null, or PGRST201's array of relationship candidates. type body struct { - Code string `json:"code"` - Message string `json:"message"` - Details *string `json:"details"` - Hint *string `json:"hint"` + Code string `json:"code"` + Message string `json:"message"` + Details json.RawMessage `json:"details"` + Hint *string `json:"hint"` } // JSON returns the rendered envelope bytes for e. func (e *APIError) JSON() []byte { + details := json.RawMessage("null") + switch { + case e.RawDetails != nil: + details = e.RawDetails + case e.Details != nil: + if b, err := json.Marshal(*e.Details); err == nil { + details = b + } + } b, _ := json.Marshal(body{ Code: e.Code, Message: e.Message, - Details: e.Details, + Details: details, Hint: e.Hint, }) return b } -// Write renders e onto w: it sets the JSON content type and the status, then -// writes the envelope. It is the single place an error reaches the client. +// Write renders e onto w: it sets the JSON content type, the Proxy-Status +// header, the WWW-Authenticate challenge when one is carried, and the status, +// then writes the envelope. It is the single place an error reaches the +// client. v14 adds Proxy-Status to every error response so a HEAD request, +// whose status alone is not descriptive enough, still names the error code; +// the "PostgREST" identifier is kept byte-identical for wire compatibility. func (e *APIError) Write(w http.ResponseWriter) { + // A full-control raised error carries its own headers; merge them first so the + // fixed envelope headers below win, keeping the body well-formed even if a + // function tries to override Content-Type. + for k, vs := range e.Headers { + for _, v := range vs { + w.Header().Add(k, v) + } + } w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Proxy-Status", "PostgREST; error="+e.Code) + if e.WWWAuthenticate != "" { + w.Header().Set("WWW-Authenticate", e.WWWAuthenticate) + } w.WriteHeader(e.HTTPStatus) _, _ = w.Write(e.JSON()) } +// BearerInvalidToken renders the RFC 6750 challenge PostgREST sends with a JWT +// decode or claims error: Bearer error="invalid_token" with the error message +// quoted into error_description. +func BearerInvalidToken(msg string) string { + return `Bearer error="invalid_token", error_description=` + strconv.Quote(msg) +} + // New builds an APIError from its parts. func New(status int, code, message string) *APIError { return &APIError{HTTPStatus: status, Code: code, Message: message} diff --git a/pgerr/pgerr_test.go b/pgerr/pgerr_test.go index c048cc1..8fd8a5c 100644 --- a/pgerr/pgerr_test.go +++ b/pgerr/pgerr_test.go @@ -45,6 +45,54 @@ func TestWithDetailsHintImmutable(t *testing.T) { } } +// PGRST201 returns details as a JSON array of candidate relationship objects; +// the envelope must carry it as an array, not a quoted string, while string +// details and null keep their existing encodings. +func TestDetailsCanCarryNonStringJSON(t *testing.T) { + candidates := []map[string]string{{ + "cardinality": "many-to-one", + "embedding": "orders with addresses", + "relationship": "billing using orders(billing_address_id) and addresses(id)", + }} + base := ErrAmbiguousEmbed("orders", "addresses", nil) + e := base.WithDetailsJSON(candidates) + if base.RawDetails != nil { + t.Error("WithDetailsJSON mutated the receiver") + } + + var m map[string]json.RawMessage + if err := json.Unmarshal(e.JSON(), &m); err != nil { + t.Fatalf("envelope not valid json: %v", err) + } + var got []map[string]string + if err := json.Unmarshal(m["details"], &got); err != nil { + t.Fatalf("details is not a JSON array: %v: %s", err, m["details"]) + } + if len(got) != 1 || got[0]["embedding"] != "orders with addresses" { + t.Errorf("details round-trip = %v", got) + } + + // A string details still renders as a JSON string. + var sm map[string]json.RawMessage + se := ErrParse("x").WithDetails("plain text") + if err := json.Unmarshal(se.JSON(), &sm); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if string(sm["details"]) != `"plain text"` { + t.Errorf("string details = %s, want %q", sm["details"], `"plain text"`) + } + + // Raw details win over a previously set string. + both := se.WithDetailsJSON([]int{1, 2}) + var bm map[string]json.RawMessage + if err := json.Unmarshal(both.JSON(), &bm); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if string(bm["details"]) != "[1,2]" { + t.Errorf("raw details = %s, want [1,2]", bm["details"]) + } +} + func TestUnsupportedNamesFeatureAndBackend(t *testing.T) { e := ErrUnsupported("range operator 'sl'", "mysql") if e.HTTPStatus != http.StatusBadRequest { @@ -64,13 +112,18 @@ func TestUnsupportedNamesFeatureAndBackend(t *testing.T) { func TestWriteSetsStatusAndContentType(t *testing.T) { rec := httptest.NewRecorder() - ErrUnknownTable("films").Write(rec) + ErrUnknownTable("public", "films").Write(rec) if rec.Code != http.StatusNotFound { t.Errorf("status = %d, want 404", rec.Code) } if ct := rec.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" { t.Errorf("content-type = %q", ct) } + // v14 names the error code in Proxy-Status so HEAD requests can identify + // the failure; the value matches a live v14's byte for byte. + if ps := rec.Header().Get("Proxy-Status"); ps != "PostgREST; error=PGRST205" { + t.Errorf("Proxy-Status = %q, want %q", ps, "PostgREST; error=PGRST205") + } var b body if err := json.Unmarshal(rec.Body.Bytes(), &b); err != nil { t.Fatalf("body not valid json: %v", err) @@ -78,6 +131,35 @@ func TestWriteSetsStatusAndContentType(t *testing.T) { if b.Code != CodeUnknownTable { t.Errorf("code = %s", b.Code) } + if h := rec.Header().Get("WWW-Authenticate"); h != "" { + t.Errorf("a non-auth error must not carry WWW-Authenticate, got %q", h) + } +} + +func TestWriteEmitsWWWAuthenticate(t *testing.T) { + rec := httptest.NewRecorder() + ErrJWTClaims("JWT expired").Write(rec) + want := `Bearer error="invalid_token", error_description="JWT expired"` + if h := rec.Header().Get("WWW-Authenticate"); h != want { + t.Errorf("WWW-Authenticate = %q, want %q", h, want) + } +} + +// A full-control raised error carries response headers that Write must merge +// onto the response, while the fixed envelope headers still win (item 04.9). +func TestWriteEmitsCarriedHeaders(t *testing.T) { + rec := httptest.NewRecorder() + e := New(402, "123", "Payment Required").WithHeaders(map[string]string{ + "X-Reason": "quota", + "Content-Type": "text/plain", // a function must not be able to break the body + }) + e.Write(rec) + if h := rec.Header().Get("X-Reason"); h != "quota" { + t.Errorf("X-Reason = %q, want quota", h) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" { + t.Errorf("content-type = %q, want the reserved envelope type to win", ct) + } } func TestAs(t *testing.T) { diff --git a/pgerr/raise.go b/pgerr/raise.go new file mode 100644 index 0000000..0f69903 --- /dev/null +++ b/pgerr/raise.go @@ -0,0 +1,76 @@ +package pgerr + +import ( + "encoding/json" + "fmt" + "net/http" +) + +// A function can take full control of the response by raising SQLSTATE 'PGRST' +// with a JSON object in MESSAGE ({code, message, details?, hint?}, the +// envelope) and a JSON object in DETAIL ({status, headers, status_text?}, the +// response control). PostgREST forwards the envelope verbatim, sets the HTTP +// status from detail.status, and applies detail.headers; a payload it cannot +// parse is reported as PGRST121 at 500 with details naming the malformed +// field. All texts and the obligatory-key rules below were verified against a +// live v14 (code and message are obligatory in MESSAGE; status and headers, +// which may be an empty object, are obligatory in DETAIL). + +// CodeRaiseParse is PGRST121: the MESSAGE or DETAIL payload of a RAISE +// SQLSTATE 'PGRST' could not be parsed. +const CodeRaiseParse = "PGRST121" + +const ( + raiseParseMessage = `Could not parse JSON in the "RAISE SQLSTATE 'PGRST'" error` + raiseMessageHint = "MESSAGE must be a JSON object with obligatory keys: 'code', 'message' and optional keys: 'details', 'hint'." + raiseDetailHint = "DETAIL must be a JSON object with obligatory keys: 'status', 'headers' and optional key: 'status_text'." +) + +// ErrRaiseParse is the PGRST121 envelope, with details naming the malformed +// field and the hint spelling the expected shape. +func ErrRaiseParse(details, hint string) *APIError { + return New(http.StatusInternalServerError, CodeRaiseParse, raiseParseMessage). + WithDetails(details).WithHint(hint) +} + +// raiseMessage is the envelope object a function puts in MESSAGE. Pointer +// fields distinguish a missing obligatory key from an empty value. +type raiseMessage struct { + Code *string `json:"code"` + Message *string `json:"message"` + Details *string `json:"details"` + Hint *string `json:"hint"` +} + +// raiseDetail is the response-control object a function puts in DETAIL. +type raiseDetail struct { + Status *int `json:"status"` + StatusText *string `json:"status_text"` + Headers map[string]string `json:"headers"` +} + +// FromRaise assembles the client-controlled error from the MESSAGE and DETAIL +// strings of a RAISE SQLSTATE 'PGRST'. On success it returns the function's +// envelope with the status from detail.status, plus the headers to apply to +// the response. When either payload cannot be parsed it returns the PGRST121 +// envelope and no headers, exactly as PostgREST does; pass detail as the empty +// string when the RAISE carried no DETAIL. +func FromRaise(message, detail string) (*APIError, map[string]string) { + var m raiseMessage + if err := json.Unmarshal([]byte(message), &m); err != nil || m.Code == nil || m.Message == nil { + return ErrRaiseParse( + fmt.Sprintf("Invalid JSON value for MESSAGE: '%s'", message), raiseMessageHint), nil + } + if detail == "" { + return ErrRaiseParse("DETAIL is missing in the RAISE statement", raiseDetailHint), nil + } + var d raiseDetail + if err := json.Unmarshal([]byte(detail), &d); err != nil || d.Status == nil || d.Headers == nil { + return ErrRaiseParse( + fmt.Sprintf("Invalid JSON value for DETAIL: '%s'", detail), raiseDetailHint), nil + } + e := New(*d.Status, *m.Code, *m.Message) + e.Details = m.Details + e.Hint = m.Hint + return e, d.Headers +} diff --git a/pgerr/raise_test.go b/pgerr/raise_test.go new file mode 100644 index 0000000..0786260 --- /dev/null +++ b/pgerr/raise_test.go @@ -0,0 +1,86 @@ +package pgerr + +import ( + "net/http" + "testing" +) + +// The happy path: a function controls status, headers, and the whole envelope. +// The payloads mirror the documented example, which a live v14 answers with +// 402, the X-Powered-By header, and the envelope verbatim. +func TestFromRaiseFullControl(t *testing.T) { + e, headers := FromRaise( + `{"code":"123","message":"Payment Required","details":"Quota exceeded","hint":"Upgrade your plan"}`, + `{"status":402,"headers":{"X-Powered-By":"Nerd Rage"}}`) + if e.HTTPStatus != http.StatusPaymentRequired { + t.Errorf("status = %d, want 402", e.HTTPStatus) + } + if e.Code != "123" || e.Message != "Payment Required" { + t.Errorf("envelope = %s: %s", e.Code, e.Message) + } + if e.Details == nil || *e.Details != "Quota exceeded" { + t.Errorf("details = %v", e.Details) + } + if e.Hint == nil || *e.Hint != "Upgrade your plan" { + t.Errorf("hint = %v", e.Hint) + } + if headers["X-Powered-By"] != "Nerd Rage" { + t.Errorf("headers = %v", headers) + } +} + +// details and hint are optional in MESSAGE; headers may be an empty object. +func TestFromRaiseMinimal(t *testing.T) { + e, headers := FromRaise(`{"code":"123","message":"m"}`, `{"status":402,"headers":{}}`) + if e.Code != "123" || e.HTTPStatus != 402 { + t.Errorf("envelope = %s status %d", e.Code, e.HTTPStatus) + } + if e.Details != nil || e.Hint != nil { + t.Errorf("details/hint should stay null: %v %v", e.Details, e.Hint) + } + if headers == nil || len(headers) != 0 { + t.Errorf("headers = %v, want empty map", headers) + } +} + +// Every malformed payload comes back as the PGRST121 envelope with details +// naming the field and the hint spelling the expected shape; the texts are +// pinned to a live v14's byte for byte. +func TestFromRaiseParseFailures(t *testing.T) { + cases := []struct { + name string + message, detail string + details, hint string + }{ + {"message not json", "not json", `{"status":402,"headers":{}}`, + "Invalid JSON value for MESSAGE: 'not json'", raiseMessageHint}, + {"message missing code", `{"message":"no code"}`, `{"status":419,"headers":{}}`, + `Invalid JSON value for MESSAGE: '{"message":"no code"}'`, raiseMessageHint}, + {"detail not json", `{"code":"123","message":"ok"}`, "nope", + "Invalid JSON value for DETAIL: 'nope'", raiseDetailHint}, + {"detail missing headers", `{"code":"123","message":"m"}`, `{"status":402}`, + `Invalid JSON value for DETAIL: '{"status":402}'`, raiseDetailHint}, + {"detail missing", `{"code":"123","message":"just msg"}`, "", + "DETAIL is missing in the RAISE statement", raiseDetailHint}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + e, headers := FromRaise(c.message, c.detail) + if headers != nil { + t.Errorf("headers = %v, want none on a parse failure", headers) + } + if e.HTTPStatus != http.StatusInternalServerError || e.Code != CodeRaiseParse { + t.Errorf("got %d %s, want 500 PGRST121", e.HTTPStatus, e.Code) + } + if e.Message != raiseParseMessage { + t.Errorf("message = %q", e.Message) + } + if e.Details == nil || *e.Details != c.details { + t.Errorf("details = %v, want %q", e.Details, c.details) + } + if e.Hint == nil || *e.Hint != c.hint { + t.Errorf("hint = %v, want %q", e.Hint, c.hint) + } + }) + } +} diff --git a/plan/aggregate_test.go b/plan/aggregate_test.go new file mode 100644 index 0000000..e873ae6 --- /dev/null +++ b/plan/aggregate_test.go @@ -0,0 +1,66 @@ +package plan + +import ( + "testing" + + "github.com/tamnd/dbrest/ir" +) + +// aggQuery builds a films read whose projection carries the given select items. +func aggQuery(items ...ir.SelectItem) *ir.Query { + return &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Name: "films"}, + Select: items, + } +} + +func TestAggregateGatedOffByDefault(t *testing.T) { + q := aggQuery(ir.Aggregate{Func: ir.AggCount}) + _, err := Read(model(), q, nil, Options{}) // AggregatesEnabled defaults false + if err == nil || err.Code != "PGRST123" { + t.Fatalf("want PGRST123 with aggregates off, got %v", err) + } +} + +func TestAggregateAllowedWhenEnabled(t *testing.T) { + q := aggQuery( + ir.Column{Path: []string{"year"}}, + ir.Aggregate{Func: ir.AggSum, Arg: &ir.Column{Path: []string{"id"}}}, + ) + if _, err := Read(model(), q, nil, Options{AggregatesEnabled: true}); err != nil { + t.Fatalf("unexpected error with aggregates on: %v", err) + } +} + +func TestAggregateArgColumnValidated(t *testing.T) { + // nope is not a films column; even with aggregates enabled the arg is checked. + q := aggQuery(ir.Aggregate{Func: ir.AggSum, Arg: &ir.Column{Path: []string{"nope"}}}) + _, err := Read(model(), q, nil, Options{AggregatesEnabled: true}) + // An aggregate over a column that does not exist reaches PostgreSQL: 42703 + // (item 04.5), not the schema-cache PGRST204. + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703 for unknown aggregate column, got %v", err) + } +} + +func TestLegacyEmbedCountExemptFromGate(t *testing.T) { + // A legacy bare count carried by an embed is allowed even with aggregates off. + // It is validated through the embedded relation, so use a real relationship. + q := &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Name: "films"}, + Embeds: []ir.Embed{{ + Target: ir.Ref{Name: "directors"}, + OutKey: "directors", + Query: ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Name: "directors"}, + Select: []ir.SelectItem{ir.Aggregate{Func: ir.AggCount, Legacy: true}}, + }, + }}, + } + if _, err := Read(nullEmbedModel(), q, []string{"public"}, Options{}); err != nil { + t.Fatalf("legacy embed count should be exempt from the gate, got %v", err) + } +} diff --git a/plan/bench_test.go b/plan/bench_test.go index e8dc750..e521d10 100644 --- a/plan/bench_test.go +++ b/plan/bench_test.go @@ -58,7 +58,7 @@ func BenchmarkReadPlan(b *testing.B) { for b.Loop() { // A fresh query each iteration: Read binds resolved pointers onto it, so a // reused value would measure planning an already-planned query. - if _, err := Read(m, newQuery(), path); err != nil { + if _, err := Read(m, newQuery(), path, Options{}); err != nil { b.Fatal(err) } } diff --git a/plan/call_test.go b/plan/call_test.go index 8ed8006..e24d0af 100644 --- a/plan/call_test.go +++ b/plan/call_test.go @@ -1,6 +1,7 @@ package plan import ( + "strings" "testing" "github.com/tamnd/dbrest/ir" @@ -25,7 +26,7 @@ func TestCallResolvesFunction(t *testing.T) { c := &ir.Call{Function: ir.Ref{Name: "add_them"}, Args: map[string]ir.Value{ "a": {Text: "2"}, "b": {Text: "3"}, }} - p, err := Call(reg(addThem()), c, true, nil) + p, err := Call(reg(addThem()), nil, c, true, nil) if err != nil { t.Fatalf("Call: %v", err) } @@ -39,7 +40,7 @@ func TestCallResolvesFunction(t *testing.T) { func TestCallNoFunctionIs404(t *testing.T) { c := &ir.Call{Function: ir.Ref{Name: "nope"}} - _, err := Call(reg(addThem()), c, true, nil) + _, err := Call(reg(addThem()), nil, c, true, nil) if err == nil || err.Code != "PGRST202" { t.Fatalf("want PGRST202, got %v", err) } @@ -48,12 +49,152 @@ func TestCallNoFunctionIs404(t *testing.T) { func TestCallArgMismatchIs404(t *testing.T) { // add_them needs a and b; only a is supplied. c := &ir.Call{Function: ir.Ref{Name: "add_them"}, Args: map[string]ir.Value{"a": {Text: "2"}}} - _, err := Call(reg(addThem()), c, true, nil) + _, err := Call(reg(addThem()), nil, c, true, nil) if err == nil || err.Code != "PGRST202" { t.Fatalf("want PGRST202, got %v", err) } } +// TestCallAmbiguousOverloadIs300 checks that two overloads tying at the top score +// surface as PGRST203 (a 300) carrying both competing signatures, rather than the +// planner silently picking one. Two single-optional-parameter overloads called +// with no arguments are equally good. +func TestCallAmbiguousOverloadIs300(t *testing.T) { + left := &rpc.Function{ + Name: "f", + Params: []rpc.Param{{Name: "a", Type: "integer", Optional: true}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar}, + Volatility: rpc.Immutable, + Query: &rpc.PortableQuery{SQL: "SELECT 1"}, + } + right := &rpc.Function{ + Name: "f", + Params: []rpc.Param{{Name: "b", Type: "integer", Optional: true}}, + Returns: rpc.ReturnShape{Kind: rpc.ReturnScalar}, + Volatility: rpc.Immutable, + Query: &rpc.PortableQuery{SQL: "SELECT 1"}, + } + c := &ir.Call{Function: ir.Ref{Name: "f"}} + _, err := Call(reg(left, right), nil, c, true, nil) + if err == nil || err.Code != "PGRST203" { + t.Fatalf("want PGRST203, got %v", err) + } + if err.HTTPStatus != 300 { + t.Errorf("status = %d, want 300", err.HTTPStatus) + } +} + +// TestCallNoFunctionMessageQualifiedWithHint checks the PGRST202 message names the +// function schema-qualified with the searched argument list, and that an overload +// of the same name rides along as the nearest-signature hint. +func TestCallNoFunctionMessageQualifiedWithHint(t *testing.T) { + // add_them(a, b) exists; the call supplies (a, c), matching neither overload. + c := &ir.Call{Function: ir.Ref{Name: "add_them"}, Args: map[string]ir.Value{ + "a": {Text: "1"}, "c": {Text: "2"}, + }} + _, err := Call(reg(addThem()), nil, c, true, []string{"api"}) + if err == nil || err.Code != "PGRST202" { + t.Fatalf("want PGRST202, got %v", err) + } + if want := "api.add_them(a, c)"; !strings.Contains(err.Message, want) { + t.Errorf("message = %q, want it to mention %q", err.Message, want) + } + if err.Hint == nil { + t.Fatal("PGRST202 should carry a nearest-signature hint") + } + if want := "add_them(a => integer, b => integer)"; !strings.Contains(*err.Hint, want) { + t.Errorf("hint = %q, want it to mention %q", *err.Hint, want) + } +} + +// TestCallNoParameterlessMessage checks the "without parameters" phrasing when the +// call names a function with no arguments and none is registered. +func TestCallNoParameterlessMessage(t *testing.T) { + c := &ir.Call{Function: ir.Ref{Name: "ghost"}} + _, err := Call(reg(addThem()), nil, c, true, []string{"api"}) + if err == nil || err.Code != "PGRST202" { + t.Fatalf("want PGRST202, got %v", err) + } + if want := "api.ghost without parameters"; !strings.Contains(err.Message, want) { + t.Errorf("message = %q, want it to mention %q", err.Message, want) + } +} + +// TestCallGetPartitionsArgsFromFilters checks the GET argument-versus-filter +// split: a key naming a declared parameter binds as an argument, while a key that +// does not name a parameter is re-read as a post-filter on the table return. The +// function still resolves, and the filter lands in the call's WHERE. +func TestCallGetPartitionsArgsFromFilters(t *testing.T) { + c := &ir.Call{ + Function: ir.Ref{Name: "films_after"}, + Args: map[string]ir.Value{"y": {Text: "2000"}, "title": {Text: "eq.Arrival"}}, + RawGet: map[string][]string{ + "y": {"2000"}, + "title": {"eq.Arrival"}, + }, + } + p, err := Call(reg(filmsAfter()), nil, c, true, nil) + if err != nil { + t.Fatalf("Call: %v", err) + } + if p.Func == nil || p.Func.Name != "films_after" { + t.Fatalf("function not bound: %+v", p.Func) + } + // y stayed an argument; title moved out of the argument map. + if _, ok := c.Args["y"]; !ok { + t.Error("declared parameter y should remain an argument") + } + if _, ok := c.Args["title"]; ok { + t.Error("title names no parameter and should not be an argument") + } + // title became a post-filter in the WHERE tree. + if c.Where == nil { + t.Fatal("the non-parameter key should have become a filter") + } + cmp, ok := (*c.Where).(ir.Compare) + if !ok || len(cmp.Path) != 1 || cmp.Path[0] != "title" || cmp.Op != ir.OpEq { + t.Errorf("WHERE = %#v, want title eq filter", *c.Where) + } +} + +// TestCallGetFilterUnknownColumnRejected checks a partitioned filter is still +// validated against the table return's declared columns, so a non-parameter key +// naming no column reaches PostgreSQL as 42703 (item 04.5) rather than silently +// dropped. +func TestCallGetFilterUnknownColumnRejected(t *testing.T) { + c := &ir.Call{ + Function: ir.Ref{Name: "films_after"}, + Args: map[string]ir.Value{"y": {Text: "2000"}, "ghost": {Text: "eq.1"}}, + RawGet: map[string][]string{ + "y": {"2000"}, + "ghost": {"eq.1"}, + }, + } + _, err := Call(reg(filmsAfter()), nil, c, true, nil) + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) + } +} + +// TestCallGetArgTypeCoercion checks a GET text argument is validated against its +// declared parameter type, so a non-integer value for an integer parameter is the +// same 22P02 a read filter raises, on every backend. +func TestCallGetArgTypeCoercion(t *testing.T) { + c := &ir.Call{ + Function: ir.Ref{Name: "add_them"}, + Args: map[string]ir.Value{"a": {Text: "notanint"}, "b": {Text: "3"}}, + RawGet: map[string][]string{"a": {"notanint"}, "b": {"3"}}, + } + _, err := Call(reg(addThem()), nil, c, true, nil) + if err == nil || err.HTTPStatus != 400 { + t.Fatalf("want a 400 coercion error, got %v", err) + } +} + +// A GET reaching a volatile function fails the way PostgREST's does: the read-only +// transaction rejects the write with SQLSTATE 25006 at 405, not a PGRST101. The +// registry path raises it from the declared volatility since it cannot run the +// call, but the code and status a client sees match the native path (item 04.6). func TestCallGetOnVolatileIs405(t *testing.T) { vol := &rpc.Function{ Name: "do_thing", @@ -62,9 +203,9 @@ func TestCallGetOnVolatileIs405(t *testing.T) { Query: &rpc.PortableQuery{SQL: "SELECT 1"}, } c := &ir.Call{Function: ir.Ref{Name: "do_thing"}} - _, err := Call(reg(vol), c, true, nil) - if err == nil || err.Code != "PGRST101" { - t.Fatalf("want PGRST101, got %v", err) + _, err := Call(reg(vol), nil, c, true, nil) + if err == nil || err.Code != "25006" { + t.Fatalf("want 25006, got %v", err) } if err.HTTPStatus != 405 { t.Errorf("status = %d, want 405", err.HTTPStatus) @@ -79,7 +220,7 @@ func TestCallPostOnVolatileIsAllowed(t *testing.T) { Query: &rpc.PortableQuery{SQL: "SELECT 1"}, } c := &ir.Call{Function: ir.Ref{Name: "do_thing"}} - p, err := Call(reg(vol), c, false, nil) + p, err := Call(reg(vol), nil, c, false, nil) if err != nil { t.Fatalf("Call: %v", err) } @@ -101,9 +242,9 @@ func TestCallPostFilterUnknownColumn(t *testing.T) { Args: map[string]ir.Value{"y": {Text: "2000"}}, Select: []ir.SelectItem{ir.Column{Path: []string{"bogus"}}}, } - _, err := Call(reg(tab), c, true, nil) - if err == nil || err.Code != "PGRST204" { - t.Fatalf("want PGRST204, got %v", err) + _, err := Call(reg(tab), nil, c, true, nil) + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) } } @@ -120,7 +261,7 @@ func TestCallPostFilterKnownColumnOK(t *testing.T) { Args: map[string]ir.Value{"y": {Text: "2000"}}, Select: []ir.SelectItem{ir.Column{Path: []string{"title"}}}, } - if _, err := Call(reg(tab), c, true, nil); err != nil { + if _, err := Call(reg(tab), nil, c, true, nil); err != nil { t.Fatalf("Call: %v", err) } } @@ -156,7 +297,7 @@ func TestCallPostFilterWhereTreeKnownColumns(t *testing.T) { ir.Not{Kid: ir.Compare{Path: []string{"id"}, Op: ir.OpEq, Value: ir.Value{Text: "0"}}}, }}, }}) - if _, err := Call(reg(filmsAfter()), callWith(where), true, nil); err != nil { + if _, err := Call(reg(filmsAfter()), nil, callWith(where), true, nil); err != nil { t.Fatalf("Call with a valid filter tree: %v", err) } } @@ -176,9 +317,9 @@ func TestCallPostFilterWhereTreeUnknownColumn(t *testing.T) { } for name, where := range cases { t.Run(name, func(t *testing.T) { - _, err := Call(reg(filmsAfter()), callWith(where), true, nil) - if err == nil || err.Code != "PGRST204" { - t.Fatalf("want PGRST204, got %v", err) + _, err := Call(reg(filmsAfter()), nil, callWith(where), true, nil) + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) } }) } @@ -191,9 +332,9 @@ func TestCallPostFilterOrderUnknownColumn(t *testing.T) { Args: map[string]ir.Value{"y": {Text: "2000"}}, Order: []ir.OrderTerm{{Path: []string{"ghost"}}}, } - _, err := Call(reg(filmsAfter()), c, true, nil) - if err == nil || err.Code != "PGRST204" { - t.Fatalf("want PGRST204, got %v", err) + _, err := Call(reg(filmsAfter()), nil, c, true, nil) + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) } } @@ -206,7 +347,7 @@ func TestCallScalarReturnSkipsFilterValidation(t *testing.T) { Args: map[string]ir.Value{"a": {Text: "1"}, "b": {Text: "2"}}, Where: &where, } - if _, err := Call(reg(addThem()), c, true, nil); err != nil { + if _, err := Call(reg(addThem()), nil, c, true, nil); err != nil { t.Fatalf("scalar return should not validate post-filter columns: %v", err) } } diff --git a/plan/coerce_test.go b/plan/coerce_test.go index 7695ad7..d7907ac 100644 --- a/plan/coerce_test.go +++ b/plan/coerce_test.go @@ -9,7 +9,7 @@ import ( func TestReadCoercesIntegerFilter(t *testing.T) { where := ir.Cond(ir.Compare{Path: []string{"year"}, Op: ir.OpEq, Value: ir.Value{Text: "abc"}}) q := &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where} - _, err := Read(model(), q, nil) + _, err := Read(model(), q, nil, Options{}) if err == nil || err.Code != "22P02" { t.Fatalf("a non-integer operand on an integer column should be 22P02, got %v", err) } @@ -18,7 +18,7 @@ func TestReadCoercesIntegerFilter(t *testing.T) { func TestReadAcceptsValidIntegerFilter(t *testing.T) { where := ir.Cond(ir.Compare{Path: []string{"year"}, Op: ir.OpGte, Value: ir.Value{Text: "2000"}}) q := &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where} - if _, err := Read(model(), q, nil); err != nil { + if _, err := Read(model(), q, nil, Options{}); err != nil { t.Fatalf("a valid integer operand should pass, got %v", err) } } @@ -26,7 +26,7 @@ func TestReadAcceptsValidIntegerFilter(t *testing.T) { func TestReadCoercesInListMembers(t *testing.T) { where := ir.Cond(ir.Compare{Path: []string{"id"}, Op: ir.OpIn, Value: ir.Value{List: []string{"1", "2", "x"}}}) q := &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where} - _, err := Read(model(), q, nil) + _, err := Read(model(), q, nil, Options{}) if err == nil || err.Code != "22P02" { t.Fatalf("a bad member of an in-list on an integer column should be 22P02, got %v", err) } @@ -36,7 +36,7 @@ func TestReadTextFilterAcceptsAnything(t *testing.T) { // A text column carries any operand through to the engine. where := ir.Cond(ir.Compare{Path: []string{"title"}, Op: ir.OpEq, Value: ir.Value{Text: "anything 123 !@#"}}) q := &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where} - if _, err := Read(model(), q, nil); err != nil { + if _, err := Read(model(), q, nil, Options{}); err != nil { t.Fatalf("a text operand should never be rejected, got %v", err) } } @@ -46,7 +46,7 @@ func TestReadLikePatternNotCoerced(t *testing.T) { // left for the engine, not rejected as a bad integer. where := ir.Cond(ir.Compare{Path: []string{"year"}, Op: ir.OpLike, Value: ir.Value{Text: "20%"}}) q := &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where} - if _, err := Read(model(), q, nil); err != nil { + if _, err := Read(model(), q, nil, Options{}); err != nil { t.Fatalf("a like pattern should not be coerced, got %v", err) } } diff --git a/plan/embed_test.go b/plan/embed_test.go index 3955c6d..d7b430b 100644 --- a/plan/embed_test.go +++ b/plan/embed_test.go @@ -1,6 +1,8 @@ package plan import ( + "encoding/json" + "strings" "testing" "github.com/tamnd/dbrest/ir" @@ -38,7 +40,7 @@ func readEmbed(t *testing.T, m *schema.Model, sel string) (*ir.Plan, *pgerr.APIE if perr != nil { t.Fatalf("ParseRead: %v", perr) } - return Read(m, q, nil) + return Read(m, q, nil, Options{}) } func TestEmbedResolvesAndBinds(t *testing.T) { @@ -65,6 +67,28 @@ func TestEmbedNoRelationshipIsPGRST200(t *testing.T) { if err == nil || err.Code != "PGRST200" { t.Fatalf("want PGRST200, got %v", err) } + // PGRST200 names the searched pair in its details (item 04.4), so a client + // learns which relationship was looked for, not just that one was missing. + if err.Details == nil { + t.Fatal("PGRST200 details are nil, want the searched-pair sentence") + } + want := "Searched for a foreign key relationship between 'films' and 'ghosts'" + if !strings.Contains(*err.Details, want) || !strings.Contains(*err.Details, "but no matches were found.") { + t.Errorf("details = %q, want it to contain %q", *err.Details, want) + } +} + +// A PGRST200 raised for a hinted embed echoes the hint in the details, so a +// client sees the search was constrained by the hint it gave (item 04.4). +func TestEmbedNoRelationshipWithHintEchoesHint(t *testing.T) { + m := embedModel() + _, err := readEmbed(t, m, "title,ghosts!nope(x)") + if err == nil || err.Code != "PGRST200" { + t.Fatalf("want PGRST200, got %v", err) + } + if err.Details == nil || !strings.Contains(*err.Details, "using the hint 'nope'") { + t.Errorf("details = %v, want it to echo the hint", err.Details) + } } func TestEmbedAmbiguousIsPGRST201(t *testing.T) { @@ -77,6 +101,48 @@ func TestEmbedAmbiguousIsPGRST201(t *testing.T) { if err.HTTPStatus != 300 { t.Errorf("status = %d, want 300", err.HTTPStatus) } + // The details carry the candidate array a client reads to auto-disambiguate + // (item 04.4): one entry per surviving edge, each with its cardinality, the + // "parent with target" embedding, and the join-column relationship spelling. + if err.RawDetails == nil { + t.Fatal("PGRST201 details are nil, want the candidate array") + } + var cands []map[string]string + if uerr := json.Unmarshal(err.RawDetails, &cands); uerr != nil { + t.Fatalf("details is not a JSON array: %v: %s", uerr, err.RawDetails) + } + if len(cands) != 2 { + t.Fatalf("got %d candidates, want 2: %v", len(cands), cands) + } + byRel := map[string]map[string]string{} + for _, c := range cands { + byRel[c["relationship"]] = c + } + want := "films_director_id_fkey using films(director_id) and people(id)" + got, ok := byRel[want] + if !ok { + t.Fatalf("no candidate with relationship %q, got %v", want, cands) + } + if got["cardinality"] != "many-to-one" { + t.Errorf("cardinality = %q, want many-to-one", got["cardinality"]) + } + if got["embedding"] != "films with people" { + t.Errorf("embedding = %q, want %q", got["embedding"], "films with people") + } + // The hint lists each disambiguated embed spelling and points at the details. + if err.Hint == nil { + t.Fatal("PGRST201 hint is nil, want the Try-changing list") + } + for _, frag := range []string{ + "Try changing 'people' to one of the following:", + "'people!films_director_id_fkey'", + "'people!films_writer_id_fkey'", + "Find the desired relationship in the 'details' key.", + } { + if !strings.Contains(*err.Hint, frag) { + t.Errorf("hint = %q, missing %q", *err.Hint, frag) + } + } } func TestEmbedHintDisambiguates(t *testing.T) { @@ -89,7 +155,72 @@ func TestEmbedHintDisambiguates(t *testing.T) { func TestEmbedUnknownColumnInEmbedIsRejected(t *testing.T) { m := embedModel() _, err := readEmbed(t, m, "title,people!director_id(nope)") - if err == nil || err.Code != "PGRST204" { - t.Fatalf("want PGRST204, got %v", err) + // An unknown column inside an embed's select reaches PostgreSQL: 42703 + // (item 04.5), not the schema-cache PGRST204. + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) + } +} + +// commentsModel is a self-referential thread: comments.parent_id references +// comments.id, so the derived forward (parent) and backward (children) edges +// share a hint set and a bare or hinted embed is ambiguous (PGRST201). +func commentsModel() *schema.Model { + comments := &schema.Relation{ + Name: "comments", + Columns: []*schema.Column{ + {Name: "id", Type: "integer", Position: 1}, + {Name: "parent_id", Type: "integer", Position: 2}, + {Name: "body", Type: "text", Position: 3}, + }, + PrimaryKey: []string{"id"}, + ForeignKeys: []*schema.ForeignKey{ + {Name: "comments_parent_id_fkey", Columns: []string{"parent_id"}, RefRelation: "comments", RefColumns: []string{"id"}}, + }, + } + return schema.NewModel([]*schema.Relation{comments}) +} + +// TestEmbedSelfReferentialIsAmbiguous covers 01.10: a self FK alone leaves the +// recursive embed ambiguous, with no hint able to pick a direction. +func TestEmbedSelfReferentialIsAmbiguous(t *testing.T) { + m := commentsModel() + q, perr := ir.ParseRead("comments", "select=body,comments(body)", nil) + if perr != nil { + t.Fatalf("ParseRead: %v", perr) + } + _, err := Read(m, q, nil, Options{}) + if err == nil || err.Code != "PGRST201" { + t.Fatalf("want PGRST201, got %v", err) + } +} + +// TestEmbedDeclaredEdgeResolvesRecursive covers the 01.10 escape hatch: a +// declared computed relationship names one direction of the self FK, so the +// recursive embed resolves and binds with the declared cardinality. +func TestEmbedDeclaredEdgeResolvesRecursive(t *testing.T) { + m := commentsModel() + m.AddDeclaredRelationship(schema.DeclaredRel{ + Name: "children", + ParentSchema: "", ParentName: "comments", + TargetSchema: "", TargetName: "comments", + Card: schema.CardToMany, + Local: []string{"id"}, + Foreign: []string{"parent_id"}, + }) + q, perr := ir.ParseRead("comments", "select=body,children:comments!children(body)", nil) + if perr != nil { + t.Fatalf("ParseRead: %v", perr) + } + pl, err := Read(m, q, nil, Options{}) + if err != nil { + t.Fatalf("Read: %v", err) + } + emb := pl.Query.Embeds[0] + if emb.Rel == nil || emb.Rel.Name != "children" { + t.Fatalf("embed edge = %v, want children", emb.Rel) + } + if emb.Cardinality != ir.CardToMany { + t.Errorf("cardinality = %v, want to-many", emb.Cardinality) } } diff --git a/plan/plan.go b/plan/plan.go index 1d2740b..1c29eb4 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -6,7 +6,12 @@ package plan import ( + "encoding/json" "errors" + "fmt" + "sort" + "strconv" + "strings" "github.com/tamnd/dbrest/ir" "github.com/tamnd/dbrest/pgerr" @@ -15,6 +20,17 @@ import ( "github.com/tamnd/dbrest/schema" ) +// Options carries request-level toggles the planner needs that are not part of +// the query itself. The zero value matches a default PostgREST: aggregates are +// off, so an aggregate select item is rejected with PGRST123 until the +// db-aggregates-enabled option turns it on. +type Options struct { + // AggregatesEnabled mirrors db-aggregates-enabled. When false, a request using + // count()/col.sum()/... is rejected with PGRST123; the legacy bare count an + // embed may carry is exempt and always allowed. + AggregatesEnabled bool +} + // Read resolves a parsed read query against the model and returns an executable // plan. searchPath orders the schemas an unqualified relation is looked up in. // @@ -22,37 +38,88 @@ import ( // Embeds, aggregates, and JSON paths are validated by their own subsystems as // they land; a query carrying one is passed through for the compiler to reject // with a clear PGRST127 rather than being silently accepted here. -func Read(model *schema.Model, q *ir.Query, searchPath []string) (*ir.Plan, *pgerr.APIError) { +func Read(model *schema.Model, q *ir.Query, searchPath []string, opts Options) (*ir.Plan, *pgerr.APIError) { rel, ok := model.Lookup(q.Relation.Name, searchPath) if !ok { - return nil, pgerr.ErrUnknownTable(q.Relation.Name) + return nil, pgerr.ErrUnknownTable(searchSchema(searchPath), q.Relation.Name) } // Bind the resolved schema/name back onto the query so the compiler emits a // fully qualified, model-validated reference. q.Relation = ir.Ref{Schema: rel.Schema, Name: rel.Name} - if err := validateSelect(rel, q.Select); err != nil { + if err := validateSelect(rel, q.Select, opts.AggregatesEnabled); err != nil { return nil, err } + // A filter naming an embed (films?actors=not.is.null) is an existence test on + // the relationship, not a parent column. Reclassify those before column + // validation so they are not rejected as unknown columns, then validate the + // rest of the tree. See item 01.12. + reclassifyEmbedFilters(q) if err := validateCond(rel, q.Where); err != nil { return nil, err } if err := validateOrder(rel, q.Order); err != nil { return nil, err } - if err := resolveEmbeds(model, rel, q, searchPath); err != nil { + if err := resolveEmbeds(model, rel, q, searchPath, opts.AggregatesEnabled); err != nil { + return nil, err + } + // Related-order terms (order=rel(col)) are validated once the embeds they + // reference are resolved, so the relationship's cardinality is known. + if err := validateRelatedOrder(rel, q); err != nil { return nil, err } + bindComputed(rel, q) + bindReps(rel, q) return &ir.Plan{Query: q, Rel: rel, ReadOnly: true}, nil } +// bindComputed copies the relation's computed-field name-to-schema mapping onto +// the query so the compiler can render a selected, filtered, or ordered computed +// field as a function call on the row. It maps every computed field the relation +// exposes, not only the referenced ones, which keeps it a single pass and costs +// nothing for the common case of a relation with none (the map stays nil). +func bindComputed(rel *schema.Relation, q *ir.Query) { + if len(rel.Computed) == 0 { + return + } + m := make(map[string]string, len(rel.Computed)) + for i := range rel.Computed { + m[rel.Computed[i].Name] = rel.Computed[i].FuncSchema + } + q.Computed = m +} + +// bindReps copies the data-representation cast set of every column that carries +// one onto the query, so the compiler reformats that column through its domain's +// casts (spec 11): ToJSON on read output, FromJSON on a write value, FromText on a +// filter literal. It costs nothing for the common relation with no representation +// column (the map stays nil). +func bindReps(rel *schema.Relation, q *ir.Query) { + var m map[string]ir.Rep + for _, c := range rel.Columns { + if c.Rep == nil { + continue + } + if m == nil { + m = make(map[string]ir.Rep) + } + m[c.Name] = ir.Rep{ + ToJSONSchema: c.Rep.ToJSON.Schema, ToJSONFunc: c.Rep.ToJSON.Name, + FromTextSchema: c.Rep.FromText.Schema, FromTextFunc: c.Rep.FromText.Name, + FromJSONSchema: c.Rep.FromJSON.Schema, FromJSONFunc: c.Rep.FromJSON.Name, + } + } + q.Reps = m +} + // resolveEmbeds binds every embed of a query against the model: it finds the // relationship from the parent to the embedded resource, applies a disambiguation // hint, and recurses into nested embeds. A missing relationship is PGRST200; an // ambiguous one (more than one surviving edge) is PGRST201. The embed's nested // select, filters, and ordering are validated against the embedded relation. -func resolveEmbeds(model *schema.Model, parent *schema.Relation, q *ir.Query, searchPath []string) *pgerr.APIError { +func resolveEmbeds(model *schema.Model, parent *schema.Relation, q *ir.Query, searchPath []string, aggEnabled bool) *pgerr.APIError { for i := range q.Embeds { emb := &q.Embeds[i] rel, err := resolveOne(model, parent, emb, searchPath) @@ -64,7 +131,7 @@ func resolveEmbeds(model *schema.Model, parent *schema.Relation, q *ir.Query, se // Bind the embedded relation so the compiler emits a model-validated ref. emb.Query.Relation = ir.Ref{Schema: rel.Target.Schema, Name: rel.Target.Name} - if err := validateSelect(rel.Target, emb.Query.Select); err != nil { + if err := validateSelect(rel.Target, emb.Query.Select, aggEnabled); err != nil { return err } if err := validateCond(rel.Target, emb.Query.Where); err != nil { @@ -73,9 +140,16 @@ func resolveEmbeds(model *schema.Model, parent *schema.Relation, q *ir.Query, se if err := validateOrder(rel.Target, emb.Query.Order); err != nil { return err } - if err := resolveEmbeds(model, rel.Target, &emb.Query, searchPath); err != nil { + if err := resolveEmbeds(model, rel.Target, &emb.Query, searchPath, aggEnabled); err != nil { + return err + } + // A nested related order (tasks.order=projects(id)) references the embed's + // own sub-embeds, now resolved. + if err := validateRelatedOrder(rel.Target, &emb.Query); err != nil { return err } + bindComputed(rel.Target, &emb.Query) + bindReps(rel.Target, &emb.Query) } return nil } @@ -86,7 +160,13 @@ func resolveEmbeds(model *schema.Model, parent *schema.Relation, q *ir.Query, se func resolveOne(model *schema.Model, parent *schema.Relation, emb *ir.Embed, searchPath []string) (*schema.Relationship, *pgerr.APIError) { cands, found := model.Relationships(parent, emb.Target.Name, searchPath) if !found || len(cands) == 0 { - return nil, pgerr.ErrNoRelationship(parent.Name, emb.Target.Name) + // A computed relationship is embedded by the function name, which need not + // equal the target relation name, so the target-name path above cannot see + // it. Fall back to resolving the edge by name and inferring its target. + if rel, ok := model.ComputedRelByName(parent, emb.Target.Name, searchPath); ok { + return rel, nil + } + return nil, pgerr.ErrNoRelationship(parent.Name, emb.Target.Name, searchSchema(searchPath), emb.Hint) } if emb.Hint != "" { filtered := cands[:0:0] @@ -99,15 +179,35 @@ func resolveOne(model *schema.Model, parent *schema.Relation, emb *ir.Embed, sea } switch len(cands) { case 0: - return nil, pgerr.ErrNoRelationship(parent.Name, emb.Target.Name) + return nil, pgerr.ErrNoRelationship(parent.Name, emb.Target.Name, searchSchema(searchPath), emb.Hint) case 1: c := cands[0] return &c, nil default: - return nil, pgerr.ErrAmbiguousEmbed(parent.Name, emb.Target.Name) + return nil, pgerr.ErrAmbiguousEmbed(parent.Name, emb.Target.Name, embedCandidates(parent, cands)) } } +// embedCandidates renders the surviving relationships into the PGRST201 details +// entries: each carries its four-way cardinality, the "parent with target" +// embedding, and the "name using parent(cols) and target(cols)" relationship +// spelling PostgREST reports. A many-to-many edge spells its end columns the same +// way; the cardinality field marks it as the junction case. +func embedCandidates(parent *schema.Relation, cands []schema.Relationship) []pgerr.EmbedCandidate { + out := make([]pgerr.EmbedCandidate, len(cands)) + for i, c := range cands { + out[i] = pgerr.EmbedCandidate{ + Cardinality: c.Cardinality, + Embedding: parent.Name + " with " + c.Target.Name, + Relationship: fmt.Sprintf("%s using %s(%s) and %s(%s)", + c.Name, parent.Name, strings.Join(c.Local, ", "), + c.Target.Name, strings.Join(c.Foreign, ", ")), + Name: c.Name, + } + } + return out +} + // toCardinality maps the schema cardinality to the IR's. func toCardinality(c schema.Card) ir.Cardinality { if c == schema.CardToMany { @@ -128,11 +228,13 @@ func toCardinality(c schema.Card) ir.Cardinality { func Write(model *schema.Model, q *ir.Query, searchPath []string) (*ir.Plan, *pgerr.APIError) { rel, ok := model.Lookup(q.Relation.Name, searchPath) if !ok { - return nil, pgerr.ErrUnknownTable(q.Relation.Name) + return nil, pgerr.ErrUnknownTable(searchSchema(searchPath), q.Relation.Name) } q.Relation = ir.Ref{Schema: rel.Schema, Name: rel.Name} - if err := validateSelect(rel, q.Select); err != nil { + // A write's return=representation projection is a read shape, but PostgREST + // does not allow aggregates there, so the gate stays closed on this path. + if err := validateSelect(rel, q.Select, false); err != nil { return nil, err } if err := validateCond(rel, q.Where); err != nil { @@ -141,31 +243,195 @@ func Write(model *schema.Model, q *ir.Query, searchPath []string) (*ir.Plan, *pg if err := validateWrite(rel, q.Write); err != nil { return nil, err } + // A return=representation body is shaped by the same select/embeds a read + // uses, so resolve the embeds against the target relation here. An unknown or + // ambiguous relationship is the read path's PGRST200/201 rather than being + // silently dropped from the response. See item 01.19. + if err := resolveEmbeds(model, rel, q, searchPath, false); err != nil { + return nil, err + } + if q.IsPut { + if err := validatePut(rel, q); err != nil { + return nil, err + } + } + // A return=representation body renders through the read path, and a write value + // of a domain column parses through its from-json cast, so the same maps a read + // uses are bound here too (spec 11). + bindComputed(rel, q) + bindReps(rel, q) return &ir.Plan{Query: q, Rel: rel, ReadOnly: false}, nil } +// validatePut enforces PostgREST's PUT contract before any write: the URL +// filters must be exactly the relation's primary key columns, each with eq +// (PGRST105); no limit or offset may be present (PGRST114); and the body must +// be a single object whose primary key values equal the URL's (PGRST115). A PUT +// addresses one row by its whole key, so anything looser is rejected here rather +// than writing the wrong row. +func validatePut(rel *schema.Relation, q *ir.Query) *pgerr.APIError { + if q.Limit != nil || q.Offset != nil { + return pgerr.ErrPutLimit() + } + eqs, ok := putEqFilters(q.Where) + if !ok { + return pgerr.ErrPutPrimaryKey() + } + pk := rel.PrimaryKey + if len(pk) == 0 || len(eqs) != len(pk) { + return pgerr.ErrPutPrimaryKey() + } + for _, c := range pk { + if _, ok := eqs[c]; !ok { + return pgerr.ErrPutPrimaryKey() + } + } + w := q.Write + if w == nil || len(w.Rows) != 1 { + return pgerr.ErrPutPayloadKey() + } + row := w.Rows[0] + for _, c := range pk { + v, ok := row[c] + if !ok || !putKeyMatches(rel, c, v, eqs[c]) { + return pgerr.ErrPutPayloadKey() + } + } + return nil +} + +// putEqFilters flattens a PUT's WHERE into a map of column to operand text, +// accepting only a conjunction of single-column, non-negated, unquantified eq +// comparisons. It returns ok=false for any other shape (a non-eq operator, an +// or/not tree, a JSON path, or a quantifier), none of which a PUT may carry. +func putEqFilters(c *ir.Cond) (map[string]string, bool) { + out := map[string]string{} + var walk func(n ir.Cond) bool + walk = func(n ir.Cond) bool { + switch v := n.(type) { + case ir.And: + for _, k := range v.Kids { + if !walk(k) { + return false + } + } + return true + case ir.Compare: + if v.Op != ir.OpEq || len(v.Path) != 1 || v.Quant != ir.QNone || v.Negate { + return false + } + out[v.Path[0]] = v.Value.Text + return true + default: + return false + } + } + if c == nil { + return out, true + } + return out, walk(*c) +} + +// putKeyMatches reports whether a payload value for a primary key column equals +// the URL filter text. Both sides are coerced through the column's type so 1 and +// "1" agree; if the type is unknown or either side fails to coerce, the raw text +// forms are compared. +func putKeyMatches(rel *schema.Relation, col string, payload ir.Value, urlText string) bool { + pj := jsonScalarText(payload.JSON) + if c, ok := rel.Column(col); ok && c.Type != "" { + pv, perr := pgtypes.ParseScalar(c.Type, pj) + uv, uerr := pgtypes.ParseScalar(c.Type, urlText) + if perr == nil && uerr == nil { + return fmt.Sprint(pv) == fmt.Sprint(uv) + } + } + return pj == urlText +} + +// jsonScalarText renders a decoded JSON scalar as the text PostgREST would +// compare against a URL operand. A JSON number prints without a trailing zero so +// 1 stays "1", not "1.000000". +func jsonScalarText(v any) string { + switch t := v.(type) { + case nil: + return "" + case string: + return t + case bool: + if t { + return "true" + } + return "false" + case float64: + return strconv.FormatFloat(t, 'f', -1, 64) + case json.Number: + return t.String() + default: + b, _ := json.Marshal(t) + return string(b) + } +} + // Call resolves a parsed RPC call against the function registry and returns an // executable plan. It selects the overload the argument set satisfies (PGRST202 // when none does), enforces the volatility-versus-method rule (a GET to a // volatile function is 405), and validates a post-filter select/where/order // against a table return's declared columns. The resolved function and the // read-only decision travel on the plan for the backend to lower. See spec 12. -func Call(reg rpc.Registry, c *ir.Call, isGet bool, searchPath []string) (*ir.Plan, *pgerr.APIError) { +func Call(reg rpc.Registry, model *schema.Model, c *ir.Call, isGet bool, searchPath []string) (*ir.Plan, *pgerr.APIError) { + // On GET the argument-versus-filter split needs the function's parameter + // names, which the registry knows by function name (the union across every + // overload). A query key naming a parameter is an argument; the rest are + // re-read as filters on a table-valued result. An unknown function is left + // unpartitioned so resolution raises PGRST202, rather than a stray key being + // mis-parsed as a filter on a result that does not exist. + if isGet { + if params, known := paramNameSet(reg, c.Function.Name); known { + variadic := variadicNameSet(reg, c.Function.Name) + if perr := c.PartitionGetArgs( + func(k string) bool { return params[k] }, + func(k string) bool { return variadic[k] }, + ); perr != nil { + return nil, perr + } + } + } + args := make(rpc.ArgSet, len(c.Args)) for name := range c.Args { args[name] = true } - fn, ok := reg.Lookup(c.Function.Name, args) + activeSchema := "" + if len(searchPath) > 0 { + activeSchema = searchPath[0] + } + fn, ambiguous, ok := reg.Resolve(c.Function.Name, args) if !ok { - return nil, pgerr.ErrNoFunction(c.Function.Name) + if len(ambiguous) > 0 { + return nil, pgerr.ErrAmbiguousFunction(ambiguous) + } + argNames := sortedArgNames(c.Args) + return nil, pgerr.ErrNoFunction(activeSchema, c.Function.Name, argNames, nearestSignature(reg, activeSchema, c.Function.Name, args)) } // A read method may only call a read-only function; a write-capable function - // requires POST so it runs in a read-write transaction. + // requires POST so it runs in a read-write transaction. PostgREST does not + // pre-reject this: it runs the call in a read-only transaction and lets the + // server fail with SQLSTATE 25006 (405, "cannot execute ... in a read-only + // transaction"). A registry backend cannot run the function to find out, so it + // raises the same SQLSTATE up front from the declared volatility, keeping the + // code and status a client sees identical to the native path's. if isGet && !fn.Volatility.ReadOnly() { - return nil, pgerr.ErrMethodNotAllowed( - "Cannot call a volatile function with GET; use POST") + return nil, pgerr.ErrReadOnlyTransaction("function " + c.Function.Name) + } + + // A GET argument arrives as text; validate it against the declared parameter + // type so an invalid value is the same 22P02 on every backend, the way a read + // filter is coerced. A POST argument is already typed by the JSON body, and an + // empty text value stays an empty string rather than becoming NULL. + if err := coerceCallArgs(fn, c); err != nil { + return nil, err } // Post-filters apply to a table return; validate their columns against the @@ -176,10 +442,190 @@ func Call(reg rpc.Registry, c *ir.Call, isGet bool, searchPath []string) (*ir.Pl return nil, err } + // A function returning rows of a known relation supports embeds on its result, + // resolved the same way a table read's embeds are. A call with embeds over a + // function whose result is not a relation has nothing to embed against. + if len(c.Embeds) > 0 { + if err := resolveCallEmbeds(model, fn, c, searchPath); err != nil { + return nil, err + } + } + c.ReadOnly = fn.Volatility.ReadOnly() return &ir.Plan{Call: c, Func: fn, ReadOnly: c.ReadOnly}, nil } +// returnRelation resolves the relation whose rows a function returns, when its +// return type names one (returns setof clients, returns clients). A scalar, +// setof-scalar, anonymous table(...), or void return names no relation, so its +// result has no relationships to embed against. +func returnRelation(model *schema.Model, fn *rpc.Function, searchPath []string) (*schema.Relation, bool) { + if model == nil { + return nil, false + } + switch fn.Returns.Kind { + case rpc.ReturnSetOf, rpc.ReturnScalar: + if fn.Returns.Type == "" { + return nil, false + } + return model.Lookup(fn.Returns.Type, searchPath) + default: + return nil, false + } +} + +// searchSchema is the schema a relationship search names in its PGRST200 details: +// the exposed schema the request resolved to, which PostgREST reports there. It +// is the first entry of the active search path, normalized to public when that is +// empty. A backend with no schema namespace (SQLite) resolves to the empty +// schema, but PostgREST always names a schema in this sentence, so the wire form +// reports public rather than an empty-quoted gap. +func searchSchema(searchPath []string) string { + if len(searchPath) > 0 && searchPath[0] != "" { + return searchPath[0] + } + return "public" +} + +// resolveCallEmbeds binds an RPC call's embeds against the function's result +// relation. It mirrors the read path by projecting the call's select/where/order +// onto a synthetic query over that relation, so resolveEmbeds, the embed-filter +// reclassification, and related-order validation all apply unchanged. The +// resolved embeds and any reclassified filter tree are carried back onto the +// call. A function whose result is not a relation cannot be embedded on, which is +// the read path's PGRST200. +func resolveCallEmbeds(model *schema.Model, fn *rpc.Function, c *ir.Call, searchPath []string) *pgerr.APIError { + retRel, ok := returnRelation(model, fn, searchPath) + if !ok { + return pgerr.ErrNoRelationship(fn.Name, c.Embeds[0].Target.Name, searchSchema(searchPath), c.Embeds[0].Hint) + } + q := &ir.Query{ + Kind: ir.Read, + Relation: ir.Ref{Schema: retRel.Schema, Name: retRel.Name}, + Select: c.Select, + Where: c.Where, + Order: c.Order, + Embeds: c.Embeds, + } + reclassifyEmbedFilters(q) + if err := resolveEmbeds(model, retRel, q, searchPath, false); err != nil { + return err + } + if err := validateRelatedOrder(retRel, q); err != nil { + return err + } + c.Where = q.Where + c.Embeds = q.Embeds + return nil +} + +// sortedArgNames returns the call's argument names in a stable order, the list +// PostgREST echoes in a PGRST202 message. +func sortedArgNames(args map[string]ir.Value) []string { + out := make([]string, 0, len(args)) + for name := range args { + out = append(out, name) + } + sort.Strings(out) + return out +} + +// nearestSignature returns the registered overload of the same name whose +// parameter set is closest to the requested arguments, rendered as a "Perhaps you +// meant to call ..." hint. It returns an empty string when nothing of that name +// is registered, so the caller attaches no hint. +func nearestSignature(reg rpc.Registry, schemaName, name string, args rpc.ArgSet) string { + var best *rpc.Function + bestScore := -1 + for _, f := range reg.List() { + if f.Name != name { + continue + } + score := 0 + for _, p := range f.Params { + if args[p.Name] { + score++ + } else { + score-- // a parameter the call did not supply is a small mismatch + } + } + if score > bestScore { + best, bestScore = f, score + } + } + if best == nil { + return "" + } + return "Perhaps you meant to call the function " + best.Signature(schemaName) +} + +// paramNameSet is the union of parameter names across every overload of a +// function name, and whether the name is registered at all. PostgREST partitions +// a GET call's query keys against this set, independent of which overload +// eventually resolves, so a key naming any overload's parameter is an argument +// rather than a filter. The found flag separates a known parameterless function +// (partition its keys as filters) from an unknown name (leave the keys so +// resolution raises PGRST202). +func paramNameSet(reg rpc.Registry, name string) (set map[string]bool, found bool) { + set = map[string]bool{} + for _, f := range reg.List() { + if f.Name != name { + continue + } + found = true + for _, p := range f.Params { + set[p.Name] = true + } + } + return set, found +} + +// variadicNameSet is the set of variadic parameter names across every overload of +// a function name, so a GET call can collect that key's repeats into a list. +func variadicNameSet(reg rpc.Registry, name string) map[string]bool { + set := map[string]bool{} + for _, f := range reg.List() { + if f.Name != name { + continue + } + for _, p := range f.Params { + if p.Variadic { + set[p.Name] = true + } + } + } + return set +} + +// coerceCallArgs validates each GET text argument against its declared parameter +// type, turning a bad value into the 22P02 the read path raises. A POST argument +// is typed by the JSON body and skipped; an undeclared argument cannot reach here +// because resolution already rejected it. A parameter with no declared type is +// carried through unchanged. A variadic argument validates each collected element. +func coerceCallArgs(fn *rpc.Function, c *ir.Call) *pgerr.APIError { + for name, v := range c.Args { + if v.JSON != nil { + continue // a POST argument, already typed + } + p, ok := fn.Param(name) + if !ok || p.Type == "" { + continue + } + if p.Variadic { + for _, e := range v.List { + if err := coerce(p.Type, e); err != nil { + return err + } + } + continue + } + if err := coerce(p.Type, v.Text); err != nil { + return err + } + } + return nil +} + // validateCallFilters checks an RPC call's post-filter columns against a table // return's declared columns. It is a no-op for scalar and setof-scalar returns // and for a table return whose columns are not declared. @@ -201,7 +647,7 @@ func validateCallFilters(fn *rpc.Function, c *ir.Call) *pgerr.APIError { continue } if !has(col.Path) { - return pgerr.ErrUnknownColumn(col.Path[0]) + return pgerr.ErrUndefinedColumn(col.Path[0]) } } if err := validateCallCond(cols, c.Where); err != nil { @@ -209,7 +655,7 @@ func validateCallFilters(fn *rpc.Function, c *ir.Call) *pgerr.APIError { } for _, t := range c.Order { if !has(t.Path) { - return pgerr.ErrUnknownColumn(t.Path[0]) + return pgerr.ErrUndefinedColumn(t.Path[0]) } } return nil @@ -241,7 +687,7 @@ func validateCallCond(cols map[string]bool, c *ir.Cond) *pgerr.APIError { return validateCallCond(cols, &n.Kid) case ir.Compare: if len(n.Path) > 0 && !cols[n.Path[0]] { - return pgerr.ErrUnknownColumn(n.Path[0]) + return pgerr.ErrUndefinedColumn(n.Path[0]) } } return nil @@ -254,40 +700,110 @@ func validateWrite(rel *schema.Relation, w *ir.WriteSpec) *pgerr.APIError { return nil } // The insert column set (first-row keys or explicit columns=) is what the - // compiler writes; validating it covers the payload that reaches SQL. + // compiler writes; validating it covers the payload that reaches SQL. Each + // resolved column carries its canonical type so the compiler can lower a JSON + // array payload value to the shape the target column accepts (json/jsonb text + // vs a PostgreSQL array literal). + types := map[string]string{} for _, c := range w.Columns { - if !rel.HasColumn(c) { - return pgerr.ErrUnknownColumn(c) + col, ok := rel.Column(c) + if !ok { + return pgerr.ErrUnknownColumn(c, rel.Name) } + types[c] = col.Type } for k := range w.Set { - if !rel.HasColumn(k) { - return pgerr.ErrUnknownColumn(k) + col, ok := rel.Column(k) + if !ok { + return pgerr.ErrUnknownColumn(k, rel.Name) } + types[k] = col.Type } + w.ColumnTypes = types if w.Conflict != nil && len(w.Conflict.Target) == 0 { w.Conflict.Target = rel.PrimaryKey } return nil } -func validateSelect(rel *schema.Relation, items []ir.SelectItem) *pgerr.APIError { +func validateSelect(rel *schema.Relation, items []ir.SelectItem, aggEnabled bool) *pgerr.APIError { for _, it := range items { - col, ok := it.(ir.Column) - if !ok { - // Aggregates and embeds are checked by their subsystems; leave them. - continue - } - if isStarPath(col.Path) { - continue - } - if err := checkColumn(rel, col.Path); err != nil { - return err + switch v := it.(type) { + case ir.Column: + if isStarPath(v.Path) { + continue + } + if err := checkColumn(rel, v.Path); err != nil { + return err + } + case ir.Aggregate: + // The count()/col.agg() function forms are gated behind + // db-aggregates-enabled; the legacy bare count an embed carries is exempt. + if !v.Legacy && !aggEnabled { + return pgerr.ErrAggregatesDisabled() + } + if v.Arg != nil { + if isStarPath(v.Arg.Path) { + continue + } + if err := checkColumn(rel, v.Arg.Path); err != nil { + return err + } + } + default: + // Embed references are checked by resolveEmbeds. } } return nil } +// reclassifyEmbedFilters rewrites, in place, every Compare in the query's filter +// tree whose single-segment path names an embed's OutKey and whose operator is +// `is null` into an ir.EmbedPredicate. PostgREST reads films?actors=not.is.null +// as a semi-join on the actors relationship and films?actors=is.null as an +// anti-join, both usable inside or=(...); without this rewrite the embed name +// would be validated as a parent column and rejected. not.is.null carries the +// Compare's Negate, which becomes Exists (the parent must have a matching row). +// See item 01.12. +func reclassifyEmbedFilters(q *ir.Query) { + if q.Where == nil || len(q.Embeds) == 0 { + return + } + idx := make(map[string]int, len(q.Embeds)) + for i := range q.Embeds { + idx[q.Embeds[i].OutKey] = i + } + var rw func(c ir.Cond) ir.Cond + rw = func(c ir.Cond) ir.Cond { + switch n := c.(type) { + case ir.And: + for i := range n.Kids { + n.Kids[i] = rw(n.Kids[i]) + } + return n + case ir.Or: + for i := range n.Kids { + n.Kids[i] = rw(n.Kids[i]) + } + return n + case ir.Not: + n.Kid = rw(n.Kid) + return n + case ir.Compare: + if n.Op == ir.OpIs && n.Value.Text == "null" && len(n.Path) == 1 { + if i, ok := idx[n.Path[0]]; ok { + return ir.EmbedPredicate{Index: i, Exists: n.Negate} + } + } + return n + default: + return c + } + } + nc := rw(*q.Where) + q.Where = &nc +} + func validateCond(rel *schema.Relation, c *ir.Cond) *pgerr.APIError { if c == nil { return nil @@ -321,8 +837,34 @@ func validateCond(rel *schema.Relation, c *ir.Cond) *pgerr.APIError { // structure (SQLite's FTS5) raises PGRST127. See spec 21. if n.Op == ir.OpFTS && len(n.Path) == 1 { n.FullText = rel.FullTextIndexFor(n.Path[0]) + // Carry the column's canonical type so the dialect can skip the + // to_tsvector wrap on a column that is already tsvector, the way + // PostgREST does (Plan.hs "Do not apply to_tsvector to tsvector + // types"). A wrap on a tsvector column raises 42883 in PostgreSQL. + if col, ok := rel.Column(n.Path[0]); ok { + n.ColumnType = col.Type + } *c = n } + // Array operators carry the column's canonical type so the dialect can + // decide whether the column supports array semantics (e.g. SQLite's + // json_each only applies to JSON-typed columns). See spec 21. + if (n.Op == ir.OpContains || n.Op == ir.OpContained || n.Op == ir.OpOverlap) && len(n.Path) == 1 { + if col, ok := rel.Column(n.Path[0]); ok { + n.ColumnType = col.Type + *c = n + } + } + // eq/neq carry the column's canonical type so the compiler binds the + // literal "true"/"false" as a boolean only when the column is boolean; + // against a text column the words stay text, matching PostgreSQL's + // type-driven coercion (item 07.4). + if (n.Op == ir.OpEq || n.Op == ir.OpNeq) && len(n.Path) == 1 { + if col, ok := rel.Column(n.Path[0]); ok { + n.ColumnType = col.Type + *c = n + } + } } return nil } @@ -342,6 +884,21 @@ func checkOperand(rel *schema.Relation, c ir.Compare) *pgerr.APIError { if !ok { return nil } + // A quantified comparison (eq/gt/gte/lt/lte over a {…} list) carries its + // operands in the list; coerce each against the column type. Quantified + // pattern operators (like/ilike/match/imatch) take patterns, not typed values, + // and are left alone (item 01.1). + if c.Quant != ir.QNone { + switch c.Op { + case ir.OpEq, ir.OpGt, ir.OpGte, ir.OpLt, ir.OpLte: + for _, v := range c.Value.List { + if err := coerce(col.Type, v); err != nil { + return err + } + } + } + return nil + } switch c.Op { case ir.OpEq, ir.OpNeq, ir.OpGt, ir.OpGte, ir.OpLt, ir.OpLte: return coerce(col.Type, c.Value.Text) @@ -371,6 +928,12 @@ func coerce(canonicalType, text string) *pgerr.APIError { func validateOrder(rel *schema.Relation, terms []ir.OrderTerm) *pgerr.APIError { for _, t := range terms { + // A related-order term (order=rel(col)) addresses a column of an embedded + // resource, not the parent. It is validated against the resolved embed in + // validateRelatedOrder, after the embeds are bound. + if t.Rel != "" { + continue + } if err := checkColumn(rel, t.Path); err != nil { return err } @@ -378,15 +941,65 @@ func validateOrder(rel *schema.Relation, terms []ir.OrderTerm) *pgerr.APIError { return nil } +// validateRelatedOrder checks every order=rel(col) term of a query against its +// resolved embeds: the named relation must be embedded in this request (PGRST108 +// otherwise) and must be a to-one relationship (PGRST118 otherwise, since a +// to-many embed gives no single sort key). The embed's own column is then +// validated against the embedded relation. The embeds must already be resolved. +func validateRelatedOrder(parent *schema.Relation, q *ir.Query) *pgerr.APIError { + for _, t := range q.Order { + if t.Rel == "" { + continue + } + emb := findEmbedByName(q.Embeds, t.Rel) + if emb == nil { + return pgerr.ErrRelatedOrderNotEmbedded(t.Rel) + } + if emb.Cardinality != ir.CardToOne { + return pgerr.ErrRelatedOrderNotToOne(parent.Name, t.Rel) + } + if err := checkColumn(emb.Rel.Target, t.Path); err != nil { + return err + } + } + return nil +} + +// findEmbedByName returns the embed an order=rel(col) term refers to, matched by +// the embed's alias when it has one, otherwise its written target name. This is +// the same spelling PostgREST resolves the related-order relation against. +func findEmbedByName(embeds []ir.Embed, name string) *ir.Embed { + for i := range embeds { + emb := &embeds[i] + written := emb.Alias + if written == "" { + written = emb.Target.Name + } + if written == name { + return emb + } + } + return nil +} + // checkColumn validates that the base column of a path exists on the relation. // Only the base (first hop) is checked here; JSON sub-paths are opaque to the -// model and validated when the JSON subsystem lands. +// model and validated when the JSON subsystem lands. A column named in select, a +// filter, or order that does not exist is PostgreSQL's 42703 (the reference +// reaches the server under PostgREST), relation-qualified the way the server +// spells it, not the schema-cache PGRST204 reserved for write payloads. func checkColumn(rel *schema.Relation, path []string) *pgerr.APIError { if len(path) == 0 { return nil } - if !rel.HasColumn(path[0]) { - return pgerr.ErrUnknownColumn(path[0]) + if rel.HasColumn(path[0]) { + return nil } - return nil + // A computed field is a function-backed virtual column: it is selectable, + // filterable, and orderable wherever a real column is, so the name resolves + // here and the compiler renders it as a function call on the row. + if _, ok := rel.ComputedFieldFor(path[0]); ok { + return nil + } + return pgerr.ErrUndefinedColumn(rel.Name + "." + path[0]) } diff --git a/plan/plan_test.go b/plan/plan_test.go index da4f8ed..e0532d0 100644 --- a/plan/plan_test.go +++ b/plan/plan_test.go @@ -19,7 +19,7 @@ func model() *schema.Model { func TestReadResolvesRelation(t *testing.T) { q := &ir.Query{Relation: ir.Ref{Name: "films"}, Select: []ir.SelectItem{ir.Column{Path: []string{"title"}}}} - p, err := Read(model(), q, nil) + p, err := Read(model(), q, nil, Options{}) if err != nil { t.Fatalf("Read: %v", err) } @@ -33,7 +33,7 @@ func TestReadResolvesRelation(t *testing.T) { func TestReadUnknownTable(t *testing.T) { q := &ir.Query{Relation: ir.Ref{Name: "ghosts"}} - _, err := Read(model(), q, nil) + _, err := Read(model(), q, nil, Options{}) if err == nil || err.Code != "PGRST205" { t.Fatalf("want PGRST205, got %v", err) } @@ -41,26 +41,56 @@ func TestReadUnknownTable(t *testing.T) { func TestReadUnknownColumnInSelect(t *testing.T) { q := &ir.Query{Relation: ir.Ref{Name: "films"}, Select: []ir.SelectItem{ir.Column{Path: []string{"bogus"}}}} - _, err := Read(model(), q, nil) - if err == nil || err.Code != "PGRST204" { - t.Fatalf("want PGRST204, got %v", err) + _, err := Read(model(), q, nil, Options{}) + // A select column that does not exist reaches PostgreSQL: 42703, not the + // schema-cache PGRST204 reserved for write payloads (item 04.5). + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) } } func TestReadUnknownColumnInFilter(t *testing.T) { where := ir.Cond(ir.Compare{Path: []string{"missing"}, Op: ir.OpEq, Value: ir.Value{Text: "x"}}) q := &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where} - _, err := Read(model(), q, nil) - if err == nil || err.Code != "PGRST204" { - t.Fatalf("want PGRST204, got %v", err) + _, err := Read(model(), q, nil, Options{}) + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) } } func TestReadUnknownColumnInOrder(t *testing.T) { q := &ir.Query{Relation: ir.Ref{Name: "films"}, Order: []ir.OrderTerm{{Path: []string{"nope"}}}} - _, err := Read(model(), q, nil) - if err == nil || err.Code != "PGRST204" { - t.Fatalf("want PGRST204, got %v", err) + _, err := Read(model(), q, nil, Options{}) + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) + } +} + +// The planner stamps an eq/neq filter with its column's canonical type so the +// compiler can decide whether "true"/"false" binds as a boolean (item 07.4). +func TestReadStampsColumnTypeOnEq(t *testing.T) { + m := schema.NewModel([]*schema.Relation{ + {Name: "flags", Kind: schema.KindTable, Columns: []*schema.Column{ + {Name: "id", Type: "integer", Position: 1}, + {Name: "done", Type: "bool", Position: 2}, + {Name: "label", Type: "text", Position: 3}, + }}, + }) + where := ir.Cond(ir.And{Kids: []ir.Cond{ + ir.Compare{Path: []string{"done"}, Op: ir.OpEq, Value: ir.Value{Text: "true"}}, + ir.Compare{Path: []string{"label"}, Op: ir.OpNeq, Value: ir.Value{Text: "x"}}, + }}) + q := &ir.Query{Relation: ir.Ref{Name: "flags"}, Where: &where} + p, err := Read(m, q, nil, Options{}) + if err != nil { + t.Fatalf("Read: %v", err) + } + kids := (*p.Query.Where).(ir.And).Kids + if ct := kids[0].(ir.Compare).ColumnType; ct != "bool" { + t.Errorf("done ColumnType = %q, want bool", ct) + } + if ct := kids[1].(ir.Compare).ColumnType; ct != "text" { + t.Errorf("label ColumnType = %q, want text", ct) } } @@ -72,8 +102,8 @@ func TestReadNestedLogicalColumnChecked(t *testing.T) { }}, }}) q := &ir.Query{Relation: ir.Ref{Name: "films"}, Where: &where} - _, err := Read(model(), q, nil) - if err == nil || err.Code != "PGRST204" { + _, err := Read(model(), q, nil, Options{}) + if err == nil || err.Code != "42703" { t.Fatalf("nested unknown column should be caught, got %v", err) } } @@ -124,6 +154,12 @@ func TestWriteUnknownInsertColumn(t *testing.T) { if err == nil || err.Code != "PGRST204" { t.Fatalf("want PGRST204 for unknown insert column, got %v", err) } + // A write payload is the schema-cache case PGRST204 is reserved for, and v14 + // names the relation in the message (item 04.3). + want := "Could not find the 'bogus' column of 'films' in the schema cache" + if err.Message != want { + t.Errorf("message = %q, want %q", err.Message, want) + } } func TestWriteUnknownUpdateColumn(t *testing.T) { @@ -156,3 +192,164 @@ func TestWriteUpsertDefaultsConflictToPK(t *testing.T) { t.Errorf("conflict target = %v, want the primary key [id]", got) } } + +// putQuery builds a PUT upsert addressing id=eq. with a single-object +// body carrying id=, plus optional extra filters and limit. +func putQuery(idFilter, idBody string, extra *ir.Cond, limit *int) *ir.Query { + eq := ir.Cond(ir.Compare{Path: []string{"id"}, Op: ir.OpEq, Value: ir.Value{Text: idFilter}}) + where := eq + if extra != nil { + where = ir.And{Kids: []ir.Cond{eq, *extra}} + } + return &ir.Query{ + Kind: ir.Upsert, + IsPut: true, + Relation: ir.Ref{Name: "films"}, + Where: &where, + Limit: limit, + Write: &ir.WriteSpec{ + Columns: []string{"id", "title"}, + Rows: []map[string]ir.Value{{"id": {JSON: idBody}, "title": {JSON: "X"}}}, + Conflict: &ir.Conflict{}, + }, + } +} + +func TestWritePutHappyPath(t *testing.T) { + q := putQuery("1", "1", nil, nil) + if _, err := Write(modelPK(), q, nil); err != nil { + t.Fatalf("Write: %v", err) + } +} + +func TestWritePutPartialKeyIs405(t *testing.T) { + // A non-eq filter on the key column is not a valid PUT addressing. + where := ir.Cond(ir.Compare{Path: []string{"id"}, Op: ir.OpGt, Value: ir.Value{Text: "1"}}) + q := &ir.Query{ + Kind: ir.Upsert, IsPut: true, Relation: ir.Ref{Name: "films"}, Where: &where, + Write: &ir.WriteSpec{Columns: []string{"id"}, Rows: []map[string]ir.Value{{"id": {JSON: float64(1)}}}, Conflict: &ir.Conflict{}}, + } + _, err := Write(modelPK(), q, nil) + if err == nil || err.Code != "PGRST105" { + t.Fatalf("want PGRST105, got %v", err) + } +} + +func TestWritePutExtraFilterIs405(t *testing.T) { + extra := ir.Cond(ir.Compare{Path: []string{"title"}, Op: ir.OpEq, Value: ir.Value{Text: "X"}}) + _, err := Write(modelPK(), putQuery("1", "1", &extra, nil), nil) + if err == nil || err.Code != "PGRST105" { + t.Fatalf("want PGRST105 for a non-PK filter, got %v", err) + } +} + +func TestWritePutLimitIs400(t *testing.T) { + lim := 1 + _, err := Write(modelPK(), putQuery("1", "1", nil, &lim), nil) + if err == nil || err.Code != "PGRST114" { + t.Fatalf("want PGRST114, got %v", err) + } +} + +func TestWritePutPayloadMismatchIs400(t *testing.T) { + // URL says id=eq.999, body says id=1: the keys disagree. + _, err := Write(modelPK(), putQuery("999", "1", nil, nil), nil) + if err == nil || err.Code != "PGRST115" { + t.Fatalf("want PGRST115, got %v", err) + } +} + +func TestWritePutMultiRowIs400(t *testing.T) { + q := putQuery("1", "1", nil, nil) + q.Write.Rows = append(q.Write.Rows, map[string]ir.Value{"id": {JSON: float64(1)}, "title": {JSON: "Y"}}) + _, err := Write(modelPK(), q, nil) + if err == nil || err.Code != "PGRST115" { + t.Fatalf("want PGRST115 for a multi-row PUT body, got %v", err) + } +} + +// embedModel wires directors (one) <- films (many) through a forward FK so an +// embed of films on directors resolves to a single relationship. +func nullEmbedModel() *schema.Model { + directors := &schema.Relation{Schema: "public", Name: "directors", Kind: schema.KindTable, Columns: []*schema.Column{ + {Name: "id", Type: "integer", Position: 1}, + {Name: "name", Type: "text", Position: 2}, + }} + films := &schema.Relation{Schema: "public", Name: "films", Kind: schema.KindTable, Columns: []*schema.Column{ + {Name: "id", Type: "integer", Position: 1}, + {Name: "title", Type: "text", Position: 2}, + {Name: "director_id", Type: "integer", Position: 3}, + }, ForeignKeys: []*schema.ForeignKey{{ + Name: "films_director_id_fkey", Columns: []string{"director_id"}, + RefSchema: "public", RefRelation: "directors", RefColumns: []string{"id"}, + }}} + return schema.NewModel([]*schema.Relation{directors, films}) +} + +// A filter naming an embed (directors?films=not.is.null) is reclassified into an +// EmbedPredicate before column validation, rather than being rejected as an +// unknown parent column (item 01.12). not.is.null sets Exists. +func TestReadReclassifiesEmbedNullFilter(t *testing.T) { + for _, tc := range []struct { + name string + negate bool + exists bool + }{ + {"not.is.null is a semi-join", true, true}, + {"is.null is an anti-join", false, false}, + } { + t.Run(tc.name, func(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"films"}, Op: ir.OpIs, Value: ir.Value{Text: "null"}, Negate: tc.negate}) + q := &ir.Query{ + Relation: ir.Ref{Name: "directors"}, + Select: []ir.SelectItem{ir.EmbedRef{Index: 0}}, + Where: &where, + Embeds: []ir.Embed{{Target: ir.Ref{Name: "films"}, OutKey: "films"}}, + } + if _, err := Read(nullEmbedModel(), q, []string{"public"}, Options{}); err != nil { + t.Fatalf("Read: %v", err) + } + pred, ok := (*q.Where).(ir.EmbedPredicate) + if !ok { + t.Fatalf("Where = %T, want ir.EmbedPredicate", *q.Where) + } + if pred.Index != 0 || pred.Exists != tc.exists { + t.Errorf("predicate = %+v, want Index 0 Exists %v", pred, tc.exists) + } + }) + } +} + +// A null filter naming a real parent column (not an embed) stays an ordinary +// Compare and is column-validated as usual. +func TestReadEmbedNullReclassifyLeavesColumns(t *testing.T) { + where := ir.Cond(ir.Compare{Path: []string{"title"}, Op: ir.OpIs, Value: ir.Value{Text: "null"}, Negate: true}) + q := &ir.Query{ + Relation: ir.Ref{Name: "directors"}, + Select: []ir.SelectItem{ir.EmbedRef{Index: 0}}, + Where: &where, + Embeds: []ir.Embed{{Target: ir.Ref{Name: "films"}, OutKey: "films"}}, + } + // title is not a directors column; the filter is a Compare, so column + // validation rejects it rather than mistaking it for an embed predicate. + _, err := Read(nullEmbedModel(), q, []string{"public"}, Options{}) + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703 for unknown filter column, got %v", err) + } +} + +// A write whose select embeds a resource with no relationship is the read +// path's PGRST200 rather than silently dropping the embed (item 01.19). +func TestWriteResolvesEmbedsRejectsUnknown(t *testing.T) { + q := &ir.Query{ + Kind: ir.Insert, + Relation: ir.Ref{Name: "films"}, + Select: []ir.SelectItem{ir.EmbedRef{Index: 0}}, + Embeds: []ir.Embed{{Target: ir.Ref{Name: "ghosts"}, OutKey: "ghosts"}}, + Write: &ir.WriteSpec{Columns: []string{"title"}, Rows: []map[string]ir.Value{{"title": {JSON: "X"}}}, Return: ir.ReturnRepresentation}, + } + _, err := Write(model(), q, nil) + if err == nil || err.Code != "PGRST200" { + t.Fatalf("want PGRST200 for an unknown write embed, got %v", err) + } +} diff --git a/plan/related_order_test.go b/plan/related_order_test.go new file mode 100644 index 0000000..237dbf6 --- /dev/null +++ b/plan/related_order_test.go @@ -0,0 +1,81 @@ +package plan + +import ( + "testing" + + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/pgerr" + "github.com/tamnd/dbrest/schema" +) + +// planOrder parses and plans a films read with the given select and order +// strings, returning the plan or the planner error. +func planOrder(t *testing.T, m *schema.Model, relation, query string) (*ir.Plan, *pgerr.APIError) { + t.Helper() + q, perr := ir.ParseRead(relation, query, nil) + if perr != nil { + t.Fatalf("ParseRead: %v", perr) + } + return Read(m, q, nil, Options{}) +} + +// A related order over a to-one embed resolves: the term carries the relation +// name and the planner accepts it once the embed and its column check out (item +// 07.6). +func TestRelatedOrderToOneResolves(t *testing.T) { + m := embedModel() + pl, err := planOrder(t, m, "films", + "select=title,people!director_id(name)&order=people(name).asc") + if err != nil { + t.Fatalf("Read: %v", err) + } + var found bool + for _, ot := range pl.Query.Order { + if ot.Rel == "people" && len(ot.Path) == 1 && ot.Path[0] == "name" { + found = true + } + } + if !found { + t.Errorf("planned order missing related term, got %+v", pl.Query.Order) + } +} + +// Ordering by a relation not embedded in the request is PGRST108: the term names +// a resource the select never pulled in. +func TestRelatedOrderNotEmbeddedIsPGRST108(t *testing.T) { + m := embedModel() + _, err := planOrder(t, m, "films", "select=title&order=people(name).asc") + if err == nil || err.Code != "PGRST108" { + t.Fatalf("want PGRST108, got %v", err) + } + if err.HTTPStatus != 400 { + t.Errorf("status = %d, want 400", err.HTTPStatus) + } +} + +// Ordering by a to-many embed is PGRST118: a parent cannot sort on a column of a +// resource it has many of. people own many films, so people?order=films(title) +// is not a to-one relation. +func TestRelatedOrderToManyIsPGRST118(t *testing.T) { + m := embedModel() + _, err := planOrder(t, m, "people", + "select=name,films!director_id(title)&order=films(title).asc") + if err == nil || err.Code != "PGRST118" { + t.Fatalf("want PGRST118, got %v", err) + } + if err.HTTPStatus != 400 { + t.Errorf("status = %d, want 400", err.HTTPStatus) + } +} + +// A related order naming a real embed but an unknown column on the target is the +// ordinary unknown-column rejection (42703, the reference reaches PostgreSQL), +// not a relation error. +func TestRelatedOrderUnknownColumnIsRejected(t *testing.T) { + m := embedModel() + _, err := planOrder(t, m, "films", + "select=title,people!director_id(name)&order=people(nope).asc") + if err == nil || err.Code != "42703" { + t.Fatalf("want 42703, got %v", err) + } +} diff --git a/plan/rpc_embed_test.go b/plan/rpc_embed_test.go new file mode 100644 index 0000000..cdeb02e --- /dev/null +++ b/plan/rpc_embed_test.go @@ -0,0 +1,75 @@ +package plan + +import ( + "testing" + + "github.com/tamnd/dbrest/ir" + "github.com/tamnd/dbrest/rpc" +) + +// callReg wraps one function in a static registry for the call planner. +func callReg(fn *rpc.Function) rpc.Registry { + return rpc.NewStaticRegistry([]*rpc.Function{fn}) +} + +// filmsSetof returns rows of the films relation, the embeddable RPC shape. +func filmsSetof() *rpc.Function { + return &rpc.Function{ + Name: "recent_films", + Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf, Type: "films"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: "SELECT * FROM films"}, + } +} + +// A call over a function returning rows of a relation resolves its embeds against +// that relation, binding the relationship the same way a table read does. +func TestCallResolvesEmbedAgainstReturnRelation(t *testing.T) { + m := embedModel() + c, perr := ir.ParseCall("recent_films", "select=title,people!director_id(name)", nil, true, "", nil, "", "") + if perr != nil { + t.Fatalf("ParseCall: %v", perr) + } + pl, err := Call(callReg(filmsSetof()), m, c, true, nil) + if err != nil { + t.Fatalf("Call: %v", err) + } + if len(pl.Call.Embeds) != 1 { + t.Fatalf("got %d embeds, want 1", len(pl.Call.Embeds)) + } + emb := pl.Call.Embeds[0] + if emb.Rel == nil { + t.Fatal("embed relationship not bound") + } + if emb.Cardinality != ir.CardToOne { + t.Errorf("cardinality = %v, want to-one", emb.Cardinality) + } + if emb.Query.Relation.Name != "people" { + t.Errorf("embed relation = %q, want people", emb.Query.Relation.Name) + } +} + +// A function whose result is not a known relation has nothing to embed against, +// which is the read path's PGRST200. +func TestCallEmbedOnScalarReturnIsPGRST200(t *testing.T) { + m := embedModel() + scalar := &rpc.Function{ + Name: "film_titles", + Returns: rpc.ReturnShape{Kind: rpc.ReturnSetOf, Type: "text"}, + Volatility: rpc.Stable, + Query: &rpc.PortableQuery{SQL: "SELECT title FROM films"}, + } + c, perr := ir.ParseCall("film_titles", "select=people(name)", nil, true, "", nil, "", "") + if perr != nil { + t.Fatalf("ParseCall: %v", perr) + } + _, err := Call(callReg(scalar), m, c, true, nil) + if err == nil { + t.Fatal("want an error embedding on a scalar-set return") + } + if err.Code != pgerrCodeNoRelationship { + t.Errorf("code = %q, want %q", err.Code, pgerrCodeNoRelationship) + } +} + +const pgerrCodeNoRelationship = "PGRST200" diff --git a/reqctx/reqctx.go b/reqctx/reqctx.go index 5981d6f..4093b49 100644 --- a/reqctx/reqctx.go +++ b/reqctx/reqctx.go @@ -16,6 +16,7 @@ package reqctx import ( "encoding/json" + "maps" "sort" "strings" ) @@ -43,6 +44,27 @@ type Context struct { // Content-Profile choice), or "" for the default. Cross-schema routing is the // introspection subsystem's job (spec 08); this field carries the choice. Schema string + // PreRequest names the db-pre-request function the backend must invoke after + // the request context is in place and before the main statement, in the same + // transaction (spec 13). Empty means none is configured. An error the + // function raises aborts the request through normal error mapping, and any + // response controls it writes are applied at render time. + PreRequest string + + // AppSettings are the app.settings.* options, keys without the prefix. A + // backend applies them as transaction settings (GUCs on PostgreSQL) next + // to the request context. + AppSettings map[string]string + + // LogQuery asks the backend to echo the statements it executes for this + // request, the log-query option. + LogQuery bool + + // TimeZone is the Prefer: timezone= request timezone, already validated by + // ir.ParsePrefer. A backend that supports it applies SET LOCAL timezone for the + // request; the emulated render path converts temporals to it. Empty means the + // client stated no timezone and the engine default stands. + TimeZone string controls ResponseControls } @@ -50,12 +72,31 @@ type Context struct { // ClaimsJSON marshals the verified claims into the object request.jwt.claims // carries. It is "{}" when there are no claims, never null, so a backend that // writes the GUC verbatim and a policy that reads it both see a valid object. +// +// The resolved request role is folded in under the "role" key, overwriting any +// role the token carried, exactly as PostgREST does (PreQuery.hs inserts the +// resolved role into the claims before writing the GUC). So an anonymous request +// presents {"role":""} rather than {}, and the common RLS pattern +// current_setting('request.jwt.claims', true)::json->>'role' reads the role on +// every request, the case PostgREST guarantees a value. func (c *Context) ClaimsJSON() []byte { - if len(c.Claims) == 0 { - return []byte("{}") + if c.Role == "" { + // No resolved role to fold in (the fail-closed frontend always sets one, so + // this is the degenerate case); marshal the claims as they are. + if len(c.Claims) == 0 { + return []byte("{}") + } + b, err := json.Marshal(c.Claims) + if err != nil { + return []byte("{}") + } + return b } + merged := make(map[string]any, len(c.Claims)+1) + maps.Copy(merged, c.Claims) + merged["role"] = c.Role // encoding/json sorts map keys, so the output is deterministic. - b, err := json.Marshal(c.Claims) + b, err := json.Marshal(merged) if err != nil { return []byte("{}") } @@ -63,13 +104,24 @@ func (c *Context) ClaimsJSON() []byte { } // HeadersJSON marshals the request headers into the object request.headers -// carries: a JSON object of lower-cased header name to value, with a multi-valued -// header joined by ", " as HTTP defines. Keys are sorted for a deterministic -// document. +// carries: a JSON object of lower-cased header name to value. Keys are sorted +// for a deterministic document. +// +// Two rules match PostgREST exactly (ApiRequest.hs). The Cookie header is +// excluded, because it is delivered separately as request.cookies; including it +// would leak raw cookie material into code that only consults headers. A header +// sent more than once resolves to its last value (later wins), not a comma-join, +// reproducing how PostgREST's later pair overwrites the earlier in the object. func (c *Context) HeadersJSON() []byte { flat := make(map[string]string, len(c.Headers)) for k, vs := range c.Headers { - flat[strings.ToLower(k)] = strings.Join(vs, ", ") + lk := strings.ToLower(k) + if lk == "cookie" { + continue + } + if len(vs) > 0 { + flat[lk] = vs[len(vs)-1] + } } return marshalSortedObject(flat) } @@ -117,13 +169,15 @@ type ResponseControls struct { Status int // Headers are extra response headers to merge in. Headers map[string]string - // UpsertInsert is set by the backend when it can determine whether the upsert - // resulted in a pure INSERT (all rows were new, true) or hit existing rows - // (false). Only backends that support this detection (PostgreSQL via xmax) - // set UpsertStatusKnown = true; others leave it false, and the HTTP layer - // defaults to 201 for POST upserts. + // InsertedRows is the number of payload rows the upsert inserted as new (the + // rest replaced existing rows). The HTTP layer reads it to separate a + // zero-inserted merge from a mixed batch: a POST merge upsert is 200 only when + // it is zero, a PUT is 201 only when it is positive. A backend sets it together + // with UpsertStatusKnown = true when it can detect insert-vs-update (sqlite by a + // pre-write key probe, PostgreSQL via xmax); others leave the status unknown and + // the HTTP layer defaults to 201 for POST upserts. UpsertStatusKnown bool - UpsertInsert bool + InsertedRows int } // Controls returns a pointer to the mutable response controls. diff --git a/reqctx/reqctx_test.go b/reqctx/reqctx_test.go index 33cf100..e0d6213 100644 --- a/reqctx/reqctx_test.go +++ b/reqctx/reqctx_test.go @@ -24,13 +24,43 @@ func TestHeadersJSONFlattensAndLowercases(t *testing.T) { "Accept": {"application/json"}, "Cache-Control": {"no-cache", "no-store"}, }} - // Lower-cased names, sorted keys, a multi-valued header joined by ", ". - want := `{"accept":"application/json","cache-control":"no-cache, no-store","x-tenant":"acme"}` + // Lower-cased names, sorted keys, a repeated header resolved to its last value + // (PostgREST's later-wins), not a comma-join. + want := `{"accept":"application/json","cache-control":"no-store","x-tenant":"acme"}` if got := string(c.HeadersJSON()); got != want { t.Errorf("HeadersJSON() = %q, want %q", got, want) } } +// The Cookie header is excluded from request.headers (it is request.cookies), +// matching PostgREST, verified against PostgREST 14.12. A repeated header keeps +// its last value. +func TestHeadersJSONExcludesCookieAndKeepsLast(t *testing.T) { + c := &Context{Headers: map[string][]string{ + "Cookie": {"sessionid=abc123"}, + "X-Dup": {"first", "second"}, + }} + want := `{"x-dup":"second"}` + if got := string(c.HeadersJSON()); got != want { + t.Errorf("HeadersJSON() = %q, want %q", got, want) + } +} + +// The resolved role is folded into request.jwt.claims under "role", overwriting +// any token role, matching PostgREST 14.12 (anonymous presents the anon role). +func TestClaimsJSONFoldsInRole(t *testing.T) { + // Anonymous: no claims, the resolved role still appears. + anon := &Context{Role: "anon"} + if got := string(anon.ClaimsJSON()); got != `{"role":"anon"}` { + t.Errorf("anon ClaimsJSON() = %q, want {\"role\":\"anon\"}", got) + } + // Authenticated: the resolved role overwrites whatever the token carried. + auth := &Context{Role: "web_user", Claims: map[string]any{"role": "stale", "sub": "alice"}} + if got := string(auth.ClaimsJSON()); got != `{"role":"web_user","sub":"alice"}` { + t.Errorf("auth ClaimsJSON() = %q", got) + } +} + func TestHeadersJSONEmptyIsObject(t *testing.T) { c := &Context{} if got := string(c.HeadersJSON()); got != "{}" { diff --git a/rpc/merge_test.go b/rpc/merge_test.go new file mode 100644 index 0000000..b987c84 --- /dev/null +++ b/rpc/merge_test.go @@ -0,0 +1,94 @@ +package rpc + +import ( + "reflect" + "sort" + "testing" +) + +// names lists every function name a registry exposes, with overloads repeated, in +// the order List returns them, so a test can assert the merged candidate set. +func names(reg Registry) []string { + var out []string + for _, f := range reg.List() { + out = append(out, f.Signature("")) + } + return out +} + +// Merge returns the non-empty side unchanged when the other is empty, so the +// common case (no declared registry alongside the native one) adds no overhead. +func TestMergeEmptySides(t *testing.T) { + native := NewStaticRegistry([]*Function{ + {Name: "add", Params: []Param{{Name: "a"}, {Name: "b"}}}, + }) + if got := Merge(EmptyRegistry{}, native); got != Registry(native) { + t.Errorf("Merge(empty, native) should return native unchanged") + } + portable := NewStaticRegistry([]*Function{{Name: "echo", Params: []Param{{Name: "x"}}}}) + if got := Merge(portable, EmptyRegistry{}); got != Registry(portable) { + t.Errorf("Merge(portable, empty) should return portable unchanged") + } +} + +// A function declared in the primary registry with the same signature as a native +// one shadows it: the primary version is the one resolved, so an operator's +// explicit declaration wins. +func TestMergePrimaryShadowsSameSignature(t *testing.T) { + portable := NewStaticRegistry([]*Function{ + {Name: "add", Params: []Param{{Name: "a", Type: "int"}, {Name: "b", Type: "int"}}, + Query: &PortableQuery{SQL: "SELECT :a + :b"}}, + }) + native := NewStaticRegistry([]*Function{ + {Name: "add", Params: []Param{{Name: "a", Type: "int"}, {Name: "b", Type: "int"}}}, + }) + merged := Merge(portable, native) + + got := names(merged) + want := []string{"add(a => int, b => int)"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("merged signatures = %v, want %v", got, want) + } + fn, ok := merged.Lookup("add", ArgSet{"a": true, "b": true}) + if !ok { + t.Fatal("add should resolve in the merged registry") + } + if fn.Query == nil { + t.Error("the declared (portable) add should win over the native one") + } +} + +// Overloads that differ in any parameter are kept as distinct candidates from both +// sides, and overload resolution runs across the union: the right arg set picks the +// right source. +func TestMergeKeepsDistinctOverloads(t *testing.T) { + portable := NewStaticRegistry([]*Function{ + {Name: "f", Params: []Param{{Name: "a", Type: "text"}}, + Query: &PortableQuery{SQL: "SELECT :a"}}, + }) + native := NewStaticRegistry([]*Function{ + {Name: "f", Params: []Param{{Name: "a", Type: "text"}, {Name: "b", Type: "text"}}}, + {Name: "g", Params: []Param{{Name: "x", Type: "int"}}}, + }) + merged := Merge(portable, native) + + got := names(merged) + sort.Strings(got) + want := []string{"f(a => text)", "f(a => text, b => text)", "g(x => int)"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("merged signatures = %v, want %v", got, want) + } + + // The one-arg form is the portable overload; the two-arg form is native. + one, ok := merged.Lookup("f", ArgSet{"a": true}) + if !ok || one.Query == nil { + t.Errorf("f(a) should resolve to the portable overload, got %+v ok=%v", one, ok) + } + two, ok := merged.Lookup("f", ArgSet{"a": true, "b": true}) + if !ok || two.Query != nil { + t.Errorf("f(a,b) should resolve to the native overload, got %+v ok=%v", two, ok) + } + if _, ok := merged.Lookup("g", ArgSet{"x": true}); !ok { + t.Error("g(x) from the native side should resolve in the merged registry") + } +} diff --git a/rpc/registry.go b/rpc/registry.go index 596233a..6ba5791 100644 --- a/rpc/registry.go +++ b/rpc/registry.go @@ -9,7 +9,12 @@ // *Function without a cycle. package rpc -import "sort" +import ( + "encoding/json" + "fmt" + "sort" + "strings" +) // Volatility classifies a function's effect, which fixes the methods it allows // and the transaction mode it runs in (spec 12). A registry entry that omits it @@ -41,9 +46,11 @@ const ( type ReturnKind uint8 const ( - ReturnScalar ReturnKind = iota // returns -> a single value - ReturnSetOf // returns setof -> an array of values - ReturnTable // returns table(...) -> an array of objects + ReturnScalar ReturnKind = iota // returns -> a single value + ReturnSetOf // returns setof -> an array of values + ReturnTable // returns table(...) -> an array of objects + ReturnVoid // returns void -> 200 with a null body + ReturnObject // returns -> one object, not an array ) // ReturnShape is a function's declared result. Type is the canonical type of a @@ -68,15 +75,36 @@ type Param struct { Type string // canonical type (spec 16) Optional bool // may be omitted; Default is bound in its place Default any // value bound when an optional param is omitted (nil = NULL) - Variadic bool // collects the trailing values (not yet lowered; see notes) + Variadic bool // collects the trailing values into a list, expanded at lowering + // RawBody marks the single-unnamed-parameter form: PostgREST binds the whole + // raw request body to this parameter, decoded by Content-Type (a JSON value of + // any kind for application/json, raw text for text/plain and text/xml, raw + // bytes for application/octet-stream), rather than treating the body as a JSON + // object of named arguments. The parameter keeps a name so the SQL body can + // reference its placeholder. + RawBody bool +} + +// SingleRawBody reports whether the function takes exactly one parameter bound +// from the raw request body, the unnamed-argument form. Such a function receives +// the whole body as that one argument regardless of the body's JSON shape. +func (f *Function) SingleRawBody() (Param, bool) { + if len(f.Params) == 1 && f.Params[0].RawBody { + return f.Params[0], true + } + return Param{}, false } // Function is one callable function descriptor. Exactly one realization is set; // this slice implements the portable Query (native discovery from an engine // catalog is a later slice). type Function struct { - Name string - Params []Param + Name string + Params []Param + // Comment is the database comment on the function (COMMENT ON FUNCTION, + // or the registry declaration's comment field). The OpenAPI generator + // splits it into the rpc operation's summary and description, as v14 does. + Comment string Returns ReturnShape Volatility Volatility Security SecurityMode @@ -90,11 +118,13 @@ type PortableQuery struct { SQL string } -// Required reports the names of the function's non-optional parameters. +// Required reports the names of the function's non-optional parameters. A +// variadic parameter is never required: PostgreSQL accepts a variadic call with +// zero trailing arguments, so an omitted variadic still satisfies an overload. func (f *Function) Required() []string { var req []string for _, p := range f.Params { - if !p.Optional { + if !p.Optional && !p.Variadic { req = append(req, p.Name) } } @@ -123,10 +153,30 @@ type Registry interface { // set present in the request. The bool is false when no overload matches; the // caller raises PGRST202. Lookup(name string, args ArgSet) (*Function, bool) + // Resolve is Lookup that also reports ambiguity: it returns the chosen overload + // (ok true), or ok false with the competing signatures when two overloads are + // equally good (the caller raises PGRST203), or ok false with no signatures + // when none match (PGRST202). Lookup is Resolve collapsed to (fn, ok). + Resolve(name string, args ArgSet) (fn *Function, ambiguous []string, ok bool) // List enumerates the exposed functions in a stable order, for OpenAPI. List() []*Function } +// Signature renders the function as PostgREST spells it in a PGRST202/PGRST203 +// message: name(param => type, ...), or name() when it takes no parameters. The +// schema, when given, qualifies the name (api.add(...)). +func (f *Function) Signature(schemaName string) string { + name := f.Name + if schemaName != "" { + name = schemaName + "." + name + } + parts := make([]string, len(f.Params)) + for i, p := range f.Params { + parts[i] = p.Name + " => " + p.Type + } + return name + "(" + strings.Join(parts, ", ") + ")" +} + // StaticRegistry is a portable registry built in memory: one or more overloads // per name, declared programmatically (and, once configuration lands, from // config). It is the realization for engines with no usable stored-procedure @@ -156,12 +206,23 @@ func NewStaticRegistry(fns []*Function) *StaticRegistry { // declare. Among satisfiable overloads it prefers an exact parameter-set match, // then the most specific (largest required set), deterministically. func (r *StaticRegistry) Lookup(name string, args ArgSet) (*Function, bool) { + fn, _, ok := r.Resolve(name, args) + return fn, ok +} + +// Resolve selects the overload for an argument set and reports ambiguity. It +// scores every satisfiable overload (an exact parameter-set match wins outright, +// then the largest required set), and when two overloads tie at the top score it +// returns them as competing signatures instead of silently picking one, so the +// caller raises PGRST203 the way PostgREST does for unresolvable overloads. +func (r *StaticRegistry) Resolve(name string, args ArgSet) (*Function, []string, bool) { cands := r.byName[name] if len(cands) == 0 { - return nil, false + return nil, nil, false } var best *Function bestScore := -1 + var tied []*Function for _, f := range cands { if !satisfiable(f, args) { continue @@ -170,14 +231,25 @@ func (r *StaticRegistry) Lookup(name string, args ArgSet) (*Function, bool) { if exactMatch(f, args) { score += 1000 // an exact parameter-set match wins outright } - if score > bestScore { - best, bestScore = f, score + switch { + case score > bestScore: + best, bestScore, tied = f, score, []*Function{f} + case score == bestScore: + tied = append(tied, f) } } if best == nil { - return nil, false + return nil, nil, false + } + if len(tied) > 1 { + sigs := make([]string, len(tied)) + for i, f := range tied { + sigs[i] = f.Signature("") + } + sort.Strings(sigs) + return nil, sigs, false } - return best, true + return best, nil, true } // List returns the functions in stable name order (overloads in declared order). @@ -218,6 +290,140 @@ func exactMatch(f *Function, args ArgSet) bool { return true } +// ParseRegistry decodes a JSON function-registry declaration into a +// StaticRegistry ready to Register on a backend. The JSON is an array of +// function objects; each carries: +// +// name string required; bare function name +// sql string required; parameterized SQL with :name placeholders +// comment string optional; surfaces in the OpenAPI document +// params []{name, type, optional?, default?} +// returns {kind: "scalar"|"setof"|"table", type?, columns?} +// volatility "volatile"|"stable"|"immutable" (default: volatile) +// +// Returns an error when the JSON is malformed; an empty array yields an empty +// registry. Schemas are stripped from names; a name of "api.add" resolves as "add". +func ParseRegistry(rawJSON string) (*StaticRegistry, error) { + rawJSON = strings.TrimSpace(rawJSON) + if rawJSON == "" { + return NewStaticRegistry(nil), nil + } + type paramDecl struct { + Name string `json:"name"` + Type string `json:"type"` + Optional bool `json:"optional"` + Default any `json:"default"` + Variadic bool `json:"variadic"` + RawBody bool `json:"rawBody"` + } + type returnDecl struct { + Kind string `json:"kind"` + Type string `json:"type"` + Columns []struct { + Name string `json:"name"` + Type string `json:"type"` + } `json:"columns"` + } + type fnDecl struct { + Name string `json:"name"` + SQL string `json:"sql"` + Comment string `json:"comment"` + Params []paramDecl `json:"params"` + Returns returnDecl `json:"returns"` + Volatility string `json:"volatility"` + } + var decls []fnDecl + if err := json.Unmarshal([]byte(rawJSON), &decls); err != nil { + return nil, fmt.Errorf("function-registry: %w", err) + } + fns := make([]*Function, 0, len(decls)) + for _, d := range decls { + // Strip schema prefix (e.g. "api.add" → "add"). + name := d.Name + if dot := strings.LastIndex(name, "."); dot >= 0 { + name = name[dot+1:] + } + var vol Volatility + switch strings.ToLower(d.Volatility) { + case "stable": + vol = Stable + case "immutable": + vol = Immutable + default: + vol = Volatile + } + params := make([]Param, len(d.Params)) + for i, p := range d.Params { + params[i] = Param(p) + } + var ret ReturnShape + switch strings.ToLower(d.Returns.Kind) { + case "void": + ret.Kind = ReturnVoid + case "object": + ret.Kind = ReturnObject + case "setof": + ret.Kind = ReturnSetOf + case "table": + ret.Kind = ReturnTable + ret.Columns = make([]Column, len(d.Returns.Columns)) + for i, c := range d.Returns.Columns { + ret.Columns[i] = Column{Name: c.Name, Type: c.Type} + } + default: + ret.Kind = ReturnScalar + } + ret.Type = d.Returns.Type + fns = append(fns, &Function{ + Name: name, + Params: params, + Comment: d.Comment, + Returns: ret, + Volatility: vol, + Query: &PortableQuery{SQL: d.SQL}, + }) + } + return NewStaticRegistry(fns), nil +} + +// Merge composes two registries into one, with primary taking precedence on an +// exact-signature collision. It is how a NativeRPC backend exposes both a declared +// portable registry and its introspected native functions through the single +// Registry the frontend resolves against: a function declared via Register shadows +// a native function of the same signature (the explicit declaration wins), while +// every other native function and every distinct portable overload stays reachable. +// Overload resolution then runs across the union in one place, so a wrong argument +// set is PGRST202 and an ambiguous one PGRST203 regardless of which source an +// overload came from. The signature key is the function's printed signature +// (name plus parameter names and types), so two overloads that differ in any +// parameter are kept as distinct candidates rather than one shadowing the other. +// +// When either side is empty the other is returned unchanged, so the common +// case (no declared registry) costs nothing. +func Merge(primary, secondary Registry) Registry { + pl := primary.List() + sl := secondary.List() + if len(pl) == 0 { + return secondary + } + if len(sl) == 0 { + return primary + } + seen := make(map[string]bool, len(pl)) + merged := make([]*Function, 0, len(pl)+len(sl)) + for _, f := range pl { + seen[f.Signature("")] = true + merged = append(merged, f) + } + for _, f := range sl { + if seen[f.Signature("")] { + continue // shadowed by a same-signature primary function + } + merged = append(merged, f) + } + return NewStaticRegistry(merged) +} + // EmptyRegistry is a registry with no functions; every Lookup misses. A backend // that has not been given any functions returns this so the frontend raises a // clean PGRST202 rather than dereferencing nil. @@ -226,5 +432,10 @@ type EmptyRegistry struct{} // Lookup always misses: an empty registry has no functions. func (EmptyRegistry) Lookup(string, ArgSet) (*Function, bool) { return nil, false } +// Resolve always misses with no ambiguity: an empty registry has no functions. +func (EmptyRegistry) Resolve(string, ArgSet) (*Function, []string, bool) { + return nil, nil, false +} + // List returns no functions. func (EmptyRegistry) List() []*Function { return nil } diff --git a/rpc/registry_test.go b/rpc/registry_test.go index 4a0971b..5bb8453 100644 --- a/rpc/registry_test.go +++ b/rpc/registry_test.go @@ -115,3 +115,171 @@ func TestEmptyRegistry(t *testing.T) { t.Error("EmptyRegistry.List must be nil") } } + +// TestResolveAmbiguousOverloads checks that two overloads tying at the top score +// resolve to PGRST203 input: ok false with both competing signatures, rather than +// silently picking one. Two single-optional-parameter overloads called with no +// arguments are each satisfiable with the same (zero required) score. +func TestResolveAmbiguousOverloads(t *testing.T) { + left := &Function{Name: "f", Params: []Param{{Name: "a", Optional: true}}} + right := &Function{Name: "f", Params: []Param{{Name: "b", Optional: true}}} + reg := NewStaticRegistry([]*Function{left, right}) + + fn, ambiguous, ok := reg.Resolve("f", ArgSet{}) + if ok || fn != nil { + t.Fatalf("Resolve f() = %v, ok %v, want ambiguous miss", fn, ok) + } + want := []string{"f(a => )", "f(b => )"} + if !reflect.DeepEqual(ambiguous, want) { + t.Errorf("ambiguous = %v, want %v", ambiguous, want) + } +} + +// TestResolveExactWinsOverAmbiguous checks that an exact parameter-set match +// breaks a tie outright: f(a,b) and f(a,c) both take two arguments, but calling +// with exactly {a,b} names f(a,b)'s parameters and no other's. +func TestResolveExactWinsOverAmbiguous(t *testing.T) { + ab := &Function{Name: "f", Params: []Param{{Name: "a"}, {Name: "b"}}} + ac := &Function{Name: "f", Params: []Param{{Name: "a"}, {Name: "c"}}} + reg := NewStaticRegistry([]*Function{ab, ac}) + + fn, ambiguous, ok := reg.Resolve("f", ArgSet{"a": true, "b": true}) + if !ok || fn != ab || ambiguous != nil { + t.Fatalf("Resolve f(a,b) = %v, ambiguous %v, ok %v", fn, ambiguous, ok) + } +} + +// TestResolveUnknownName checks an unknown name misses cleanly (PGRST202 input): +// ok false with no competing signatures. +func TestResolveUnknownName(t *testing.T) { + reg := NewStaticRegistry(nil) + fn, ambiguous, ok := reg.Resolve("nope", nil) + if ok || fn != nil || ambiguous != nil { + t.Errorf("Resolve(nope) = %v, %v, %v", fn, ambiguous, ok) + } +} + +// TestSignature checks the PostgREST-style rendering used in PGRST202/PGRST203 +// messages: schema-qualified name with each parameter as "name => type", and the +// parameterless form collapsing to name(). +func TestSignature(t *testing.T) { + f := &Function{Name: "add", Params: []Param{ + {Name: "a", Type: "int4"}, + {Name: "b", Type: "int4"}, + }} + if got := f.Signature("api"); got != "api.add(a => int4, b => int4)" { + t.Errorf("Signature = %q", got) + } + if got := f.Signature(""); got != "add(a => int4, b => int4)" { + t.Errorf("unqualified Signature = %q", got) + } + none := &Function{Name: "now"} + if got := none.Signature("api"); got != "api.now()" { + t.Errorf("parameterless Signature = %q", got) + } +} + +// TestParseRegistryVoidKind checks a "void" return declaration decodes to the +// void kind, which the renderer answers with 200 and a null body. +func TestParseRegistryVoidKind(t *testing.T) { + reg, err := ParseRegistry(`[{ + "name": "touch", + "sql": "insert into log default values", + "returns": {"kind": "void"} + }]`) + if err != nil { + t.Fatalf("ParseRegistry: %v", err) + } + f, ok := reg.Lookup("touch", ArgSet{}) + if !ok { + t.Fatal("touch not found") + } + if f.Returns.Kind != ReturnVoid { + t.Errorf("return kind = %v, want ReturnVoid", f.Returns.Kind) + } +} + +// TestParseRegistryVariadic checks a "variadic": true parameter decodes to a +// Variadic param, which Required omits so a zero-argument call still resolves. +func TestParseRegistryVariadic(t *testing.T) { + reg, err := ParseRegistry(`[{ + "name": "pick", + "sql": "select title from films where id in (:ids)", + "params": [{"name": "ids", "type": "integer", "variadic": true}], + "returns": {"kind": "setof", "type": "text"} + }]`) + if err != nil { + t.Fatalf("ParseRegistry: %v", err) + } + f, ok := reg.Lookup("pick", ArgSet{}) + if !ok { + t.Fatal("a variadic-only function must resolve with no arguments") + } + if len(f.Params) != 1 || !f.Params[0].Variadic { + t.Errorf("params = %+v, want one variadic", f.Params) + } + if len(f.Required()) != 0 { + t.Errorf("Required = %v, want none for a variadic", f.Required()) + } +} + +// TestParseRegistryRawBody checks a "rawBody": true parameter decodes to a +// raw-body param, and SingleRawBody recognizes a one-parameter function of that +// shape as taking the whole POST body as its single unnamed argument. +func TestParseRegistryRawBody(t *testing.T) { + reg, err := ParseRegistry(`[{ + "name": "echo", + "sql": "select :payload", + "params": [{"name": "payload", "type": "json", "rawBody": true}], + "returns": {"kind": "scalar", "type": "json"} + }]`) + if err != nil { + t.Fatalf("ParseRegistry: %v", err) + } + f, ok := reg.Lookup("echo", ArgSet{"payload": true}) + if !ok { + t.Fatal("echo not found") + } + if len(f.Params) != 1 || !f.Params[0].RawBody { + t.Errorf("params = %+v, want one raw-body param", f.Params) + } + p, ok := f.SingleRawBody() + if !ok || p.Name != "payload" || p.Type != "json" { + t.Errorf("SingleRawBody = %+v, %v", p, ok) + } +} + +// TestSingleRawBodyRejectsMultiParam checks SingleRawBody only fires on a lone +// parameter: a function with a raw-body parameter beside another is not the +// single-unnamed-argument form. +func TestSingleRawBodyRejectsMultiParam(t *testing.T) { + f := &Function{Name: "f", Params: []Param{ + {Name: "payload", RawBody: true}, + {Name: "tag"}, + }} + if _, ok := f.SingleRawBody(); ok { + t.Error("a raw-body parameter beside another is not a single raw body") + } +} + +// TestParseRegistryComment checks a declaration's comment field rides into the +// Function, where the OpenAPI generator reads it. +func TestParseRegistryComment(t *testing.T) { + reg, err := ParseRegistry(`[{ + "name": "add", + "sql": "select :a + :b", + "comment": "Add two numbers\nReturns the sum.", + "params": [{"name": "a", "type": "int4"}, {"name": "b", "type": "int4"}], + "returns": {"kind": "scalar", "type": "int4"} + }]`) + if err != nil { + t.Fatalf("ParseRegistry: %v", err) + } + f, ok := reg.Lookup("add", ArgSet{"a": true, "b": true}) + if !ok { + t.Fatal("add not found") + } + if f.Comment != "Add two numbers\nReturns the sum." { + t.Errorf("Comment = %q", f.Comment) + } +} diff --git a/schema/cache.go b/schema/cache.go new file mode 100644 index 0000000..1556c8e --- /dev/null +++ b/schema/cache.go @@ -0,0 +1,26 @@ +package schema + +import "sync/atomic" + +// Cache publishes the schema model. Readers take an immutable snapshot with +// Load and keep using it for the whole request even if a reload lands midway; +// Store swaps in a freshly introspected model atomically, the reload +// mechanism PostgREST drives from SIGUSR1 and NOTIFY on the db-channel. The +// models themselves are never mutated after publication. +type Cache struct { + p atomic.Pointer[Model] +} + +// NewCache publishes the initial model. +func NewCache(m *Model) *Cache { + c := &Cache{} + c.p.Store(m) + return c +} + +// Load returns the current model snapshot. +func (c *Cache) Load() *Model { return c.p.Load() } + +// Store publishes a new model. In-flight readers keep their snapshot; the +// next Load sees the new one. +func (c *Cache) Store(m *Model) { c.p.Store(m) } diff --git a/schema/model.go b/schema/model.go index e243a02..2ccbf36 100644 --- a/schema/model.go +++ b/schema/model.go @@ -9,7 +9,6 @@ package schema import ( "slices" - "strings" ) // Kind distinguishes the relation flavors the planner cares about. @@ -28,15 +27,57 @@ type Model struct { relations map[string]*Relation // order preserves a deterministic relation order for OpenAPI and tests. order []string + // schemaComments holds the database comment on each exposed schema, the + // source of the OpenAPI info title and description (first line and rest). + schemaComments map[string]string + // declared holds relationships supplied outside the catalog: config-declared + // edges on an FK-less backend (MongoDB) and emulated computed relationships. + // The planner treats them like derived edges; a declared edge whose name + // equals a derived one overrides it (spec 09). Empty on a pure catalog model. + declared []DeclaredRel +} + +// AddDeclaredRelationship registers a relationship that does not come from a +// foreign key: a config-declared edge or an emulated computed relationship. It +// is called during introspection or config load, before the model is published. +// The planner resolves it alongside catalog edges, and it overrides a derived +// edge of the same name, the way a computed relationship overrides an +// auto-detected one in PostgREST (spec 09). +func (m *Model) AddDeclaredRelationship(d DeclaredRel) { + m.declared = append(m.declared, d) +} + +// SetSchemaComment records a schema's database comment. It is called during +// introspection, before the model is published; readers use SchemaComment. +func (m *Model) SetSchemaComment(schemaName, comment string) { + if m.schemaComments == nil { + m.schemaComments = make(map[string]string) + } + m.schemaComments[schemaName] = comment +} + +// SchemaComment returns the database comment on the named schema, or "" when +// none was recorded. +func (m *Model) SchemaComment(schemaName string) string { + return m.schemaComments[schemaName] } // Relation is one table or view in the exposed schema. type Relation struct { - Schema string - Name string - Kind Kind + Schema string + Name string + Kind Kind + // Comment is the database comment on the relation (COMMENT ON TABLE, or + // the declared-schema equivalent). The OpenAPI generator splits it into + // the operation summary (first line) and description (rest), as v14 does. + Comment string Columns []*Column PrimaryKey []string // column names forming the PK, in order; may be empty + // Unique are the relation's unique constraints, each a set of column names. A + // foreign key whose columns match the PK or one of these is one-to-one from the + // referenced side, so the reverse embed renders as an object (spec 09). An + // engine whose introspector does not read unique constraints leaves this empty. + Unique [][]string // ForeignKeys are the relation's outgoing foreign keys, the raw material the // planner resolves embeds from (spec 09). Empty on an engine without them. ForeignKeys []*ForeignKey @@ -45,20 +86,109 @@ type Relation struct { // from the FTS5 virtual tables that shadow a base table; an engine with // column-agnostic full-text (PostgreSQL's tsvector) leaves it empty. FullText []*FullTextIndex + // ViewColumns maps this view's output columns to the base-relation columns + // they project, when the relation is a view whose definition the introspector + // resolved to simple base-column references. It is empty for tables and for + // views the introspector does not project (UNIONs, expression columns). The + // model projects base-table foreign keys onto the view through it, so a view + // embeds the way the base table does (spec 09). + ViewColumns []ViewColumn + // Computed are the relation's computed fields: functions taking the relation's + // row type and returning a scalar, exposed as virtual columns the client can + // select, filter, and order by (PostgREST computed fields). The planner accepts + // their names where a real column is accepted and the compiler renders each as a + // function call on the row. Empty for engines or relations without any. + Computed []ComputedField + // ComputedRels are the relation's computed relationships: functions taking the + // relation's row type and returning a set (to-many) or a single row (to-one) of + // another relation, exposed as embeddable edges (PostgREST computed + // relationships, the escape hatch for recursive embeds). The planner resolves an + // embed name against them like a foreign-key edge and the compiler embeds by + // calling the function on the parent row. Empty for relations without any. + ComputedRels []ComputedRel - byName map[string]*Column + byName map[string]*Column + byComputed map[string]*ComputedField +} + +// ComputedRel is a function-backed embeddable edge: a function taking the parent +// relation's row type and returning rows of a target relation. Name is the edge +// name a client embeds by (the function name); FuncSchema is the schema the +// function lives in; Target names the relation its rows belong to; Card is to-many +// when the function is set-returning, to-one when it returns a single row. +type ComputedRel struct { + Name string + FuncSchema string + TargetSchema string + TargetName string + Card Card +} + +// ComputedField is a function-backed virtual column: a function taking the +// relation's row type and returning a scalar. Name is the field as the client +// selects it (the function name); FuncSchema is the schema the function lives in +// (PostgREST requires it to match the relation's schema); Type is the canonical +// return type, surfaced in OpenAPI and used for type-driven coercion. +type ComputedField struct { + Name string + FuncSchema string + Type string +} + +// ViewColumn records that a view's output column projects one base-relation +// column. The introspector emits these by parsing the view definition; the model +// uses them to carry base-table foreign keys onto the view. +type ViewColumn struct { + Name string // the view's output column name + BaseSchema string + BaseRelation string + BaseColumn string } // Column is one attribute of a relation. type Column struct { - Name string - Type string // canonical PG type name (spec 16) + Name string + Type string // canonical PG type name (spec 16) + // Comment is the database comment on the column. The OpenAPI generator + // surfaces it on the column's rowFilter parameter and ahead of the pk/fk + // notes in the definition property, matching v14. + Comment string Nullable bool HasDefault bool + // Identity reports whether the column is an auto-generated identity/serial + // column (IDENTITY on SQL Server, SERIAL/GENERATED ALWAYS AS IDENTITY on + // PostgreSQL). Backends that support explicit-identity inserts (e.g. SQL + // Server's IDENTITY_INSERT) use this to decide whether to enable it. + Identity bool // Position is the 1-based ordinal, used for stable ordering. Position int + // Rep, when non-nil, is the column's data-representation cast set: the column's + // type is a domain whose casts to and from json/text reformat the wire value + // (PostgREST domain representations, spec 11). Nil for an ordinary column. + Rep *Representation +} + +// Representation is a column's data-representation cast set (PostgREST domain +// representations, spec 11): a domain type whose casts to and from json/text +// reformat the wire value. ToJSON formats the stored value for a response, +// FromText parses a query-string filter literal, FromJSON parses a write-body +// value. A direction the domain declares no cast for is the zero FuncRef. +type Representation struct { + ToJSON FuncRef + FromText FuncRef + FromJSON FuncRef +} + +// FuncRef names a schema-qualified function. The zero value (empty Name) marks +// an absent cast in a Representation. +type FuncRef struct { + Schema string + Name string } +// IsZero reports whether the reference names no function (an absent cast). +func (f FuncRef) IsZero() bool { return f.Name == "" } + // FullTextIndex is an engine-side full-text facility covering one or more of a // relation's columns. The planner attaches the covering index to an fts predicate // so the compiler can lower the engine's match form; a backend that requires one @@ -96,6 +226,7 @@ func NewModel(rels []*Relation) *Model { } m.relations[key] = r } + m.projectViews() return m } @@ -104,6 +235,10 @@ func (r *Relation) index() { for _, c := range r.Columns { r.byName[c.Name] = c } + r.byComputed = make(map[string]*ComputedField, len(r.Computed)) + for i := range r.Computed { + r.byComputed[r.Computed[i].Name] = &r.Computed[i] + } } // Column returns the named column and whether it exists. @@ -118,6 +253,14 @@ func (r *Relation) HasColumn(name string) bool { return ok } +// ComputedFieldFor returns the named computed field and whether it exists. A +// computed field is selectable, filterable, and orderable like a real column but +// is backed by a function call rather than stored data. +func (r *Relation) ComputedFieldFor(name string) (*ComputedField, bool) { + c, ok := r.byComputed[name] + return c, ok +} + // ColumnNames returns the column names in ordinal order. It is the whole-row // projection a write returns when the client asks for the representation but // names no explicit columns. @@ -139,23 +282,39 @@ func Key(schemaName, rel string) string { return schemaName + "." + rel } -// Lookup resolves a possibly-qualified relation name. An unqualified name -// (no dot) is matched first directly, then against each schema in searchPath in -// order, mirroring PostgREST's exposed-schema / search-path resolution. +// Lookup resolves a relation name against the search path, trying each schema +// in order. Request resolution passes the single active schema (selected by +// the profile headers, defaulting to the first exposed schema), so a request +// can never reach a relation outside it: PostgREST treats the path segment as +// a bare name within the active schema, never as a qualified reference. With +// an empty searchPath the name is matched directly against the model keys, +// the mode introspection-internal callers use. func (m *Model) Lookup(name string, searchPath []string) (*Relation, bool) { - if r, ok := m.relations[name]; ok { + if len(searchPath) == 0 { + r, ok := m.relations[name] return r, ok } - if !strings.Contains(name, ".") { - for _, s := range searchPath { - if r, ok := m.relations[Key(s, name)]; ok { - return r, ok - } + for _, s := range searchPath { + if r, ok := m.relations[Key(s, name)]; ok { + return r, ok } } return nil, false } +// RelationsIn returns the relations of one schema in deterministic insertion +// order. It is the per-schema view the OpenAPI root builds its document from, +// so two same-named relations in different schemas can never collide there. +func (m *Model) RelationsIn(schemaName string) []*Relation { + var out []*Relation + for _, k := range m.order { + if r := m.relations[k]; r.Schema == schemaName { + out = append(out, r) + } + } + return out +} + // Relations returns the relations in deterministic insertion order. func (m *Model) Relations() []*Relation { out := make([]*Relation, 0, len(m.order)) diff --git a/schema/model_test.go b/schema/model_test.go index 39d732a..a2cda5c 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -40,15 +40,50 @@ func TestLookupSearchPath(t *testing.T) { t.Error("secrets should not resolve when only public is searched") } - // Qualified resolves regardless of search path. + // An empty search path matches model keys directly, the + // introspection-internal mode. if _, ok := m.Lookup("private.secrets", nil); !ok { - t.Error("private.secrets should resolve when fully qualified") + t.Error("private.secrets should resolve against the model key with no search path") } // Unknown stays unknown. if _, ok := m.Lookup("nope", []string{"public"}); ok { t.Error("unknown relation should not resolve") } + + // A request never reaches outside its active schema with a dotted name: + // the path segment is a bare name within the schema, never a qualified + // reference (PostgREST profile semantics). + if _, ok := m.Lookup("private.secrets", []string{"public"}); ok { + t.Error("a dotted name must not escape the active schema") + } + if _, ok := m.Lookup("public.users", []string{"public"}); ok { + t.Error("a dotted name must not bypass search-path keying") + } +} + +func TestRelationsIn(t *testing.T) { + m := sampleModel() + + pub := m.RelationsIn("public") + if len(pub) != 2 || pub[0].Name != "users" || pub[1].Name != "todos" { + t.Fatalf("RelationsIn(public) = %v, want [users todos]", names(pub)) + } + priv := m.RelationsIn("private") + if len(priv) != 1 || priv[0].Name != "secrets" { + t.Fatalf("RelationsIn(private) = %v, want [secrets]", names(priv)) + } + if got := m.RelationsIn("nope"); len(got) != 0 { + t.Errorf("RelationsIn(nope) = %v, want empty", names(got)) + } +} + +func names(rels []*Relation) []string { + out := make([]string, len(rels)) + for i, r := range rels { + out[i] = r.Name + } + return out } func TestColumnLookup(t *testing.T) { diff --git a/schema/relationship.go b/schema/relationship.go index 7642067..749a1c1 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -30,6 +30,22 @@ type ForeignKey struct { RefSchema string RefRelation string RefColumns []string + // SourceRelation, when set, is the base relation this foreign key was + // projected from onto a view. It makes the base table name an extra + // disambiguation hint for the view's relationship, the third hint kind + // PostgREST documents for view-sourced edges (spec 09). + SourceRelation string +} + +// hintNames is the set of disambiguation names a derived edge over this foreign +// key exposes: the constraint name, each participating column, and, for a foreign +// key projected onto a view, the base table name. +func (fk *ForeignKey) hintNames() []string { + hints := append([]string{fk.Name}, fk.Columns...) + if fk.SourceRelation != "" { + hints = append(hints, fk.SourceRelation) + } + return hints } // references reports whether this foreign key points at the given relation. @@ -51,6 +67,22 @@ type Relationship struct { JLocal []string JForeign []string + // FuncSchema and FuncName, when set, mark a computed relationship: the edge is + // not a column join but a call to FuncSchema.FuncName(parent_row), which yields + // the target rows. Local/Foreign/Junction are unused for such an edge; the + // compiler renders the function call in the embed's FROM and correlates through + // the row argument instead of a join predicate (spec 11). + FuncSchema string + FuncName string + + // Cardinality is the four-way spelling PostgREST reports in a PGRST201 details + // array ("many-to-one", "one-to-one", "one-to-many", "many-to-many"), derived + // the way upstream derives it: a forward foreign key is many-to-one, or + // one-to-one when its columns are unique on the parent; a backward foreign key + // is one-to-many, or one-to-one when unique on the target; a junction edge is + // many-to-many. Card stays the planner's two-way join shape. + Cardinality string + // hints is the set of names a disambiguation hint may match: the edge name // and each participating column. Matched case-sensitively, like PostgREST. hints []string @@ -80,71 +112,297 @@ func (m *Model) Relationships(parent *Relation, targetName string, searchPath [] // Forward: a foreign key on the parent pointing at the target is to-one. for _, fk := range parent.ForeignKeys { if fk.references(target) { + card := "many-to-one" + if isUnique(parent, fk.Columns) { + card = "one-to-one" + } out = append(out, Relationship{ - Name: fk.Name, - Card: CardToOne, - Target: target, - Local: fk.Columns, - Foreign: fk.RefColumns, - hints: append([]string{fk.Name}, fk.Columns...), + Name: fk.Name, + Card: CardToOne, + Cardinality: card, + Target: target, + Local: fk.Columns, + Foreign: fk.RefColumns, + hints: fk.hintNames(), }) } } - // Backward: a foreign key on the target pointing at the parent is to-many - // (the reverse view of the same key). + // Backward: a foreign key on the target pointing at the parent is the reverse + // view of the same key. It is to-many in general, but to-one when the FK + // columns are unique on the target (its primary key or a unique constraint), + // because then at most one target row references each parent row (spec 09). for _, fk := range target.ForeignKeys { if fk.references(parent) { + card := CardToMany + cardinality := "one-to-many" + if isUnique(target, fk.Columns) { + card = CardToOne + cardinality = "one-to-one" + } out = append(out, Relationship{ - Name: fk.Name, - Card: CardToMany, - Target: target, - Local: fk.RefColumns, - Foreign: fk.Columns, - hints: append([]string{fk.Name}, fk.Columns...), + Name: fk.Name, + Card: card, + Cardinality: cardinality, + Target: target, + Local: fk.RefColumns, + Foreign: fk.Columns, + hints: fk.hintNames(), }) } } - // Many-to-many: a junction relation with a foreign key to each side. The - // junction is not the parent or the target; its two keys supply the two hops. + // Many-to-many: a junction relation whose foreign keys to the two ends are + // part of its composite primary key. Every (toParent, toTarget) FK pair is a + // separate, hintable edge, so two keys to one end make the embed ambiguous + // rather than silently picking one (spec 09). for _, j := range m.Relations() { if j == parent || j == target { continue } - toParent, toTarget := junctionKeys(j, parent, target) - if toParent == nil || toTarget == nil { + for _, toParent := range junctionFKs(j, parent) { + for _, toTarget := range junctionFKs(j, target) { + if toParent == toTarget { + continue // a self-to-self junction needs two distinct keys + } + out = append(out, Relationship{ + Name: j.Name, + Card: CardToMany, + Cardinality: "many-to-many", + Target: target, + Local: toParent.RefColumns, + Foreign: toTarget.RefColumns, + Junction: j, + JLocal: toParent.Columns, + JForeign: toTarget.Columns, + hints: junctionHints(j, toTarget), + }) + } + } + } + + // Declared and computed edges: relationships supplied outside the catalog. A + // declared edge whose name equals a derived one overrides it, so a computed + // relationship can replace an auto-detected edge and a config-declared edge can + // disambiguate a self-referential FK that derivation leaves ambiguous (spec 09). + // Function-backed computed relationships (spec 11) join this set and override the + // same way: an edge named like a derived one replaces it. + declared := m.declaredEdges(parent, target) + declared = append(declared, computedRelEdges(parent, target)...) + if len(declared) > 0 { + overridden := make(map[string]bool, len(declared)) + for _, d := range declared { + overridden[d.Name] = true + } + kept := out[:0:0] + for _, e := range out { + if !overridden[e.Name] { + kept = append(kept, e) + } + } + kept = append(kept, declared...) + out = kept + } + + return out, true +} + +// declaredEdges returns the registered declared and computed relationships from +// parent to target as resolved Relationship values, resolving each junction +// relation against the model. An entry whose target or junction is not in the +// model is skipped rather than failing the whole resolution. +func (m *Model) declaredEdges(parent, target *Relation) []Relationship { + var out []Relationship + for _, d := range m.declared { + if d.ParentName != parent.Name || d.ParentSchema != parent.Schema { + continue + } + if d.TargetName != target.Name || d.TargetSchema != target.Schema { + continue + } + rel := Relationship{ + Name: d.Name, + Card: d.Card, + Cardinality: declaredCardinality(d.Card), + Target: target, + Local: d.Local, + Foreign: d.Foreign, + hints: append([]string{d.Name}, d.Hints...), + } + if d.JunctionName != "" { + j, ok := m.Lookup(d.JunctionName, junctionPath(d.JunctionSchema)) + if !ok { + continue + } + rel.Junction = j + rel.JLocal = d.JLocal + rel.JForeign = d.JForeign + rel.Cardinality = "many-to-many" + } + out = append(out, rel) + } + return out +} + +// ComputedRelByName resolves a computed relationship on parent by its edge name +// (the function name), inferring the target relation from the function's return +// type. PostgREST embeds a computed relationship by the function name, which need +// not equal the target relation name, so the planner cannot resolve it by the +// target-name path the way a foreign-key edge resolves. It returns the resolved +// edge and whether a computed relationship of that name exists with its target in +// the model. +func (m *Model) ComputedRelByName(parent *Relation, name string, searchPath []string) (*Relationship, bool) { + for _, cr := range parent.ComputedRels { + if cr.Name != name { + continue + } + target, ok := m.Lookup(cr.TargetName, []string{cr.TargetSchema}) + if !ok { + return nil, false + } + return &Relationship{ + Name: cr.Name, + Card: cr.Card, + Cardinality: declaredCardinality(cr.Card), + Target: target, + FuncSchema: cr.FuncSchema, + FuncName: cr.Name, + hints: []string{cr.Name}, + }, true + } + return nil, false +} + +// computedRelEdges returns the function-backed edges from parent to target: each +// computed relationship on the parent whose target is this relation, as a +// Relationship carrying the function to call instead of join columns. The edge is +// hintable by its name (the function name), matching how a derived edge is +// hintable by its constraint name. +func computedRelEdges(parent, target *Relation) []Relationship { + var out []Relationship + for _, cr := range parent.ComputedRels { + if cr.TargetName != target.Name || cr.TargetSchema != target.Schema { continue } out = append(out, Relationship{ - Name: j.Name, - Card: CardToMany, - Target: target, - Local: toParent.RefColumns, - Foreign: toTarget.RefColumns, - Junction: j, - JLocal: toParent.Columns, - JForeign: toTarget.Columns, - hints: []string{j.Name, toParent.Name, toTarget.Name}, + Name: cr.Name, + Card: cr.Card, + Cardinality: declaredCardinality(cr.Card), + Target: target, + FuncSchema: cr.FuncSchema, + FuncName: cr.Name, + hints: []string{cr.Name}, }) } + return out +} - return out, true +// junctionPath turns a declared junction schema into a one-element search path, +// or an empty path (direct key match) when the schema is unset. +func junctionPath(schemaName string) []string { + if schemaName == "" { + return nil + } + return []string{schemaName} +} + +// DeclaredRel is a relationship supplied outside the catalog: a config-declared +// edge on an FK-less backend, or an emulated computed relationship. The planner +// resolves it exactly like a derived edge, and it overrides a derived edge of the +// same name (spec 09). Local and Foreign are the parent and target join columns; +// for a many-to-many declared edge JunctionName names the junction relation and +// JLocal/JForeign are its columns on the parent and target sides. +type DeclaredRel struct { + Name string + ParentSchema string + ParentName string + TargetSchema string + TargetName string + Card Card + Local []string + Foreign []string + JunctionSchema string + JunctionName string + JLocal []string + JForeign []string + // Hints are extra names a disambiguation hint may match, beyond the edge name + // (the participating column names a computed relationship wants hintable). + Hints []string +} + +// declaredCardinality spells a declared edge's two-way Card as the PGRST201 +// four-way cardinality. A declared edge carries no parent-side uniqueness or +// direction, so a to-one edge reads as many-to-one and a to-many edge as +// one-to-many; a junction edge is set to many-to-many by the caller. +func declaredCardinality(c Card) string { + if c == CardToMany { + return "one-to-many" + } + return "many-to-one" +} + +// isUnique reports whether cols (as a set) is the relation's primary key or one +// of its unique constraints, the test that makes a referencing FK one-to-one. +func isUnique(r *Relation, cols []string) bool { + if sameColumnSet(r.PrimaryKey, cols) { + return true + } + for _, u := range r.Unique { + if sameColumnSet(u, cols) { + return true + } + } + return false +} + +// sameColumnSet reports whether two column-name lists hold the same set, +// ignoring order (constraint membership does not depend on column order). +func sameColumnSet(a, b []string) bool { + if len(a) != len(b) || len(a) == 0 { + return false + } + for _, x := range a { + if !slices.Contains(b, x) { + return false + } + } + return true } -// junctionKeys finds the two foreign keys that make j a junction between parent -// and target: one pointing at the parent and a distinct one pointing at the -// target. The distinctness guard matters for a self-referential many-to-many, -// where both keys point at the same relation. -func junctionKeys(j, parent, target *Relation) (toParent, toTarget *ForeignKey) { +// junctionHints is the hint set for a many-to-many edge: the junction name and +// the target-pointing foreign key, by its constraint name and its columns. The +// hint identifies the edge by how the junction reaches the target, which is what +// disambiguates a self-referential junction where both directions share the same +// pair of columns and only the target side differs (PostgREST). +func junctionHints(j *Relation, toTarget *ForeignKey) []string { + hints := []string{j.Name, toTarget.Name} + hints = append(hints, toTarget.Columns...) + return hints +} + +// junctionFKs returns the foreign keys on j that point at end and whose columns +// are part of j's primary key, the PostgREST rule for what makes j a junction. +// A table with an FK to a relation but no PK over those columns is an incidental +// referencing table, not a junction, so it yields no edge. +func junctionFKs(j, end *Relation) []*ForeignKey { + var out []*ForeignKey for _, fk := range j.ForeignKeys { - if toParent == nil && fk.references(parent) { - toParent = fk - continue + if fk.references(end) && isSubset(fk.Columns, j.PrimaryKey) { + out = append(out, fk) } - if toTarget == nil && fk.references(target) { - toTarget = fk + } + return out +} + +// isSubset reports whether every column in cols appears in set. +func isSubset(cols, set []string) bool { + if len(cols) == 0 { + return false + } + for _, c := range cols { + if !slices.Contains(set, c) { + return false } } - return toParent, toTarget + return true } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 998bd03..14ce3ce 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -18,9 +18,10 @@ func buildEmbedModel() *Model { } actors := &Relation{Schema: "public", Name: "actors", Columns: cols("id", "name")} roles := &Relation{ - Schema: "public", - Name: "roles", - Columns: cols("film_id", "actor_id"), + Schema: "public", + Name: "roles", + Columns: cols("film_id", "actor_id"), + PrimaryKey: []string{"film_id", "actor_id"}, // the composite PK that makes roles a junction ForeignKeys: []*ForeignKey{ {Name: "roles_film_id_fkey", Columns: []string{"film_id"}, RefSchema: "public", RefRelation: "films", RefColumns: []string{"id"}}, {Name: "roles_actor_id_fkey", Columns: []string{"actor_id"}, RefSchema: "public", RefRelation: "actors", RefColumns: []string{"id"}}, @@ -151,3 +152,275 @@ func TestRelationshipsAmbiguous(t *testing.T) { t.Errorf("writer_id hint matched %d edges, want 1", matched) } } + +// TestRelationshipsReverseToOneOnPrimaryKey covers 01.8: a foreign key whose +// columns are the referencing relation's primary key is one-to-one, so its +// reverse view renders as an object. profiles.user_id is both the PK of +// profiles and an FK to users, so a user has at most one profile. +func TestRelationshipsReverseToOneOnPrimaryKey(t *testing.T) { + users := &Relation{Schema: "public", Name: "users", Columns: cols("id", "name")} + profiles := &Relation{ + Schema: "public", + Name: "profiles", + Columns: cols("user_id", "bio"), + PrimaryKey: []string{"user_id"}, + ForeignKeys: []*ForeignKey{{ + Name: "profiles_user_id_fkey", Columns: []string{"user_id"}, + RefSchema: "public", RefRelation: "users", RefColumns: []string{"id"}, + }}, + } + m := NewModel([]*Relation{users, profiles}) + cands, _ := m.Relationships(rel(t, m, "users"), "profiles", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if cands[0].Card != CardToOne { + t.Errorf("Card = %v, want to-one (FK is the profiles PK)", cands[0].Card) + } +} + +// TestRelationshipsReverseToOneOnUniqueConstraint covers 01.8 via a unique +// constraint rather than the primary key: profiles has its own surrogate PK, +// but a UNIQUE(user_id) constraint still makes the FK one-to-one. +func TestRelationshipsReverseToOneOnUniqueConstraint(t *testing.T) { + users := &Relation{Schema: "public", Name: "users", Columns: cols("id", "name")} + profiles := &Relation{ + Schema: "public", + Name: "profiles", + Columns: cols("id", "user_id", "bio"), + PrimaryKey: []string{"id"}, + Unique: [][]string{{"user_id"}}, + ForeignKeys: []*ForeignKey{{ + Name: "profiles_user_id_fkey", Columns: []string{"user_id"}, + RefSchema: "public", RefRelation: "users", RefColumns: []string{"id"}, + }}, + } + m := NewModel([]*Relation{users, profiles}) + cands, _ := m.Relationships(rel(t, m, "users"), "profiles", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if cands[0].Card != CardToOne { + t.Errorf("Card = %v, want to-one (FK matches a unique constraint)", cands[0].Card) + } +} + +// TestRelationshipsReverseToManyWithoutUnique covers the 01.8 negative: a +// plain FK that is neither the PK nor unique stays to-many, the ordinary +// reverse-view case (a director owns many films). +func TestRelationshipsReverseToManyWithoutUnique(t *testing.T) { + m := buildEmbedModel() + cands, _ := m.Relationships(rel(t, m, "directors"), "films", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if cands[0].Card != CardToMany { + t.Errorf("Card = %v, want to-many (FK is neither PK nor unique)", cands[0].Card) + } +} + +// TestRelationshipsIncidentalReferencingTableNotJunction covers 01.9: a table +// that has foreign keys to both ends but does not key them as its primary key +// is an incidental referencing table, not a junction, so it yields no edge. +// Here log references both films and actors but keys on its own id. +func TestRelationshipsIncidentalReferencingTableNotJunction(t *testing.T) { + films := &Relation{Schema: "public", Name: "films", Columns: cols("id", "title")} + actors := &Relation{Schema: "public", Name: "actors", Columns: cols("id", "name")} + log := &Relation{ + Schema: "public", + Name: "log", + Columns: cols("id", "film_id", "actor_id"), + PrimaryKey: []string{"id"}, // keyed on its own surrogate id, not the FK pair + ForeignKeys: []*ForeignKey{ + {Name: "log_film_id_fkey", Columns: []string{"film_id"}, RefSchema: "public", RefRelation: "films", RefColumns: []string{"id"}}, + {Name: "log_actor_id_fkey", Columns: []string{"actor_id"}, RefSchema: "public", RefRelation: "actors", RefColumns: []string{"id"}}, + }, + } + m := NewModel([]*Relation{films, actors, log}) + cands, _ := m.Relationships(rel(t, m, "films"), "actors", []string{"public"}) + if len(cands) != 0 { + t.Fatalf("got %d candidates, want 0 (log is not a junction)", len(cands)) + } +} + +// TestRelationshipsJunctionWithExtraPrimaryKeyColumn covers 01.9: the FK +// columns only need to be a subset of the composite primary key, so a junction +// that adds another column to its PK (here a role discriminator) still embeds. +func TestRelationshipsJunctionWithExtraPrimaryKeyColumn(t *testing.T) { + films := &Relation{Schema: "public", Name: "films", Columns: cols("id", "title")} + actors := &Relation{Schema: "public", Name: "actors", Columns: cols("id", "name")} + roles := &Relation{ + Schema: "public", + Name: "roles", + Columns: cols("film_id", "actor_id", "character"), + PrimaryKey: []string{"film_id", "actor_id", "character"}, + ForeignKeys: []*ForeignKey{ + {Name: "roles_film_id_fkey", Columns: []string{"film_id"}, RefSchema: "public", RefRelation: "films", RefColumns: []string{"id"}}, + {Name: "roles_actor_id_fkey", Columns: []string{"actor_id"}, RefSchema: "public", RefRelation: "actors", RefColumns: []string{"id"}}, + }, + } + m := NewModel([]*Relation{films, actors, roles}) + cands, _ := m.Relationships(rel(t, m, "films"), "actors", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if cands[0].Junction == nil || cands[0].Junction.Name != "roles" { + t.Fatalf("Junction = %v, want roles", cands[0].Junction) + } +} + +// TestRelationshipsSelfJunctionTwoKeys covers 01.9 with a self-referential +// many-to-many: a friendship junction has two FKs to users, which yields two +// distinct edges (one per direction), so an unqualified embed is ambiguous and +// a column hint disambiguates it. +func TestRelationshipsSelfJunctionTwoKeys(t *testing.T) { + users := &Relation{Schema: "public", Name: "users", Columns: cols("id", "name")} + friendships := &Relation{ + Schema: "public", + Name: "friendships", + Columns: cols("user_id", "friend_id"), + PrimaryKey: []string{"user_id", "friend_id"}, + ForeignKeys: []*ForeignKey{ + {Name: "friendships_user_id_fkey", Columns: []string{"user_id"}, RefSchema: "public", RefRelation: "users", RefColumns: []string{"id"}}, + {Name: "friendships_friend_id_fkey", Columns: []string{"friend_id"}, RefSchema: "public", RefRelation: "users", RefColumns: []string{"id"}}, + }, + } + m := NewModel([]*Relation{users, friendships}) + cands, _ := m.Relationships(rel(t, m, "users"), "users", []string{"public"}) + if len(cands) != 2 { + t.Fatalf("got %d candidates, want 2 (the two junction directions)", len(cands)) + } + matched := 0 + for _, c := range cands { + if c.MatchesHint("friend_id") { + matched++ + } + } + if matched != 1 { + t.Errorf("friend_id hint matched %d edges, want 1", matched) + } +} + +// TestDeclaredRelationshipAddsEdge covers 01.10: a declared relationship makes an +// edge embeddable where no foreign key derives one. Here directors and actors +// share no key, but a declared edge connects them as the planner would resolve it. +func TestDeclaredRelationshipAddsEdge(t *testing.T) { + m := buildEmbedModel() + m.AddDeclaredRelationship(DeclaredRel{ + Name: "favorite_actor", + ParentSchema: "public", ParentName: "directors", + TargetSchema: "public", TargetName: "actors", + Card: CardToOne, + Local: []string{"id"}, + Foreign: []string{"id"}, + }) + cands, _ := m.Relationships(rel(t, m, "directors"), "actors", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if cands[0].Name != "favorite_actor" || cands[0].Card != CardToOne { + t.Errorf("edge = %q %v, want favorite_actor to-one", cands[0].Name, cands[0].Card) + } +} + +// TestDeclaredRelationshipOverridesDerived covers the 01.10 override rule: a +// computed/declared edge whose name equals a derived edge replaces it, so the +// derived cardinality and join give way to the declared one. +func TestDeclaredRelationshipOverridesDerived(t *testing.T) { + m := buildEmbedModel() + // The derived forward edge films->directors is named films_director_id_fkey + // and is to-one. Override it with a declared edge of the same name. + m.AddDeclaredRelationship(DeclaredRel{ + Name: "films_director_id_fkey", + ParentSchema: "public", ParentName: "films", + TargetSchema: "public", TargetName: "directors", + Card: CardToMany, // deliberately different from the derived to-one + Local: []string{"director_id"}, + Foreign: []string{"id"}, + }) + cands, _ := m.Relationships(rel(t, m, "films"), "directors", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1 (override, not addition)", len(cands)) + } + if cands[0].Card != CardToMany { + t.Errorf("Card = %v, want to-many (declared edge overrides derived)", cands[0].Card) + } +} + +// TestDeclaredRelationshipDisambiguatesSelfFK covers the recursive-embed escape +// hatch from 01.10: a self-referential foreign key derives forward and backward +// edges that share a hint set, so a declared edge with its own name is the only +// way to name one direction unambiguously. +func TestDeclaredRelationshipDisambiguatesSelfFK(t *testing.T) { + comments := &Relation{ + Schema: "public", + Name: "comments", + Columns: cols("id", "parent_id", "body"), + PrimaryKey: []string{"id"}, + ForeignKeys: []*ForeignKey{{ + Name: "comments_parent_id_fkey", Columns: []string{"parent_id"}, + RefSchema: "public", RefRelation: "comments", RefColumns: []string{"id"}, + }}, + } + m := NewModel([]*Relation{comments}) + // Without a declared edge the self FK yields two edges (parent and children + // views) that a hint cannot separate. + base, _ := m.Relationships(rel(t, m, "comments"), "comments", []string{"public"}) + if len(base) != 2 { + t.Fatalf("self FK derived %d edges, want 2 (the ambiguous pair)", len(base)) + } + + m.AddDeclaredRelationship(DeclaredRel{ + Name: "children", + ParentSchema: "public", ParentName: "comments", + TargetSchema: "public", TargetName: "comments", + Card: CardToMany, + Local: []string{"id"}, + Foreign: []string{"parent_id"}, + }) + cands, _ := m.Relationships(rel(t, m, "comments"), "comments", []string{"public"}) + matched := cands[:0:0] + for _, c := range cands { + if c.MatchesHint("children") { + matched = append(matched, c) + } + } + if len(matched) != 1 { + t.Fatalf("children hint matched %d edges, want 1", len(matched)) + } + if matched[0].Card != CardToMany { + t.Errorf("Card = %v, want to-many", matched[0].Card) + } +} + +// TestDeclaredManyToManyJunction covers a declared edge that crosses a junction, +// the FK-less backend's path to a many-to-many embed (spec 09). +func TestDeclaredManyToManyJunction(t *testing.T) { + authors := &Relation{Schema: "public", Name: "authors", Columns: cols("id", "name")} + books := &Relation{Schema: "public", Name: "books", Columns: cols("id", "title")} + authorship := &Relation{Schema: "public", Name: "authorship", Columns: cols("author_id", "book_id")} + m := NewModel([]*Relation{authors, books, authorship}) + m.AddDeclaredRelationship(DeclaredRel{ + Name: "books", + ParentSchema: "public", ParentName: "authors", + TargetSchema: "public", TargetName: "books", + Card: CardToMany, + Local: []string{"id"}, + Foreign: []string{"id"}, + JunctionSchema: "public", + JunctionName: "authorship", + JLocal: []string{"author_id"}, + JForeign: []string{"book_id"}, + }) + cands, _ := m.Relationships(rel(t, m, "authors"), "books", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + c := cands[0] + if c.Junction == nil || c.Junction.Name != "authorship" { + t.Fatalf("Junction = %v, want authorship", c.Junction) + } + if c.JLocal[0] != "author_id" || c.JForeign[0] != "book_id" { + t.Errorf("junction hops = %v / %v, want author_id / book_id", c.JLocal, c.JForeign) + } +} diff --git a/schema/views.go b/schema/views.go new file mode 100644 index 0000000..6d38dc7 --- /dev/null +++ b/schema/views.go @@ -0,0 +1,108 @@ +package schema + +// This file projects base-table foreign keys onto views. PostgREST makes a view +// embeddable by inferring its relationships from the base tables behind it: when +// a base table's foreign-key columns survive unchanged into the view's select +// list, the view inherits that foreign key under the view's own column names. The +// same projection makes a view embeddable from both directions, because the +// reverse view of an inherited key is resolved by the ordinary backward scan. +// View-over-view chains resolve because projection runs to a fixpoint, so an +// inner view's inherited keys are available when the outer view projects. A view +// the introspector cannot resolve to plain base columns (a UNION, an expression +// column) carries no ViewColumns and so inherits nothing, matching PostgREST. + +// projectViews carries base-table foreign keys onto every view in the model, +// using each view's column-to-base mapping. It repeats until no new key is added +// so that chains of views (a view selecting from another view) resolve, bounded +// by the relation count since each pass can only add keys. +func (m *Model) projectViews() { + for pass := 0; pass < len(m.order); pass++ { + added := false + for _, key := range m.order { + v := m.relations[key] + if v.Kind != KindView || len(v.ViewColumns) == 0 { + continue + } + if m.projectOneView(v) { + added = true + } + } + if !added { + return + } + } +} + +// projectOneView adds to view v every base-table foreign key whose columns all +// survive into v, naming the projected key's columns by the view columns that +// expose them. It reports whether it added a key this pass (a projected key is +// added once; a second pass over the same view is a no-op). +func (m *Model) projectOneView(v *Relation) bool { + // Index the view's exposure of each base relation: base (schema,rel,col) to + // the view column that projects it. A base column may surface under several + // view columns; the first is used, the way a join projection names one. + exposes := map[string]map[string]string{} // baseRelKey -> baseCol -> viewCol + for _, vc := range v.ViewColumns { + bk := Key(vc.BaseSchema, vc.BaseRelation) + cols := exposes[bk] + if cols == nil { + cols = map[string]string{} + exposes[bk] = cols + } + if _, seen := cols[vc.BaseColumn]; !seen { + cols[vc.BaseColumn] = vc.Name + } + } + + added := false + for bk, cols := range exposes { + base, ok := m.relations[bk] + if !ok { + continue + } + for _, fk := range base.ForeignKeys { + viewCols, ok := mapColumns(fk.Columns, cols) + if !ok { + continue // not every FK column survives into the view + } + if v.hasProjectedFK(fk, base) { + continue + } + v.ForeignKeys = append(v.ForeignKeys, &ForeignKey{ + Name: fk.Name, + Columns: viewCols, + RefSchema: fk.RefSchema, + RefRelation: fk.RefRelation, + RefColumns: fk.RefColumns, + SourceRelation: base.Name, + }) + added = true + } + } + return added +} + +// mapColumns translates base column names to the view columns that expose them, +// reporting ok=false if any base column is not exposed by the view. +func mapColumns(baseCols []string, exposed map[string]string) ([]string, bool) { + out := make([]string, len(baseCols)) + for i, c := range baseCols { + vc, ok := exposed[c] + if !ok { + return nil, false + } + out[i] = vc + } + return out, true +} + +// hasProjectedFK reports whether the view already carries this base foreign key, +// so a second projection pass does not duplicate it. +func (r *Relation) hasProjectedFK(fk *ForeignKey, base *Relation) bool { + for _, existing := range r.ForeignKeys { + if existing.Name == fk.Name && existing.SourceRelation == base.Name { + return true + } + } + return false +} diff --git a/schema/views_test.go b/schema/views_test.go new file mode 100644 index 0000000..0f8beba --- /dev/null +++ b/schema/views_test.go @@ -0,0 +1,156 @@ +package schema + +import "testing" + +// viewModel wires a films table with a foreign key to directors, plus a view +// film_view that projects the film columns (including director_id) one-to-one. +// The view should inherit the films->directors foreign key under its own columns. +func viewModel() *Model { + directors := &Relation{Schema: "public", Name: "directors", Columns: cols("id", "name")} + films := &Relation{ + Schema: "public", + Name: "films", + Columns: cols("id", "title", "director_id"), + PrimaryKey: []string{"id"}, + ForeignKeys: []*ForeignKey{{ + Name: "films_director_id_fkey", Columns: []string{"director_id"}, + RefSchema: "public", RefRelation: "directors", RefColumns: []string{"id"}, + }}, + } + filmView := &Relation{ + Schema: "public", + Name: "film_view", + Kind: KindView, + Columns: cols("id", "title", "director_id"), + ViewColumns: []ViewColumn{ + {Name: "id", BaseSchema: "public", BaseRelation: "films", BaseColumn: "id"}, + {Name: "title", BaseSchema: "public", BaseRelation: "films", BaseColumn: "title"}, + {Name: "director_id", BaseSchema: "public", BaseRelation: "films", BaseColumn: "director_id"}, + }, + } + return NewModel([]*Relation{directors, films, filmView}) +} + +// TestViewInheritsForwardForeignKey covers 01.11: a view that exposes the FK +// column embeds the referenced table as a to-one, the same as the base table. +func TestViewInheritsForwardForeignKey(t *testing.T) { + m := viewModel() + cands, found := m.Relationships(rel(t, m, "film_view"), "directors", []string{"public"}) + if !found { + t.Fatal("directors not found") + } + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if cands[0].Card != CardToOne { + t.Errorf("Card = %v, want to-one", cands[0].Card) + } + if cands[0].Local[0] != "director_id" { + t.Errorf("Local = %v, want [director_id]", cands[0].Local) + } +} + +// TestViewEmbeddedFromBaseTable covers the reverse direction: a base table +// embeds the view as a to-many through the projected key's reverse view. +func TestViewEmbeddedFromBaseTable(t *testing.T) { + m := viewModel() + cands, _ := m.Relationships(rel(t, m, "directors"), "film_view", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if cands[0].Card != CardToMany { + t.Errorf("Card = %v, want to-many", cands[0].Card) + } + if cands[0].Foreign[0] != "director_id" { + t.Errorf("Foreign = %v, want [director_id]", cands[0].Foreign) + } +} + +// TestViewForeignKeyAcceptsBaseTableHint covers the third hint kind: a +// view-sourced relationship accepts the base table name as a disambiguation hint. +func TestViewForeignKeyAcceptsBaseTableHint(t *testing.T) { + m := viewModel() + cands, _ := m.Relationships(rel(t, m, "film_view"), "directors", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1", len(cands)) + } + if !cands[0].MatchesHint("films") { + t.Error("view relationship should accept the base table name films as a hint") + } +} + +// TestViewWithoutFKColumnInheritsNothing covers the PostgREST condition: a view +// that drops the foreign-key column does not inherit the relationship. +func TestViewWithoutFKColumnInheritsNothing(t *testing.T) { + directors := &Relation{Schema: "public", Name: "directors", Columns: cols("id", "name")} + films := &Relation{ + Schema: "public", + Name: "films", + Columns: cols("id", "title", "director_id"), + ForeignKeys: []*ForeignKey{{ + Name: "films_director_id_fkey", Columns: []string{"director_id"}, + RefSchema: "public", RefRelation: "directors", RefColumns: []string{"id"}, + }}, + } + // The view exposes only id and title; director_id does not survive. + titlesView := &Relation{ + Schema: "public", + Name: "titles", + Kind: KindView, + Columns: cols("id", "title"), + ViewColumns: []ViewColumn{ + {Name: "id", BaseSchema: "public", BaseRelation: "films", BaseColumn: "id"}, + {Name: "title", BaseSchema: "public", BaseRelation: "films", BaseColumn: "title"}, + }, + } + m := NewModel([]*Relation{directors, films, titlesView}) + cands, _ := m.Relationships(rel(t, m, "titles"), "directors", []string{"public"}) + if len(cands) != 0 { + t.Fatalf("got %d candidates, want 0 (FK column dropped by the view)", len(cands)) + } +} + +// TestViewOverViewChainsForeignKey covers recursive resolution: a view selecting +// from another view inherits the foreign key through the chain. +func TestViewOverViewChainsForeignKey(t *testing.T) { + directors := &Relation{Schema: "public", Name: "directors", Columns: cols("id", "name")} + films := &Relation{ + Schema: "public", + Name: "films", + Columns: cols("id", "title", "director_id"), + ForeignKeys: []*ForeignKey{{ + Name: "films_director_id_fkey", Columns: []string{"director_id"}, + RefSchema: "public", RefRelation: "directors", RefColumns: []string{"id"}, + }}, + } + inner := &Relation{ + Schema: "public", + Name: "film_view", + Kind: KindView, + Columns: cols("id", "title", "director_id"), + ViewColumns: []ViewColumn{ + {Name: "id", BaseSchema: "public", BaseRelation: "films", BaseColumn: "id"}, + {Name: "title", BaseSchema: "public", BaseRelation: "films", BaseColumn: "title"}, + {Name: "director_id", BaseSchema: "public", BaseRelation: "films", BaseColumn: "director_id"}, + }, + } + // outer selects from the inner view, renaming director_id to dir. + outer := &Relation{ + Schema: "public", + Name: "film_view2", + Kind: KindView, + Columns: cols("id", "dir"), + ViewColumns: []ViewColumn{ + {Name: "id", BaseSchema: "public", BaseRelation: "film_view", BaseColumn: "id"}, + {Name: "dir", BaseSchema: "public", BaseRelation: "film_view", BaseColumn: "director_id"}, + }, + } + m := NewModel([]*Relation{directors, films, inner, outer}) + cands, _ := m.Relationships(rel(t, m, "film_view2"), "directors", []string{"public"}) + if len(cands) != 1 { + t.Fatalf("got %d candidates, want 1 (FK chained through the inner view)", len(cands)) + } + if cands[0].Local[0] != "dir" { + t.Errorf("Local = %v, want [dir] (renamed by the outer view)", cands[0].Local) + } +}