diff --git a/controlplane/admin/api.go b/controlplane/admin/api.go index d17dfae..0a87b0d 100644 --- a/controlplane/admin/api.go +++ b/controlplane/admin/api.go @@ -3,8 +3,10 @@ package admin import ( + "bytes" "encoding/json" "errors" + "io" "net/http" "time" @@ -16,6 +18,11 @@ import ( var errWarehousePayloadNotAllowed = errors.New("warehouse payload must be updated via /orgs/:id/warehouse") +// maxWarehousePutBodyBytes caps the admin PUT body. Warehouse payloads are +// under 10 KB in practice; 1 MiB leaves room for future fields while keeping +// the handler from loading unbounded input into memory. +const maxWarehousePutBodyBytes = 1 << 20 + // WorkerStatus represents a worker's current status for the API. type WorkerStatus struct { ID int `json:"id"` @@ -306,6 +313,7 @@ func managedWarehouseUpsertColumns() []string { "metadata_store_port", "metadata_store_database_name", "metadata_store_username", + "pgbouncer_enabled", "s3_provider", "s3_region", "s3_bucket", @@ -402,6 +410,7 @@ type apiHandler struct { type managedWarehouseRequest struct { WarehouseDatabase configstore.ManagedWarehouseDatabase `json:"warehouse_database"` MetadataStore configstore.ManagedWarehouseMetadataStore `json:"metadata_store"` + PgBouncer configstore.ManagedWarehousePgBouncer `json:"pgbouncer"` S3 configstore.ManagedWarehouseS3 `json:"s3"` WorkerIdentity configstore.ManagedWarehouseWorkerIdentity `json:"worker_identity"` WarehouseDatabaseCredentials configstore.SecretRef `json:"warehouse_database_credentials"` @@ -424,35 +433,14 @@ type managedWarehouseRequest struct { FailedAt *time.Time `json:"failed_at"` } -func (r managedWarehouseRequest) toManagedWarehouse() configstore.ManagedWarehouse { - return configstore.ManagedWarehouse{ - WarehouseDatabase: r.WarehouseDatabase, - MetadataStore: r.MetadataStore, - S3: r.S3, - WorkerIdentity: r.WorkerIdentity, - WarehouseDatabaseCredentials: r.WarehouseDatabaseCredentials, - MetadataStoreCredentials: r.MetadataStoreCredentials, - S3Credentials: r.S3Credentials, - RuntimeConfig: r.RuntimeConfig, - State: r.State, - StatusMessage: r.StatusMessage, - WarehouseDatabaseState: r.WarehouseDatabaseState, - WarehouseDatabaseStatusMessage: r.WarehouseDatabaseStatusMessage, - MetadataStoreState: r.MetadataStoreState, - MetadataStoreStatusMessage: r.MetadataStoreStatusMessage, - S3State: r.S3State, - S3StatusMessage: r.S3StatusMessage, - IdentityState: r.IdentityState, - IdentityStatusMessage: r.IdentityStatusMessage, - SecretsState: r.SecretsState, - SecretsStatusMessage: r.SecretsStatusMessage, - ReadyAt: r.ReadyAt, - FailedAt: r.FailedAt, - } -} - -func decodeStrictWarehouseRequest(c *gin.Context, dst *managedWarehouseRequest) error { - dec := json.NewDecoder(c.Request.Body) +// decodeStrictWarehouseRequest validates a PUT body by decoding it into +// managedWarehouseRequest with DisallowUnknownFields. This whitelists which +// top-level fields a caller may set; the actual merge is performed separately +// by unmarshaling the same body onto an existing ManagedWarehouse (see +// putManagedWarehouse) so missing keys — at any nesting level — preserve +// whatever the stored row already holds. +func decodeStrictWarehouseRequest(body []byte, dst *managedWarehouseRequest) error { + dec := json.NewDecoder(bytes.NewReader(body)) dec.DisallowUnknownFields() return dec.Decode(dst) } @@ -558,12 +546,42 @@ func (h *apiHandler) getManagedWarehouse(c *gin.Context) { func (h *apiHandler) putManagedWarehouse(c *gin.Context) { orgID := c.Param("id") + + body, err := io.ReadAll(http.MaxBytesReader(c.Writer, c.Request.Body, maxWarehousePutBodyBytes)) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Strict decode rejects unknown top-level fields and malformed JSON. We + // don't use the decoded value directly; it just gates which keys the body + // is allowed to carry. var req managedWarehouseRequest - if err := decodeStrictWarehouseRequest(c, &req); err != nil { + if err := decodeStrictWarehouseRequest(body, &req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Load any existing row, then unmarshal the body onto it. json.Unmarshal + // only overwrites fields whose JSON keys appear in the body — both at the + // top level AND within each nested struct. Callers can therefore PATCH a + // single field (e.g. `{"metadata_store":{"database_name":"x"}}`) without + // wiping sibling fields. Note: concurrent PUTs on the same org can still + // interleave (read-modify-write across two store calls); the admin API is + // low-frequency enough that we accept this for now. + existing, err := h.store.GetManagedWarehouse(orgID) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + var warehouse configstore.ManagedWarehouse + if err == nil { + warehouse = *existing + } + if err := json.Unmarshal(body, &warehouse); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - warehouse := req.toManagedWarehouse() cfgView := &configstore.ManagedWarehouseConfig{ OrgID: orgID, WarehouseDatabase: warehouse.WarehouseDatabase, diff --git a/controlplane/admin/api_test.go b/controlplane/admin/api_test.go index f4df52f..ade4f61 100644 --- a/controlplane/admin/api_test.go +++ b/controlplane/admin/api_test.go @@ -437,6 +437,127 @@ func TestPutWarehouseUpsertsForExistingOrg(t *testing.T) { } } +func TestPutWarehouseMergesPartialUpdateIntoExistingWarehouse(t *testing.T) { + store := newFakeAPIStore() + seedOrgWithWarehouse(store, "analytics") + router := newTestAPIRouter(store) + + body := []byte(`{ + "pgbouncer": { + "enabled": true + } + }`) + + req := httptest.NewRequest(http.MethodPut, "/api/v1/orgs/analytics/warehouse", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + warehouse := store.warehouses["analytics"] + if warehouse == nil { + t.Fatal("expected stored warehouse") + } + if !warehouse.PgBouncer.Enabled { + t.Fatal("expected pgbouncer to be enabled") + } + if warehouse.MetadataStore.DatabaseName != "analytics_metadata" { + t.Fatalf("expected metadata db analytics_metadata, got %q", warehouse.MetadataStore.DatabaseName) + } + if warehouse.S3.Bucket != "analytics-bucket" { + t.Fatalf("expected s3 bucket analytics-bucket, got %q", warehouse.S3.Bucket) + } + if warehouse.MetadataStoreCredentials.Name != "analytics-metadata" { + t.Fatalf("expected metadata secret analytics-metadata, got %q", warehouse.MetadataStoreCredentials.Name) + } + if warehouse.State != configstore.ManagedWarehouseStateReady { + t.Fatalf("expected state ready, got %q", warehouse.State) + } +} + +func TestPutWarehouseDisablesPgBouncerWhenSetToFalse(t *testing.T) { + store := newFakeAPIStore() + seedOrgWithWarehouse(store, "analytics") + store.warehouses["analytics"].PgBouncer = configstore.ManagedWarehousePgBouncer{Enabled: true} + router := newTestAPIRouter(store) + + body := []byte(`{"pgbouncer": {"enabled": false}}`) + req := httptest.NewRequest(http.MethodPut, "/api/v1/orgs/analytics/warehouse", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + if store.warehouses["analytics"].PgBouncer.Enabled { + t.Fatal("expected pgbouncer to be disabled after PUT with enabled=false") + } +} + +func TestPutWarehousePreservesNestedFieldsOnPartialUpdate(t *testing.T) { + store := newFakeAPIStore() + seedOrgWithWarehouse(store, "analytics") + router := newTestAPIRouter(store) + + // Send only one inner field. Every other metadata_store field must stay + // as seeded — confirms the merge is nested-aware, not whole-struct replace. + body := []byte(`{"metadata_store": {"database_name": "renamed_metadata"}}`) + req := httptest.NewRequest(http.MethodPut, "/api/v1/orgs/analytics/warehouse", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + got := store.warehouses["analytics"].MetadataStore + if got.DatabaseName != "renamed_metadata" { + t.Fatalf("database_name = %q, want renamed_metadata", got.DatabaseName) + } + if got.Endpoint != "analytics-metadata.cluster.example" { + t.Fatalf("endpoint = %q, want analytics-metadata.cluster.example (nested fields were wiped)", got.Endpoint) + } + if got.Region != "us-east-1" { + t.Fatalf("region = %q, want us-east-1", got.Region) + } + if got.Port != 5432 { + t.Fatalf("port = %d, want 5432", got.Port) + } + if got.Kind != "dedicated_rds" { + t.Fatalf("kind = %q, want dedicated_rds", got.Kind) + } + if got.Engine != "postgres" { + t.Fatalf("engine = %q, want postgres", got.Engine) + } + if got.Username != "metadata_user" { + t.Fatalf("username = %q, want metadata_user", got.Username) + } +} + +func TestPutWarehouseRejectsOversizedBody(t *testing.T) { + store := newFakeAPIStore() + seedOrgWithWarehouse(store, "analytics") + router := newTestAPIRouter(store) + + // Pad the body past the 1 MiB cap inside a valid top-level field so the + // reader errors on size rather than JSON parsing. + oversized := strings.Repeat("a", (1<<20)+1024) + body := []byte(`{"status_message": "` + oversized + `"}`) + req := httptest.NewRequest(http.MethodPut, "/api/v1/orgs/analytics/warehouse", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } +} + func TestPutWarehouseRejectsSecretRefsOutsideTenantScope(t *testing.T) { store := newFakeAPIStore() store.orgs["analytics"] = &configstore.Org{Name: "analytics"} diff --git a/controlplane/configstore/models.go b/controlplane/configstore/models.go index a95b395..438732e 100644 --- a/controlplane/configstore/models.go +++ b/controlplane/configstore/models.go @@ -71,6 +71,16 @@ type ManagedWarehouseMetadataStore struct { Username string `gorm:"size:255" json:"username"` } +// ManagedWarehousePgBouncer captures per-org opt-in state for the per-Duckling +// PgBouncer pooler provisioned by the Crossplane composition. When Enabled is +// true, the provisioner controller sets spec.metadataStore.pgbouncer.enabled +// on the Duckling CR at creation time; worker DSN routing through the pooler +// is driven by status.metadataStore.pgbouncerEndpoint (populated by the +// composition once the pooler Service is up). +type ManagedWarehousePgBouncer struct { + Enabled bool `json:"enabled"` +} + // ManagedWarehouseS3 stores object-store metadata for an org's warehouse. type ManagedWarehouseS3 struct { Provider string `gorm:"size:64" json:"provider"` @@ -99,6 +109,7 @@ type ManagedWarehouse struct { WarehouseDatabase ManagedWarehouseDatabase `gorm:"embedded;embeddedPrefix:warehouse_database_" json:"warehouse_database"` MetadataStore ManagedWarehouseMetadataStore `gorm:"embedded;embeddedPrefix:metadata_store_" json:"metadata_store"` + PgBouncer ManagedWarehousePgBouncer `gorm:"embedded;embeddedPrefix:pgbouncer_" json:"pgbouncer"` S3 ManagedWarehouseS3 `gorm:"embedded;embeddedPrefix:s3_" json:"s3"` WorkerIdentity ManagedWarehouseWorkerIdentity `gorm:"embedded;embeddedPrefix:worker_identity_" json:"worker_identity"` @@ -299,6 +310,7 @@ type ManagedWarehouseConfig struct { WarehouseDatabase ManagedWarehouseDatabase MetadataStore ManagedWarehouseMetadataStore + PgBouncer ManagedWarehousePgBouncer S3 ManagedWarehouseS3 WorkerIdentity ManagedWarehouseWorkerIdentity @@ -335,6 +347,7 @@ func copyManagedWarehouseConfig(warehouse *ManagedWarehouse) *ManagedWarehouseCo AuroraMaxACU: warehouse.AuroraMaxACU, WarehouseDatabase: warehouse.WarehouseDatabase, MetadataStore: warehouse.MetadataStore, + PgBouncer: warehouse.PgBouncer, S3: warehouse.S3, WorkerIdentity: warehouse.WorkerIdentity, WarehouseDatabaseCredentials: warehouse.WarehouseDatabaseCredentials, diff --git a/controlplane/org_activation_test.go b/controlplane/org_activation_test.go index 836beda..3a19bdf 100644 --- a/controlplane/org_activation_test.go +++ b/controlplane/org_activation_test.go @@ -15,6 +15,59 @@ import ( "k8s.io/client-go/kubernetes/fake" ) +func TestDucklingMetadataStoreAddressUsesDirectEndpointByDefault(t *testing.T) { + status := &provisioner.DucklingStatus{} + status.MetadataStore.Endpoint = "direct-aurora.example.internal" + + host, port, viaPgBouncer, err := ducklingMetadataStoreAddress(status, "analytics") + if err != nil { + t.Fatalf("ducklingMetadataStoreAddress: %v", err) + } + if host != "direct-aurora.example.internal" { + t.Fatalf("host = %q, want direct endpoint", host) + } + if port != 5432 { + t.Fatalf("port = %d, want 5432", port) + } + if viaPgBouncer { + t.Fatal("expected viaPgBouncer=false when no pooler endpoint is present") + } +} + +func TestDucklingMetadataStoreAddressPrefersPgBouncerEndpoint(t *testing.T) { + status := &provisioner.DucklingStatus{} + status.MetadataStore.Endpoint = "direct-aurora.example.internal" + status.MetadataStore.PgBouncerEndpoint = "pooler.ducklings.svc.cluster.local:6543" + + host, port, viaPgBouncer, err := ducklingMetadataStoreAddress(status, "analytics") + if err != nil { + t.Fatalf("ducklingMetadataStoreAddress: %v", err) + } + if host != "pooler.ducklings.svc.cluster.local" { + t.Fatalf("host = %q, want pooler host", host) + } + if port != 6543 { + t.Fatalf("port = %d, want 6543", port) + } + if !viaPgBouncer { + t.Fatal("expected viaPgBouncer=true when pooler endpoint is present") + } +} + +func TestDucklingMetadataStoreAddressRejectsInvalidPgBouncerEndpoint(t *testing.T) { + status := &provisioner.DucklingStatus{} + status.MetadataStore.Endpoint = "direct-aurora.example.internal" + status.MetadataStore.PgBouncerEndpoint = "not-a-host-port" + + _, _, _, err := ducklingMetadataStoreAddress(status, "analytics") + if err == nil { + t.Fatal("expected invalid pgbouncer endpoint to fail") + } + if !strings.Contains(err.Error(), "parse pgbouncerEndpoint") { + t.Fatalf("expected parse error, got %v", err) + } +} + func TestSharedWorkerActivatorBuildsActivationRequestFromManagedWarehouse(t *testing.T) { clientset := fake.NewSimpleClientset( &corev1.Secret{ diff --git a/controlplane/provisioner/controller.go b/controlplane/provisioner/controller.go index 162f603..4d7aa53 100644 --- a/controlplane/provisioner/controller.go +++ b/controlplane/provisioner/controller.go @@ -119,8 +119,12 @@ func (c *Controller) reconcilePending(ctx context.Context, w *configstore.Manage } // Create the Duckling CR - log.Info("Creating Duckling CR.") - if err := c.duckling.Create(ctx, w.OrgID, w.AuroraMinACU, w.AuroraMaxACU); err != nil { + log.Info("Creating Duckling CR.", "pgbouncer_enabled", w.PgBouncer.Enabled) + if err := c.duckling.Create(ctx, w.OrgID, CreateOptions{ + MinACU: w.AuroraMinACU, + MaxACU: w.AuroraMaxACU, + PgBouncerEnabled: w.PgBouncer.Enabled, + }); err != nil { log.Error("Failed to create Duckling CR.", "error", err) _ = c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStatePending, map[string]interface{}{ "state": configstore.ManagedWarehouseStateFailed, diff --git a/controlplane/provisioner/controller_test.go b/controlplane/provisioner/controller_test.go index 4515746..e06db4b 100644 --- a/controlplane/provisioner/controller_test.go +++ b/controlplane/provisioner/controller_test.go @@ -164,6 +164,41 @@ func TestReconcilePendingCreatesCR(t *testing.T) { if fs.warehouses["org-a"].ProvisioningStartedAt == nil { t.Fatal("expected provisioning_started_at to be set") } + + // Default path has PgBouncer.Enabled=false — no pgbouncer block should appear. + if _, present := metadataStore["pgbouncer"]; present { + t.Fatalf("expected no pgbouncer block when disabled, got %v", metadataStore["pgbouncer"]) + } +} + +func TestReconcilePendingEmitsPgBouncerBlock(t *testing.T) { + dc, fakeK8s := newFakeDucklingClient() + fs := newFakeStore() + fs.warehouses["org-pgb"] = &configstore.ManagedWarehouse{ + OrgID: "org-pgb", + State: configstore.ManagedWarehouseStatePending, + AuroraMinACU: 0.5, + AuroraMaxACU: 4, + PgBouncer: configstore.ManagedWarehousePgBouncer{Enabled: true}, + } + + ctrl := NewControllerWithClient(fs, dc, time.Second) + ctx := context.Background() + ctrl.reconcile(ctx) + + cr, err := fakeK8s.Resource(ducklingGVR).Namespace(ducklingNamespace).Get(ctx, ducklingName("org-pgb"), metav1.GetOptions{}) + if err != nil { + t.Fatalf("expected CR to exist: %v", err) + } + spec := cr.Object["spec"].(map[string]interface{}) + metadataStore := spec["metadataStore"].(map[string]interface{}) + pgb, ok := metadataStore["pgbouncer"].(map[string]interface{}) + if !ok { + t.Fatalf("expected pgbouncer block in metadataStore, got %v", metadataStore) + } + if pgb["enabled"] != true { + t.Fatalf("expected pgbouncer.enabled=true, got %v", pgb["enabled"]) + } } func TestReconcileProvisioningAllReady(t *testing.T) { diff --git a/controlplane/provisioner/k8s_client.go b/controlplane/provisioner/k8s_client.go index bebac91..f037b99 100644 --- a/controlplane/provisioner/k8s_client.go +++ b/controlplane/provisioner/k8s_client.go @@ -74,9 +74,28 @@ func ducklingName(orgID string) string { return strings.ReplaceAll(orgID, "-", "") } +// CreateOptions carries per-org knobs that shape the generated Duckling CR. +type CreateOptions struct { + MinACU float64 + MaxACU float64 + PgBouncerEnabled bool +} + // Create creates a Duckling CR for the given org. -func (d *DucklingClient) Create(ctx context.Context, orgID string, minACU, maxACU float64) error { +func (d *DucklingClient) Create(ctx context.Context, orgID string, opts CreateOptions) error { name := ducklingName(orgID) + metadataStore := map[string]interface{}{ + "type": "aurora", + "aurora": map[string]interface{}{ + "minACU": opts.MinACU, + "maxACU": opts.MaxACU, + }, + } + if opts.PgBouncerEnabled { + metadataStore["pgbouncer"] = map[string]interface{}{ + "enabled": true, + } + } cr := &unstructured.Unstructured{ Object: map[string]interface{}{ "apiVersion": "k8s.posthog.com/v1alpha1", @@ -86,13 +105,7 @@ func (d *DucklingClient) Create(ctx context.Context, orgID string, minACU, maxAC "namespace": ducklingNamespace, }, "spec": map[string]interface{}{ - "metadataStore": map[string]interface{}{ - "type": "aurora", - "aurora": map[string]interface{}{ - "minACU": minACU, - "maxACU": maxACU, - }, - }, + "metadataStore": metadataStore, "dataStore": map[string]interface{}{ "type": "s3bucket", }, diff --git a/controlplane/shared_worker_activator.go b/controlplane/shared_worker_activator.go index c654644..707d33d 100644 --- a/controlplane/shared_worker_activator.go +++ b/controlplane/shared_worker_activator.go @@ -195,25 +195,9 @@ func (a *SharedWorkerActivator) buildDuckLakeConfigFromDuckling(ctx context.Cont return server.DuckLakeConfig{}, fmt.Errorf("duckling CR %q has no data store bucket", orgID) } - // Prefer the PgBouncer endpoint when the Duckling exposes one — the - // Crossplane composition sets status.metadataStore.pgbouncerEndpoint - // (as ":") when a per-Duckling pooler is provisioned. - // Otherwise connect directly to the metadata store on its default port. - host := status.MetadataStore.Endpoint - port := 5432 // Aurora always uses 5432 - viaPgBouncer := false - if pgb := status.MetadataStore.PgBouncerEndpoint; pgb != "" { - h, p, err := net.SplitHostPort(pgb) - if err != nil { - return server.DuckLakeConfig{}, fmt.Errorf("parse pgbouncerEndpoint %q for org %q: %w", pgb, orgID, err) - } - portNum, err := strconv.Atoi(p) - if err != nil { - return server.DuckLakeConfig{}, fmt.Errorf("parse pgbouncerEndpoint port %q for org %q: %w", p, orgID, err) - } - host = h - port = portNum - viaPgBouncer = true + host, port, viaPgBouncer, err := ducklingMetadataStoreAddress(status, orgID) + if err != nil { + return server.DuckLakeConfig{}, err } dl := server.DuckLakeConfig{ @@ -250,6 +234,29 @@ func (a *SharedWorkerActivator) buildDuckLakeConfigFromDuckling(ctx context.Cont return dl, nil } +func ducklingMetadataStoreAddress(status *provisioner.DucklingStatus, orgID string) (host string, port int, viaPgBouncer bool, err error) { + host = status.MetadataStore.Endpoint + port = 5432 // Aurora always uses 5432 + + // Prefer the PgBouncer endpoint when the Duckling exposes one — the + // Crossplane composition sets status.metadataStore.pgbouncerEndpoint + // (as ":") when a per-Duckling pooler is provisioned. + pgb := status.MetadataStore.PgBouncerEndpoint + if pgb == "" { + return host, port, false, nil + } + + h, p, err := net.SplitHostPort(pgb) + if err != nil { + return "", 0, false, fmt.Errorf("parse pgbouncerEndpoint %q for org %q: %w", pgb, orgID, err) + } + portNum, err := strconv.Atoi(p) + if err != nil { + return "", 0, false, fmt.Errorf("parse pgbouncerEndpoint port %q for org %q: %w", p, orgID, err) + } + return h, portNum, true, nil +} + // buildDuckLakeConfigFromConfigStore reads infrastructure details from the config store // and K8s Secrets. Used for non-Crossplane warehouses (manual seed, MinIO, etc.). func (a *SharedWorkerActivator) buildDuckLakeConfigFromConfigStore(ctx context.Context, warehouse *configstore.ManagedWarehouseConfig) (server.DuckLakeConfig, error) {