Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 49 additions & 31 deletions controlplane/admin/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
package admin

import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"time"

Expand All @@ -16,6 +18,11 @@ import (

var errWarehousePayloadNotAllowed = errors.New("warehouse payload must be updated via /orgs/:id/warehouse")

// maxWarehousePutBodyBytes caps the admin PUT body. Warehouse payloads are
// under 10 KB in practice; 1 MiB leaves room for future fields while keeping
// the handler from loading unbounded input into memory.
const maxWarehousePutBodyBytes = 1 << 20

// WorkerStatus represents a worker's current status for the API.
type WorkerStatus struct {
ID int `json:"id"`
Expand Down Expand Up @@ -306,6 +313,7 @@ func managedWarehouseUpsertColumns() []string {
"metadata_store_port",
"metadata_store_database_name",
"metadata_store_username",
"pgbouncer_enabled",
"s3_provider",
"s3_region",
"s3_bucket",
Expand Down Expand Up @@ -402,6 +410,7 @@ type apiHandler struct {
type managedWarehouseRequest struct {
WarehouseDatabase configstore.ManagedWarehouseDatabase `json:"warehouse_database"`
MetadataStore configstore.ManagedWarehouseMetadataStore `json:"metadata_store"`
PgBouncer configstore.ManagedWarehousePgBouncer `json:"pgbouncer"`
S3 configstore.ManagedWarehouseS3 `json:"s3"`
WorkerIdentity configstore.ManagedWarehouseWorkerIdentity `json:"worker_identity"`
WarehouseDatabaseCredentials configstore.SecretRef `json:"warehouse_database_credentials"`
Expand All @@ -424,35 +433,14 @@ type managedWarehouseRequest struct {
FailedAt *time.Time `json:"failed_at"`
}

func (r managedWarehouseRequest) toManagedWarehouse() configstore.ManagedWarehouse {
return configstore.ManagedWarehouse{
WarehouseDatabase: r.WarehouseDatabase,
MetadataStore: r.MetadataStore,
S3: r.S3,
WorkerIdentity: r.WorkerIdentity,
WarehouseDatabaseCredentials: r.WarehouseDatabaseCredentials,
MetadataStoreCredentials: r.MetadataStoreCredentials,
S3Credentials: r.S3Credentials,
RuntimeConfig: r.RuntimeConfig,
State: r.State,
StatusMessage: r.StatusMessage,
WarehouseDatabaseState: r.WarehouseDatabaseState,
WarehouseDatabaseStatusMessage: r.WarehouseDatabaseStatusMessage,
MetadataStoreState: r.MetadataStoreState,
MetadataStoreStatusMessage: r.MetadataStoreStatusMessage,
S3State: r.S3State,
S3StatusMessage: r.S3StatusMessage,
IdentityState: r.IdentityState,
IdentityStatusMessage: r.IdentityStatusMessage,
SecretsState: r.SecretsState,
SecretsStatusMessage: r.SecretsStatusMessage,
ReadyAt: r.ReadyAt,
FailedAt: r.FailedAt,
}
}

func decodeStrictWarehouseRequest(c *gin.Context, dst *managedWarehouseRequest) error {
dec := json.NewDecoder(c.Request.Body)
// decodeStrictWarehouseRequest validates a PUT body by decoding it into
// managedWarehouseRequest with DisallowUnknownFields. This whitelists which
// top-level fields a caller may set; the actual merge is performed separately
// by unmarshaling the same body onto an existing ManagedWarehouse (see
// putManagedWarehouse) so missing keys — at any nesting level — preserve
// whatever the stored row already holds.
func decodeStrictWarehouseRequest(body []byte, dst *managedWarehouseRequest) error {
dec := json.NewDecoder(bytes.NewReader(body))
dec.DisallowUnknownFields()
return dec.Decode(dst)
}
Expand Down Expand Up @@ -558,12 +546,42 @@ func (h *apiHandler) getManagedWarehouse(c *gin.Context) {

func (h *apiHandler) putManagedWarehouse(c *gin.Context) {
orgID := c.Param("id")

body, err := io.ReadAll(http.MaxBytesReader(c.Writer, c.Request.Body, maxWarehousePutBodyBytes))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

// Strict decode rejects unknown top-level fields and malformed JSON. We
// don't use the decoded value directly; it just gates which keys the body
// is allowed to carry.
var req managedWarehouseRequest
if err := decodeStrictWarehouseRequest(c, &req); err != nil {
if err := decodeStrictWarehouseRequest(body, &req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

// Load any existing row, then unmarshal the body onto it. json.Unmarshal
// only overwrites fields whose JSON keys appear in the body — both at the
// top level AND within each nested struct. Callers can therefore PATCH a
// single field (e.g. `{"metadata_store":{"database_name":"x"}}`) without
// wiping sibling fields. Note: concurrent PUTs on the same org can still
// interleave (read-modify-write across two store calls); the admin API is
// low-frequency enough that we accept this for now.
existing, err := h.store.GetManagedWarehouse(orgID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var warehouse configstore.ManagedWarehouse
if err == nil {
warehouse = *existing
}
if err := json.Unmarshal(body, &warehouse); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
warehouse := req.toManagedWarehouse()
cfgView := &configstore.ManagedWarehouseConfig{
OrgID: orgID,
WarehouseDatabase: warehouse.WarehouseDatabase,
Expand Down
121 changes: 121 additions & 0 deletions controlplane/admin/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,127 @@ func TestPutWarehouseUpsertsForExistingOrg(t *testing.T) {
}
}

func TestPutWarehouseMergesPartialUpdateIntoExistingWarehouse(t *testing.T) {
store := newFakeAPIStore()
seedOrgWithWarehouse(store, "analytics")
router := newTestAPIRouter(store)

body := []byte(`{
"pgbouncer": {
"enabled": true
}
}`)

req := httptest.NewRequest(http.MethodPut, "/api/v1/orgs/analytics/warehouse", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}

warehouse := store.warehouses["analytics"]
if warehouse == nil {
t.Fatal("expected stored warehouse")
}
if !warehouse.PgBouncer.Enabled {
t.Fatal("expected pgbouncer to be enabled")
}
if warehouse.MetadataStore.DatabaseName != "analytics_metadata" {
t.Fatalf("expected metadata db analytics_metadata, got %q", warehouse.MetadataStore.DatabaseName)
}
if warehouse.S3.Bucket != "analytics-bucket" {
t.Fatalf("expected s3 bucket analytics-bucket, got %q", warehouse.S3.Bucket)
}
if warehouse.MetadataStoreCredentials.Name != "analytics-metadata" {
t.Fatalf("expected metadata secret analytics-metadata, got %q", warehouse.MetadataStoreCredentials.Name)
}
if warehouse.State != configstore.ManagedWarehouseStateReady {
t.Fatalf("expected state ready, got %q", warehouse.State)
}
}

func TestPutWarehouseDisablesPgBouncerWhenSetToFalse(t *testing.T) {
store := newFakeAPIStore()
seedOrgWithWarehouse(store, "analytics")
store.warehouses["analytics"].PgBouncer = configstore.ManagedWarehousePgBouncer{Enabled: true}
router := newTestAPIRouter(store)

body := []byte(`{"pgbouncer": {"enabled": false}}`)
req := httptest.NewRequest(http.MethodPut, "/api/v1/orgs/analytics/warehouse", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}
if store.warehouses["analytics"].PgBouncer.Enabled {
t.Fatal("expected pgbouncer to be disabled after PUT with enabled=false")
}
}

func TestPutWarehousePreservesNestedFieldsOnPartialUpdate(t *testing.T) {
store := newFakeAPIStore()
seedOrgWithWarehouse(store, "analytics")
router := newTestAPIRouter(store)

// Send only one inner field. Every other metadata_store field must stay
// as seeded — confirms the merge is nested-aware, not whole-struct replace.
body := []byte(`{"metadata_store": {"database_name": "renamed_metadata"}}`)
req := httptest.NewRequest(http.MethodPut, "/api/v1/orgs/analytics/warehouse", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusOK, rec.Body.String())
}

got := store.warehouses["analytics"].MetadataStore
if got.DatabaseName != "renamed_metadata" {
t.Fatalf("database_name = %q, want renamed_metadata", got.DatabaseName)
}
if got.Endpoint != "analytics-metadata.cluster.example" {
t.Fatalf("endpoint = %q, want analytics-metadata.cluster.example (nested fields were wiped)", got.Endpoint)
}
if got.Region != "us-east-1" {
t.Fatalf("region = %q, want us-east-1", got.Region)
}
if got.Port != 5432 {
t.Fatalf("port = %d, want 5432", got.Port)
}
if got.Kind != "dedicated_rds" {
t.Fatalf("kind = %q, want dedicated_rds", got.Kind)
}
if got.Engine != "postgres" {
t.Fatalf("engine = %q, want postgres", got.Engine)
}
if got.Username != "metadata_user" {
t.Fatalf("username = %q, want metadata_user", got.Username)
}
}

func TestPutWarehouseRejectsOversizedBody(t *testing.T) {
store := newFakeAPIStore()
seedOrgWithWarehouse(store, "analytics")
router := newTestAPIRouter(store)

// Pad the body past the 1 MiB cap inside a valid top-level field so the
// reader errors on size rather than JSON parsing.
oversized := strings.Repeat("a", (1<<20)+1024)
body := []byte(`{"status_message": "` + oversized + `"}`)
req := httptest.NewRequest(http.MethodPut, "/api/v1/orgs/analytics/warehouse", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)

if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d: %s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
}

func TestPutWarehouseRejectsSecretRefsOutsideTenantScope(t *testing.T) {
store := newFakeAPIStore()
store.orgs["analytics"] = &configstore.Org{Name: "analytics"}
Expand Down
13 changes: 13 additions & 0 deletions controlplane/configstore/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ type ManagedWarehouseMetadataStore struct {
Username string `gorm:"size:255" json:"username"`
}

// ManagedWarehousePgBouncer captures per-org opt-in state for the per-Duckling
// PgBouncer pooler provisioned by the Crossplane composition. When Enabled is
// true, the provisioner controller sets spec.metadataStore.pgbouncer.enabled
// on the Duckling CR at creation time; worker DSN routing through the pooler
// is driven by status.metadataStore.pgbouncerEndpoint (populated by the
// composition once the pooler Service is up).
type ManagedWarehousePgBouncer struct {
Enabled bool `json:"enabled"`
}

// ManagedWarehouseS3 stores object-store metadata for an org's warehouse.
type ManagedWarehouseS3 struct {
Provider string `gorm:"size:64" json:"provider"`
Expand Down Expand Up @@ -99,6 +109,7 @@ type ManagedWarehouse struct {

WarehouseDatabase ManagedWarehouseDatabase `gorm:"embedded;embeddedPrefix:warehouse_database_" json:"warehouse_database"`
MetadataStore ManagedWarehouseMetadataStore `gorm:"embedded;embeddedPrefix:metadata_store_" json:"metadata_store"`
PgBouncer ManagedWarehousePgBouncer `gorm:"embedded;embeddedPrefix:pgbouncer_" json:"pgbouncer"`
S3 ManagedWarehouseS3 `gorm:"embedded;embeddedPrefix:s3_" json:"s3"`
WorkerIdentity ManagedWarehouseWorkerIdentity `gorm:"embedded;embeddedPrefix:worker_identity_" json:"worker_identity"`

Expand Down Expand Up @@ -299,6 +310,7 @@ type ManagedWarehouseConfig struct {

WarehouseDatabase ManagedWarehouseDatabase
MetadataStore ManagedWarehouseMetadataStore
PgBouncer ManagedWarehousePgBouncer
S3 ManagedWarehouseS3
WorkerIdentity ManagedWarehouseWorkerIdentity

Expand Down Expand Up @@ -335,6 +347,7 @@ func copyManagedWarehouseConfig(warehouse *ManagedWarehouse) *ManagedWarehouseCo
AuroraMaxACU: warehouse.AuroraMaxACU,
WarehouseDatabase: warehouse.WarehouseDatabase,
MetadataStore: warehouse.MetadataStore,
PgBouncer: warehouse.PgBouncer,
S3: warehouse.S3,
WorkerIdentity: warehouse.WorkerIdentity,
WarehouseDatabaseCredentials: warehouse.WarehouseDatabaseCredentials,
Expand Down
53 changes: 53 additions & 0 deletions controlplane/org_activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,59 @@ import (
"k8s.io/client-go/kubernetes/fake"
)

func TestDucklingMetadataStoreAddressUsesDirectEndpointByDefault(t *testing.T) {
status := &provisioner.DucklingStatus{}
status.MetadataStore.Endpoint = "direct-aurora.example.internal"

host, port, viaPgBouncer, err := ducklingMetadataStoreAddress(status, "analytics")
if err != nil {
t.Fatalf("ducklingMetadataStoreAddress: %v", err)
}
if host != "direct-aurora.example.internal" {
t.Fatalf("host = %q, want direct endpoint", host)
}
if port != 5432 {
t.Fatalf("port = %d, want 5432", port)
}
if viaPgBouncer {
t.Fatal("expected viaPgBouncer=false when no pooler endpoint is present")
}
}

func TestDucklingMetadataStoreAddressPrefersPgBouncerEndpoint(t *testing.T) {
status := &provisioner.DucklingStatus{}
status.MetadataStore.Endpoint = "direct-aurora.example.internal"
status.MetadataStore.PgBouncerEndpoint = "pooler.ducklings.svc.cluster.local:6543"

host, port, viaPgBouncer, err := ducklingMetadataStoreAddress(status, "analytics")
if err != nil {
t.Fatalf("ducklingMetadataStoreAddress: %v", err)
}
if host != "pooler.ducklings.svc.cluster.local" {
t.Fatalf("host = %q, want pooler host", host)
}
if port != 6543 {
t.Fatalf("port = %d, want 6543", port)
}
if !viaPgBouncer {
t.Fatal("expected viaPgBouncer=true when pooler endpoint is present")
}
}

func TestDucklingMetadataStoreAddressRejectsInvalidPgBouncerEndpoint(t *testing.T) {
status := &provisioner.DucklingStatus{}
status.MetadataStore.Endpoint = "direct-aurora.example.internal"
status.MetadataStore.PgBouncerEndpoint = "not-a-host-port"

_, _, _, err := ducklingMetadataStoreAddress(status, "analytics")
if err == nil {
t.Fatal("expected invalid pgbouncer endpoint to fail")
}
if !strings.Contains(err.Error(), "parse pgbouncerEndpoint") {
t.Fatalf("expected parse error, got %v", err)
}
}

func TestSharedWorkerActivatorBuildsActivationRequestFromManagedWarehouse(t *testing.T) {
clientset := fake.NewSimpleClientset(
&corev1.Secret{
Expand Down
8 changes: 6 additions & 2 deletions controlplane/provisioner/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,12 @@ func (c *Controller) reconcilePending(ctx context.Context, w *configstore.Manage
}

// Create the Duckling CR
log.Info("Creating Duckling CR.")
if err := c.duckling.Create(ctx, w.OrgID, w.AuroraMinACU, w.AuroraMaxACU); err != nil {
log.Info("Creating Duckling CR.", "pgbouncer_enabled", w.PgBouncer.Enabled)
if err := c.duckling.Create(ctx, w.OrgID, CreateOptions{
MinACU: w.AuroraMinACU,
MaxACU: w.AuroraMaxACU,
PgBouncerEnabled: w.PgBouncer.Enabled,
}); err != nil {
log.Error("Failed to create Duckling CR.", "error", err)
_ = c.store.UpdateWarehouseState(w.OrgID, configstore.ManagedWarehouseStatePending, map[string]interface{}{
"state": configstore.ManagedWarehouseStateFailed,
Expand Down
Loading
Loading