diff --git a/.env.gcp.template b/.env.gcp.template index 7674763b5c..94df21dac7 100644 --- a/.env.gcp.template +++ b/.env.gcp.template @@ -81,6 +81,8 @@ DASHBOARD_API_COUNT= SUPABASE_DB_CONNECTION_STRING= # Enable dashboard-api auth user sync background worker (default: false) ENABLE_AUTH_USER_SYNC_BACKGROUND_WORKER= +# Enable dashboard-api billing/team provisioning sink (default: false) +ENABLE_BILLING_HTTP_TEAM_PROVISION_SINK= # Filestore cache for builds shared across cluster (default:false) FILESTORE_CACHE_ENABLED= diff --git a/iac/modules/job-dashboard-api/main.tf b/iac/modules/job-dashboard-api/main.tf index 7797675d6b..1156c559fc 100644 --- a/iac/modules/job-dashboard-api/main.tf +++ b/iac/modules/job-dashboard-api/main.tf @@ -2,6 +2,7 @@ locals { base_env = { GIN_MODE = "release" ENVIRONMENT = var.environment + ADMIN_TOKEN = var.admin_token POSTGRES_CONNECTION_STRING = var.postgres_connection_string AUTH_DB_CONNECTION_STRING = var.auth_db_connection_string AUTH_DB_READ_REPLICA_CONNECTION_STRING = var.auth_db_read_replica_connection_string @@ -12,6 +13,9 @@ locals { REDIS_CLUSTER_URL = var.redis_cluster_url REDIS_TLS_CA_BASE64 = var.redis_tls_ca_base64 ENABLE_AUTH_USER_SYNC_BACKGROUND_WORKER = tostring(var.enable_auth_user_sync_background_worker) + ENABLE_BILLING_HTTP_TEAM_PROVISION_SINK = tostring(var.enable_billing_http_team_provision_sink) + BILLING_SERVER_URL = var.billing_server_url + BILLING_SERVER_API_TOKEN = var.billing_server_api_token OTEL_COLLECTOR_GRPC_ENDPOINT = "localhost:${var.otel_collector_grpc_port}" LOGS_COLLECTOR_ADDRESS = "http://localhost:${var.logs_proxy_port.port}" } diff --git a/iac/modules/job-dashboard-api/variables.tf b/iac/modules/job-dashboard-api/variables.tf index 063cff473f..10c2174600 100644 --- a/iac/modules/job-dashboard-api/variables.tf +++ b/iac/modules/job-dashboard-api/variables.tf @@ -23,6 +23,11 @@ variable "postgres_connection_string" { sensitive = true } +variable "admin_token" { + type = string + sensitive = true +} + variable "auth_db_connection_string" { type = string sensitive = true @@ -55,6 +60,10 @@ variable "enable_auth_user_sync_background_worker" { default = false } +variable "enable_billing_http_team_provision_sink" { + type = bool + default = false +} variable "otel_collector_grpc_port" { type = number default = 4317 @@ -76,6 +85,17 @@ variable "redis_tls_ca_base64" { default = "" } +variable "billing_server_url" { + type = string + default = "" +} + +variable "billing_server_api_token" { + type = string + sensitive = true + default = "" +} + variable "logs_proxy_port" { type = object({ name = string diff --git a/iac/provider-gcp/Makefile b/iac/provider-gcp/Makefile index 625f1febf9..5edb0e0767 100644 --- a/iac/provider-gcp/Makefile +++ b/iac/provider-gcp/Makefile @@ -78,6 +78,7 @@ tf_vars := \ $(call tfvar, DASHBOARD_API_COUNT) \ $(call tfvar, SUPABASE_DB_CONNECTION_STRING) \ $(call tfvar, ENABLE_AUTH_USER_SYNC_BACKGROUND_WORKER) \ + $(call tfvar, ENABLE_BILLING_HTTP_TEAM_PROVISION_SINK) \ $(call tfvar, DEFAULT_PERSISTENT_VOLUME_TYPE) \ $(call tfvar, PERSISTENT_VOLUME_TYPES) \ $(call tfvar, DB_MAX_OPEN_CONNECTIONS) \ diff --git a/iac/provider-gcp/api.tf b/iac/provider-gcp/api.tf index 26f1a1b408..576b953ba2 100644 --- a/iac/provider-gcp/api.tf +++ b/iac/provider-gcp/api.tf @@ -64,6 +64,24 @@ resource "google_secret_manager_secret_version" "api_admin_token_value" { secret_data = random_password.api_admin_secret.result } +resource "random_password" "dashboard_api_admin_secret" { + length = 32 + special = false +} + +resource "google_secret_manager_secret" "dashboard_api_admin_token" { + secret_id = "${var.prefix}dashboard-api-admin-token" + + replication { + auto {} + } +} + +resource "google_secret_manager_secret_version" "dashboard_api_admin_token_value" { + secret = google_secret_manager_secret.dashboard_api_admin_token.id + secret_data = random_password.dashboard_api_admin_secret.result +} + resource "random_password" "sandbox_access_token_hash_seed" { length = 32 special = false @@ -80,4 +98,4 @@ resource "google_secret_manager_secret" "sandbox_access_token_hash_seed" { resource "google_secret_manager_secret_version" "sandbox_access_token_hash_seed" { secret = google_secret_manager_secret.sandbox_access_token_hash_seed.id secret_data = random_password.sandbox_access_token_hash_seed.result -} \ No newline at end of file +} diff --git a/iac/provider-gcp/main.tf b/iac/provider-gcp/main.tf index 18855409c9..d5a1e32000 100644 --- a/iac/provider-gcp/main.tf +++ b/iac/provider-gcp/main.tf @@ -276,8 +276,10 @@ module "nomad" { # Dashboard API dashboard_api_count = var.dashboard_api_count + dashboard_api_admin_token = random_password.dashboard_api_admin_secret.result supabase_db_connection_string = var.supabase_db_connection_string enable_auth_user_sync_background_worker = var.enable_auth_user_sync_background_worker + enable_billing_http_team_provision_sink = var.enable_billing_http_team_provision_sink # Docker reverse proxy docker_reverse_proxy_port = var.docker_reverse_proxy_port diff --git a/iac/provider-gcp/nomad/main.tf b/iac/provider-gcp/nomad/main.tf index edd9e6a318..f7606ca4d7 100644 --- a/iac/provider-gcp/nomad/main.tf +++ b/iac/provider-gcp/nomad/main.tf @@ -1,8 +1,11 @@ locals { - clickhouse_connection_string = var.clickhouse_server_count > 0 ? "clickhouse://${var.clickhouse_username}:${random_password.clickhouse_password.result}@clickhouse.service.consul:${var.clickhouse_server_port.port}/${var.clickhouse_database}" : "" - redis_url = trimspace(data.google_secret_manager_secret_version.redis_cluster_url.secret_data) == "" ? "redis.service.consul:${var.redis_port.port}" : "" - redis_cluster_url = trimspace(data.google_secret_manager_secret_version.redis_cluster_url.secret_data) - loki_url = "http://loki.service.consul:${var.loki_service_port.port}" + clickhouse_connection_string = var.clickhouse_server_count > 0 ? "clickhouse://${var.clickhouse_username}:${random_password.clickhouse_password.result}@clickhouse.service.consul:${var.clickhouse_server_port.port}/${var.clickhouse_database}" : "" + redis_url = trimspace(data.google_secret_manager_secret_version.redis_cluster_url.secret_data) == "" ? "redis.service.consul:${var.redis_port.port}" : "" + redis_cluster_url = trimspace(data.google_secret_manager_secret_version.redis_cluster_url.secret_data) + loki_url = "http://loki.service.consul:${var.loki_service_port.port}" + enable_billing_http_team_provision_sink = var.enable_billing_http_team_provision_sink + dashboard_api_billing_server_url = local.enable_billing_http_team_provision_sink ? data.google_cloud_run_v2_service.billing_server[0].uri : "" + dashboard_api_billing_server_api_token = local.enable_billing_http_team_provision_sink ? data.google_secret_manager_secret_version.billing_server_api_token[0].secret_data : "" } # API @@ -35,6 +38,21 @@ data "google_secret_manager_secret_version" "launch_darkly_api_key" { secret = var.launch_darkly_api_key_secret_name } +data "google_secret_manager_secret_version" "billing_server_api_token" { + count = local.enable_billing_http_team_provision_sink ? 1 : 0 + + project = var.gcp_project_id + secret = "${var.prefix}billing-server-api-token" +} + +data "google_cloud_run_v2_service" "billing_server" { + count = local.enable_billing_http_team_provision_sink ? 1 : 0 + + project = var.gcp_project_id + location = var.gcp_region + name = "${var.prefix}billing-server" +} + provider "nomad" { address = "https://nomad.${var.domain_name}" secret_id = var.nomad_acl_token_secret @@ -130,6 +148,7 @@ module "dashboard_api" { image = data.google_artifact_registry_docker_image.dashboard_api_image[0].self_link + admin_token = var.dashboard_api_admin_token postgres_connection_string = data.google_secret_manager_secret_version.postgres_connection_string.secret_data auth_db_connection_string = data.google_secret_manager_secret_version.postgres_connection_string.secret_data auth_db_read_replica_connection_string = trimspace(data.google_secret_manager_secret_version.postgres_read_replica_connection_string.secret_data) @@ -140,6 +159,9 @@ module "dashboard_api" { redis_cluster_url = local.redis_cluster_url redis_tls_ca_base64 = trimspace(data.google_secret_manager_secret_version.redis_tls_ca_base64.secret_data) enable_auth_user_sync_background_worker = var.enable_auth_user_sync_background_worker + enable_billing_http_team_provision_sink = var.enable_billing_http_team_provision_sink + billing_server_url = local.dashboard_api_billing_server_url + billing_server_api_token = local.dashboard_api_billing_server_api_token otel_collector_grpc_port = var.otel_collector_grpc_port logs_proxy_port = var.logs_proxy_port diff --git a/iac/provider-gcp/nomad/variables.tf b/iac/provider-gcp/nomad/variables.tf index 074e4b0d1f..4b1b1b65a3 100644 --- a/iac/provider-gcp/nomad/variables.tf +++ b/iac/provider-gcp/nomad/variables.tf @@ -98,6 +98,10 @@ variable "api_admin_token" { type = string } +variable "dashboard_api_admin_token" { + type = string +} + variable "sandbox_access_token_hash_seed" { type = string } @@ -465,6 +469,10 @@ variable "enable_auth_user_sync_background_worker" { default = false } +variable "enable_billing_http_team_provision_sink" { + type = bool + default = false +} variable "volume_token_issuer" { type = string } diff --git a/iac/provider-gcp/variables.tf b/iac/provider-gcp/variables.tf index f9438b7080..3d8ebdeb72 100644 --- a/iac/provider-gcp/variables.tf +++ b/iac/provider-gcp/variables.tf @@ -241,6 +241,10 @@ variable "enable_auth_user_sync_background_worker" { default = false } +variable "enable_billing_http_team_provision_sink" { + type = bool + default = false +} variable "docker_reverse_proxy_port" { type = object({ name = string diff --git a/packages/auth/pkg/auth/middleware.go b/packages/auth/pkg/auth/middleware.go index 49fee38a81..6d4e9ecc2a 100644 --- a/packages/auth/pkg/auth/middleware.go +++ b/packages/auth/pkg/auth/middleware.go @@ -2,6 +2,7 @@ package auth import ( "context" + "crypto/subtle" "errors" "fmt" "net/http" @@ -127,7 +128,7 @@ func (a *CommonAuthenticator[T]) SecuritySchemeName() string { func adminValidationFunction(adminToken string) func(ctx context.Context, ginCtx *gin.Context, token string) (struct{}, *APIError) { return func(_ context.Context, _ *gin.Context, token string) (struct{}, *APIError) { - if token != adminToken { + if subtle.ConstantTimeCompare([]byte(token), []byte(adminToken)) != 1 { return struct{}{}, &APIError{ Code: http.StatusUnauthorized, Err: errors.New("invalid access token"), diff --git a/packages/auth/pkg/auth/middleware_test.go b/packages/auth/pkg/auth/middleware_test.go new file mode 100644 index 0000000000..c3ba82263b --- /dev/null +++ b/packages/auth/pkg/auth/middleware_test.go @@ -0,0 +1,28 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAdminValidationFunction(t *testing.T) { + t.Parallel() + + validate := adminValidationFunction("super-secret-token") + + t.Run("accepts matching token", func(t *testing.T) { + t.Parallel() + + _, err := validate(t.Context(), nil, "super-secret-token") + require.Nil(t, err) + }) + + t.Run("rejects non-matching token", func(t *testing.T) { + t.Parallel() + + _, err := validate(t.Context(), nil, "super-secret-tokem") + require.NotNil(t, err) + require.Equal(t, 401, err.Code) + }) +} diff --git a/packages/dashboard-api/Makefile b/packages/dashboard-api/Makefile index 1b00fd5055..514252612b 100644 --- a/packages/dashboard-api/Makefile +++ b/packages/dashboard-api/Makefile @@ -32,12 +32,12 @@ build-and-upload: .PHONY: run run: make build - ./bin/dashboard-api + @./bin/dashboard-api .PHONY: run-local run-local: make build - NODE_ID=$(HOSTNAME) ./bin/dashboard-api + @NODE_ID=$(HOSTNAME) ./bin/dashboard-api .PHONY: test test: diff --git a/packages/dashboard-api/go.mod b/packages/dashboard-api/go.mod index 46742fe24e..cb4c50b26a 100644 --- a/packages/dashboard-api/go.mod +++ b/packages/dashboard-api/go.mod @@ -20,6 +20,7 @@ require ( github.com/gin-contrib/cors v1.7.6 github.com/gin-gonic/gin v1.12.0 github.com/google/uuid v1.6.0 + github.com/hashicorp/go-retryablehttp v0.7.7 github.com/jackc/pgx/v5 v5.9.1 github.com/oapi-codegen/gin-middleware v1.0.2 github.com/oapi-codegen/runtime v1.1.1 @@ -27,6 +28,7 @@ require ( github.com/riverqueue/river/riverdriver/riverpgxv5 v0.33.0 github.com/riverqueue/river/rivertype v0.33.0 github.com/stretchr/testify v1.11.1 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 @@ -82,6 +84,7 @@ require ( github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -149,7 +152,6 @@ require ( go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/bridges/otelzap v0.14.0 // indirect go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.68.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect go.opentelemetry.io/contrib/instrumentation/runtime v0.66.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.15.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.39.0 // indirect diff --git a/packages/dashboard-api/go.sum b/packages/dashboard-api/go.sum index fe1f0c118f..c86d1d6a68 100644 --- a/packages/dashboard-api/go.sum +++ b/packages/dashboard-api/go.sum @@ -70,6 +70,8 @@ github.com/ebitengine/purego v0.9.0 h1:mh0zpKBIXDceC63hpvPuGLiJ8ZAa3DfrFTudmfi8A github.com/ebitengine/purego v0.9.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/exaring/otelpgx v0.9.3 h1:4yO02tXC7ZJZ+hcqcUkfxblYNCIFGVhpUWI0iw1TzPU= github.com/exaring/otelpgx v0.9.3/go.mod h1:R5/M5LWsPPBZc1SrRE5e0DiU48bI78C1/GPTWs6I66U= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= @@ -135,6 +137,12 @@ github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 h1:sGm2vDRFUrQJO/Veii4h4z github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2/go.mod h1:wd1YpapPLivG6nQgbf7ZkG1hhSOXDhhn4MLTknx2aAc= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 h1:D/V0gu4zQ3cL2WKeVNVM4r2gLxGGf6McLwgXzRTo2RQ= github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -176,6 +184,8 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= diff --git a/packages/dashboard-api/internal/api/api.gen.go b/packages/dashboard-api/internal/api/api.gen.go index 35963afac3..00ac397e5c 100644 --- a/packages/dashboard-api/internal/api/api.gen.go +++ b/packages/dashboard-api/internal/api/api.gen.go @@ -21,6 +21,7 @@ import ( ) const ( + AdminTokenAuthScopes = "AdminTokenAuth.Scopes" Supabase1TokenAuthScopes = "Supabase1TokenAuth.Scopes" Supabase2TeamAuthScopes = "Supabase2TeamAuth.Scopes" ) @@ -90,6 +91,11 @@ type BuildsStatusesResponse struct { // CPUCount CPU cores for the sandbox type CPUCount = int64 +// CreateTeamRequest defines model for CreateTeamRequest. +type CreateTeamRequest struct { + Name string `json:"name"` +} + // DefaultTemplate defines model for DefaultTemplate. type DefaultTemplate struct { Aliases []DefaultTemplateAlias `json:"aliases"` @@ -327,6 +333,9 @@ type GetTeamsResolveParams struct { Slug TeamSlug `form:"slug" json:"slug"` } +// PostTeamsJSONRequestBody defines body for PostTeams for application/json ContentType. +type PostTeamsJSONRequestBody = CreateTeamRequest + // PatchTeamsTeamIDJSONRequestBody defines body for PatchTeamsTeamID for application/json ContentType. type PatchTeamsTeamIDJSONRequestBody = UpdateTeamRequest @@ -335,6 +344,9 @@ type PostTeamsTeamIDMembersJSONRequestBody = AddTeamMemberRequest // ServerInterface represents all server handlers. type ServerInterface interface { + // Bootstrap user + // (POST /admin/users/bootstrap) + PostAdminUsersBootstrap(c *gin.Context) // List team builds // (GET /builds) GetBuilds(c *gin.Context, params GetBuildsParams) @@ -353,6 +365,9 @@ type ServerInterface interface { // List user teams // (GET /teams) GetTeams(c *gin.Context) + // Create team + // (POST /teams) + PostTeams(c *gin.Context) // Resolve team identity // (GET /teams/resolve) GetTeamsResolve(c *gin.Context, params GetTeamsResolveParams) @@ -382,6 +397,23 @@ type ServerInterfaceWrapper struct { type MiddlewareFunc func(c *gin.Context) +// PostAdminUsersBootstrap operation middleware +func (siw *ServerInterfaceWrapper) PostAdminUsersBootstrap(c *gin.Context) { + + c.Set(AdminTokenAuthScopes, []string{}) + + c.Set(Supabase1TokenAuthScopes, []string{}) + + for _, middleware := range siw.HandlerMiddlewares { + middleware(c) + if c.IsAborted() { + return + } + } + + siw.Handler.PostAdminUsersBootstrap(c) +} + // GetBuilds operation middleware func (siw *ServerInterfaceWrapper) GetBuilds(c *gin.Context) { @@ -557,6 +589,21 @@ func (siw *ServerInterfaceWrapper) GetTeams(c *gin.Context) { siw.Handler.GetTeams(c) } +// PostTeams operation middleware +func (siw *ServerInterfaceWrapper) PostTeams(c *gin.Context) { + + c.Set(Supabase1TokenAuthScopes, []string{}) + + for _, middleware := range siw.HandlerMiddlewares { + middleware(c) + if c.IsAborted() { + return + } + } + + siw.Handler.PostTeams(c) +} + // GetTeamsResolve operation middleware func (siw *ServerInterfaceWrapper) GetTeamsResolve(c *gin.Context) { @@ -755,12 +802,14 @@ func RegisterHandlersWithOptions(router gin.IRouter, si ServerInterface, options ErrorHandler: errorHandler, } + router.POST(options.BaseURL+"/admin/users/bootstrap", wrapper.PostAdminUsersBootstrap) router.GET(options.BaseURL+"/builds", wrapper.GetBuilds) router.GET(options.BaseURL+"/builds/statuses", wrapper.GetBuildsStatuses) router.GET(options.BaseURL+"/builds/:build_id", wrapper.GetBuildsBuildId) router.GET(options.BaseURL+"/health", wrapper.GetHealth) router.GET(options.BaseURL+"/sandboxes/:sandboxID/record", wrapper.GetSandboxesSandboxIDRecord) router.GET(options.BaseURL+"/teams", wrapper.GetTeams) + router.POST(options.BaseURL+"/teams", wrapper.PostTeams) router.GET(options.BaseURL+"/teams/resolve", wrapper.GetTeamsResolve) router.PATCH(options.BaseURL+"/teams/:teamID", wrapper.PatchTeamsTeamID) router.GET(options.BaseURL+"/teams/:teamID/members", wrapper.GetTeamsTeamIDMembers) @@ -772,54 +821,56 @@ func RegisterHandlersWithOptions(router gin.IRouter, si ServerInterface, options // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xcWXPcuBH+KygkVXmhZ0aHU4nedOyuVbESlSRvpcqlkjBkzwzWJMAFQB2rzH9P4eIx", - "BI+RJUfZ9ZM1ZAN9fd1oNEA/4ZhnOWfAlMQHTzgngmSgQJhf84KmyQ1N9N8JyFjQXFHO8AE+TYApuqAg", - "EF8gtQJkaCc4wlS/z4la4QgzkgE+qOaJsIBfCyogwQdKFBBhGa8gI5rBgouMKHyAi8JQqsdcj5VKULbE", - "63VUTnPDxY2CLE+JgrZo/zJ/kBQtaKpAoPmjlQ3RUuYI+eGNh1xUz0lKiSzV+bUA8djWpyFIXZdu2WVb", - "4GOeZeSdBG17BQlKqVTaqlbq0xOJFEdLUEgqogoJEi240KLBQ57yBPDBgqQS+kWVvbanCjI5wgkRzsjD", - "qSXemc3K90QIopkWjP5agCPQTNYRluox1TR6alxawuuyrTlKGyiOKIvTIoGxpihZBjX/s4AFPsB/mlYB", - "MbVkcnqkWV+a4VqDhtJdGsqbuBCSi4CC5jkSoArBINEA1QGUC7ijvJBWYQEy50wCogzdxgK0KW6I+o/3", - "5y2yruqCqGM+ApTyJqUZVW05z8gDzYoMsSKb2zg3xtKWt7KjHATKyRK6hLAT12VIYEGKVOGD97OoAhtl", - "am8XG3Bpjg5bGWXuV2lyyhQsQRjhJWHJnD+cnozJTo64Iz9VU/UFSdt+Ckg2jr+m7GDuJvm61KgnuUyL", - "ZVuWKyAZkmmxtH6TPL3r9Jcm29IEhQRxOmqB0JQdJnCTfI0J1nqwDRkTzvuzmf4n5kwBM+AmeZ7SmGj5", - "pr9ILeRTbf6+8P9BCC4sj6aSRyRBWmSQSsf9/mzn9XkeFmqlTWtnRWDpNPO912f+IxdzmiTALMf91+f4", - "T67Qghcs0RzffwunXoK4A+ENu/YgNKg6TBIdT2egM+KF87wumwTPQShqsQcZoWkDtPZJKHArxH92VNcl", - "GZ//ArFBllmATtmCt5m5teEwkMDNKGQINFQUzUAqkuV6Tbn48Xhvb+/vtVWkFDYhCt5p4tD6v6CMylUv", - "P57lKQxxjBBdID9ZJ3tWpCmZ68XV5oOWODqByFDSc2WceY8EpKaUUBypFZW2lDASkDtCDQeTmXwt0GYT", - "lqNWAZjaYLsywg46AynJMlDH/khoWghAmSVA9ytgrvxBVKLbBaEpJLcR4moF4p5KQLdaztvJsOE2gFdh", - "qOHgUq9NWTshelnaIYQMJ3xG8hwSjQOUELmacyISFKdU28rUckwv+p9tdaLFjbDVVctRxDFIWZOgclJN", - "Al2BtkPljWF3y33VYGn+fw5Co1QJuAAMB9EnP1KpLlwV0HZ/QtT4kl9PBYmZNlTyM3hQx/31veIoJ1La", - "pANIj/ClvVk3zH7TGkvjSdsPtE0Zt7S+sN7OikbJhnzd5rp0G6Juk80rsITS7Mfg1qyeSUci0cRry8wb", - "qjWFCal1fP7pmBcsEN7H559QzIXdO9d3BLi5DfnrPu7feET4xG5hrmoNiKbRTOvA/jnKDhsTHurhIcwZ", - "/Uv9WpuntqRmgC3OB5NHo4wYVwoAu0t+BiGpLbtG5rvW47yYpzSuvZpzngIxJaYg2dl8U1vjo7a2Mif3", - "LGiejgGKK5KeUPnlkv4GHWw6lKrNchfnxSiGoXTnoVL5yuvsJm5LGTVWa2e8BjgapgiFSRBwYRiHqyFd", - "VOUkhhFu39DaTjpCqJ6k5Dtuz46wwUxTcQhK6p1x1M4z+h2S9DfYzDO6ijijR73pZhbCl92ntMt+0+3a", - "ZG+IkX43wdGYFJF1Lfx2Jvd6Mrh1MeJU04XM9gFIqlbdbu0U5UOREfZOAEk0ztDKzIPiFcRfkABZpGpY", - "vj7B6kv99+3V9xK1IXH3OcNV46jAss0FSGCqzqs8UTg9meAeBuOaaJ56GPHWAdXhRI1P574u6toJhsLm", - "DDIuHkNJ0L55Tgbc2f1bKEtd2hkuIOYi6VmpNjplxi8bhgvWPnlR1g19uCzLy3WEk8Yi0Lv4VJR6HM8I", - "ZYHgJhKQfamhJKBhOiXIYkFjjWdiNsBUY3YEfLOak/qELJ35zMZ6R7CLjtR5RTMXqHUt74lEbtDohCkV", - "z/PtmZhBz06LZSydbBOzaCF4hu5XNF5pR9aFcmE3GNQ1xlHj1KKydQ3ONfc3ABuK5qqtGYivJIHk6DG0", - "jxg01fC+YnCKsp3asTwNrjpUnvhjp/YmI5Q2fbu2GlhXpN98sq/CMQSjy9aaT4YqVj91l2wX9uCnW7aR", - "ppTujGlMK0eThuT5lGv3W6nKDnpG2XlNoJ1oQz57UvSEM/LwEdhSrfDB7vv3Zunwv3cC8uaCL2gK5zRW", - "hYBPIh23ZemV+StN6DV5KVlbhjcMgoaXILQKgT5PyuMvkFwAkSM388/pF3x9HB8RxiAJ9wqoPLJadL3u", - "SQKRPaoejEhvwY+W+qW92RlfEVbUZuax/o/8Ya4ZWKW0tlh1y9VsHG2AopkNnbmG0uKGvQK7WBYXQgBT", - "rsQDObK3VY30dbjtaY4crlfDdsuna5Psc84HXgg5sruUkYeLUPeqm8fPgU5SkHoz9zfFi4JW7bFYxbwm", - "dWmiPrf2NmlINn6lKzPTcGdGT9uWSYcOxIWg6vFSz2mFuCxyMicSdq74F2CHhV4lnuwFhBWQxASGu4Lw", - "73ee+J0hrsxOcvoPMA1YT7GrRR09m1arNZkWmLrzXEWVuT70w+4ROikPxA7PT3GE73x/Fc8mO5OZloLn", - "wEhO8QHem8wmMx3TRK2MvtN5GQNLMIlOu8S0J/T2Ev8EqvR5/abf57B3KpJp8MbbOho5rjwZGDvC30ka", - "T+/uO62vN+6B7L7glYHAIVPo/oA9olwUafpYXfLKyZIycwhtBZ7YGxSzLp6lElNNVF0uGaLdqd0FGaLd", - "q92p6KfVRPUYM5AJRdfn62CYfL7WjpFFlhHx6A+OdCw7a+gAIUtZnvJIfK3ZOedO65cF+4F9WZ1BPQ/g", - "8ltAqHXwNhpGGydtf2QM/QSqffDYh6In7+P1MI6OyuOY58HoG6DIXAfaEjgJKEJTn3xeBwzuWtgQ7f4b", - "AI4zRxdu7ElDH1jsmQZ+RVdvnJoE/P2hfh4iS+dbk5VK16nMq6n0teH0qewkraei7LF26VzWlJd+lOvL", - "bhsrVf/qVYOl2TweHTC+Nfc9ZFzIeIMI720fMyWQXNiUpb9DUNPaF8bAEpE0NRWAbYSS6lYrJOaqMJpD", - "ytlSIsUjdE/VCtk9JyJMB67ZiKJFSpYTHLVBanYnrxmX7S3QaGQZ7YzqW4PqpZ0fqMoq6Woudtuuyr1T", - "d5O8x83mvUTE1nn+Brq/DP8X6T69UY8RuiMpTYiibFneFDdHHci2Nbs97LhsnXrK6/KvmnlCjddhlBj6", - "xH0x8DtNOk3cORtZoHhU9KLvyX40sbafrKl41V6pzvVjA5Ir/4HF9hgp1ybTpz7iyePLJZBWF3zdbHa4", - "z6deL4O1W9pD4CzMkAY2/6CbD2s8Y4hRQJ3Wzn66CqsaWN1R0tdh9hWz2uZR1+i1z4S4s8Xkd9zWyEoH", - "bmIjwjmXAQCcc/niCHj5rBX8AGZU4tpp1wgNiJhD5rrx3kyCefPV+WHSMNxWCWn6ZD++W1v3pGCvOjWx", - "eWKet9H5yX+39zyMDjd33YeBgXy2PwAnARm/e6OA+p+A5MIYZBRO3OXXqdtlDe/ldNHuP9P2W7NyGrt5", - "UyugApkHvvtC2YKb3Zy7Bd1R5rtpTrwwr7i0dd5BHr2+tbR/i1u8lpANJJRXn43S7vlT4/8TsAc5zY+n", - "ofHQAqrxwM+7vl7/NwAA//+FAA++dkIAAA==", + "H4sIAAAAAAAC/+xcWXPjuBH+KygmVXnhSPIxqcRvPnZ2XBknLh9bqXJN2RDZkrBDAhwAtK119N9TuHiI", + "4CHbcpzdeRoNCaCvrxvdDdBPQcTSjFGgUgQHT0GGOU5BAtf/m+YkiW9JrH7HICJOMkkYDQ6C0xioJDMC", + "HLEZkgtAeuwoCAOi3mdYLoIwoDiF4KBcJww4fM8Jhzg4kDyHMBDRAlKsCMwYT7EMDoI81yPlMlNzheSE", + "zoPVKiyWuWX8VkKaJVhCk7V/6R84QTOSSOBoujS8IVLwHCI3vfaQ8fI5TggWhTjfc+DLpjw1RqqytPMu", + "mgwfszTFHwQo3UuIUUKEVFo1XJ+eCCQZmoNEQmKZCxBoxrhiDR6zhMUQHMxwIqCbVdGpeyIhFQOMEAYp", + "fjw1g3cmk+I95hwrojkl33OwAxSRVRgIuUzUGLV0UGjCybKpOgodSIYIjZI8hqGqKEh6Jf8zh1lwEPxp", + "XDrE2AwT4yNF+lJPVxLUhG6TUNxGOReMewTUzxEHmXMKsQKocqCMwz1huTACcxAZowIQoegu4qBUcYvl", + "f5w975AxVRtELfEBoBS3CUmJbPJ5hh9JmqeI5unU+LlWltK84R1lwFGG59DGhFm4ykMMM5wnMjj4OAlL", + "sBEq93YDDS5F0WIrJdT+r1A5oRLmwDXzAtN4yh5PT4ZEJzu4JT6VS3U5SVN/EnA6jL4a2ULcLvKy0KgW", + "uUzyeZOXK8ApEkk+N3YTLLlvtZcatqEKcgH8dNAGoUa2qMAu8hIVrNRk4zLanfcnE/VPxKgEqsGNsywh", + "EVb8jX8Vismnyvpd7v8T54wbGnUhj3CMFMsgpPL7/cnO9mke5nKhVGtWRWDGKeJ72yf+ifEpiWOghuL+", + "9in+k0k0YzmNFcWPb2HUS+D3wJ1iVw6EGlWHcaz86QxURLywlldpE2cZcEkM9iDFJKmB1jzxOW6J+Bs7", + "6msxjE1/hUgjS29Ap3TGmsTs3nDoCeB6FtIDFFQkSUFInGZqT7n4dLy3t/f3yi5SMBtjCR/UYN/+PyOU", + "iEUnPZZmCfRRDBGZIbdYK3maJwmeqs3VxIMGOyqACF/Qs2mcfo84JDqVkAzJBREmldAc4HtMNAUdmVwu", + "0CTj56OSAejcYLM0wkw6AyHw3JPHfsIkyTmg1AxADwugNv1BRKC7GSYJxHchYnIB/IEIQHeKz7tRv+LW", + "gFdiqGbgQq51XlshelnowYcMy3yKswxihQMUY7GYMsxjFCVE6UrnclRt+jcmO1HshoGRVfGRRxEIUeGg", + "NFKFA5WBNl3lnWF3w7qqNzX/PwehFqoAnAeGvegTX4iQFzYLaJo/xnJ4yq+Wglgv60v5KTzK4+78XjKU", + "YSFM0AGkZrjUXu8but40ylJ4UvoDpVPKzFiXWG+mRS1kjb92dV3agqhdZdMSLL4w+8VbmlUj6UAkan9t", + "qHlNtDozPrGOz6+PWU497n18fo0ixk3tXK0IgnoZ8tf9oLvwCINjHSxVGtCaAJi09knVM1+AzuUiONj9", + "+FEv7P6/02dIvYZPyBNTQl1VGiB16rp1YX4OssPagodqug/zWv+FfhvFW1NTeoIpDnqDVy2NGZaKAL2P", + "fwEuiEn7BsbbxuMsnyYkqryaMpYA1ikux+nZdF1ajZGmtCLDD9SrnpYJkkmcnBDx7ZL8Bi1kWoSqrHIf", + "Zfkggr5w66BS2srJbBduchnWsgWrvBo4aqoYgGADOD+M/dmYSuoyHMEAs69JbRYdwFRHUHQdv2d7WG+k", + "Kyl4OXXGOGrGOfUOCfIbrMc5lcWckaPOcDfx4cvUSc2yQ3fb1snrwUi9GwXhkBCRtiUeZiX7etRbOml2", + "yuV8avsMOJGLdrO2svI5TzH9wAHHCmdooddB0QKib4iDyBPZz18XY9VU40d59yNFrnHcfs5xVTuqMGQz", + "DgKorNIqTjROT0ZBB4FhTTw3uh/xxgDl4UiFTmtdGbZVoj63OYOU8aUvCJo3z4mAO7t/80WpS7PCBUSM", + "xx071VqnTttlTXHe3CfLi7yhC5dFersKg7i2CXRuPuVINY+lmFCPc2MByLxUUOJQU53keDYjkcIz1gU4", + "UZgdAN+0YqQuJgtjPrOx3+LsvCV0XpHUOmpVygcskJ00OGAKybJscyJ60rPDYuFLJ5v4LJpxlqKHBYkW", + "ypBVpqzb9Tp1hXBYOzUpdV2Bc8X8NcD6vLlsq3r8K44hPlr66oheVfXXFb1LFO3clu2pd9ch4sQdezWL", + "DF/YdO3icmJVkG71ia4MRw8YnLZWbNKXsbql23i7MAdP7bwNVKWwZ1xDWklqqI+f6yxuFvApoecVhnbC", + "1yjp9SIzksA5iWTO4Zonw0qWTp5fqEInyWvx2lB8a+fiWgBXInj6TAmLvkF8AVgMLOaf0y94uR8fYUoh", + "9vcKiDgyUrS97ggCoTkq7/VIp8EvZvRrW7PVv8JAEhOZh9o/dIfJemIZ0ppsVTVX0XG4Bop6NLTq6guL", + "a/ryVLE0yjkHKm2KB2Jgb6uc6fJw01MdOF3ths2WT1uR7GLOZ5ZzMbC7lOLHC1/3qp3GL55Oknf0euyv", + "sxd6tdqhsZJ4hetCRV1m7WzS4HT4TldEpv7OjFq2yZNyHYhyTuTyUq0J9rg4JfSKfQN6mKsd4slcflgA", + "jrVT2OsP//6gB37QI0t944z8A3Tn9TLP8BQL2Bmylhvcv9yuEnnwako9jcWU4MSeS0si9TWon3aP0Elx", + "sHd4fhqEwb3r0waT0c5oorhgGVCckeAg2BtNRhMVG7BcaL2NsdLHOBfAxXjKmBSS40wbmZntWplatz1U", + "2RqcMyG1CpUdxVExYe1OyO4rXh/wJTW+ywTmvHKWJ8kSFZJkENvrMOWtER+xgvuxGlRegOgeqwZVARkc", + "3DShePPVD6ubr6uvYSDyNMV8qepCx7NmWAEAz0XFExSh8bQIfXPwmOdnkIWrVy+Y3vglKYeMvRctV+HA", + "ecWB1NAZ7irc8PH2mp3S2dag5jnb7ENacbcww3NC9d0Hw7BF3GQI4iabotNeQeobu/cyJPtR641q62DW", + "55UKuFYbFTjbB1U8j6t3VLuBfVkefT4P4OItINQ47x0Mo7UD3j8yhn4G2Tzv7kLRk7Pxqh9HR8Up3PNg", + "9AYo0rfQNgRODBKTRIy2CQZ7G7Fv7P47AI5VRxtuzAFTF1jMUdY205u1wzKPvT9Xj8FEYXyjskLo6ij9", + "aixcSTB+KhqIqzEvWuttMhelxKWbZdvxm/pK2bbcqrPUzwwGO4zryP5wGesyTiHcWdv5TAEk6zZFxWcR", + "VNf2hVawQDhJdAZg+t+4vExtU3I0hYTRuUCSheiByAUyrQaEqXJc3X9AswTPR0HYBKkuSrfpl83KdzCy", + "tHRa9DcsO4aUGDorK7nz1BhhR91Xqlz3co9YvHw1bTeveq3qDQH7idO7qjJtR8x+crLFVG27qDC611K0", + "VJ3699h+0tLh+Pq9QNhk/u5TGPdVzl+E/QZQLkN0jxMSY0novPhkRZ95InO+0e7zlsrGm1Hx3c5W96Ln", + "wMjqtYaj3982VMec1ZEBikNFJ/qezNdbK/PtrIwWniClHmuQXLkvvTbHSJGtvH6Qax6HvXGQ85xt9YEz", + "11PeIMa9+3LUKK8/TDqgjiuHwG2pdgWs9kz5ZZjdYlRbP/MenA1pF7e6GP2OG11pYcCNk6pXRMDrRy3v", + "l3iDAtdOM0eoQUTfNqkq790EmHdfrx3GNcVtFJDGT+Yr4JUxTwLmzmMdmyf6eROd1+4D4udhtL/db79Q", + "9sSz/R44cUjZ/TsF1P8EJBdaIYNwYm/Bj23d3V/dq6Td/b0IV6wXy5hyXi6AcKQfuH4coTOm63v7OURL", + "mm+XOXHMbHFra/0YYfD+1pD+PRb9DSZrSCi+gdBC2+dPtT9sYo726n/FAWoPDaBqD9y6q6+r/wYAAP//", + "OLVzw/9GAAA=", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/packages/dashboard-api/internal/backgroundworker/auth_user_sync.go b/packages/dashboard-api/internal/backgroundworker/auth_user_sync.go index 4cc42ffbfe..b937cc6f99 100644 --- a/packages/dashboard-api/internal/backgroundworker/auth_user_sync.go +++ b/packages/dashboard-api/internal/backgroundworker/auth_user_sync.go @@ -12,10 +12,10 @@ import ( "go.opentelemetry.io/otel/trace" "go.uber.org/zap" - sqlcdb "github.com/e2b-dev/infra/packages/db/client" + authdb "github.com/e2b-dev/infra/packages/db/pkg/auth" + authqueries "github.com/e2b-dev/infra/packages/db/pkg/auth/queries" "github.com/e2b-dev/infra/packages/db/pkg/dberrors" supabasedb "github.com/e2b-dev/infra/packages/db/pkg/supabase" - "github.com/e2b-dev/infra/packages/db/queries" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" ) @@ -38,12 +38,12 @@ type AuthUserSyncWorker struct { river.WorkerDefaults[AuthUserSyncArgs] supabaseDB *supabasedb.Client - authDB *sqlcdb.Client + authDB *authdb.Client l logger.Logger jobsCounter metric.Int64Counter } -func NewAuthUserSyncWorker(ctx context.Context, supabaseDB *supabasedb.Client, authDB *sqlcdb.Client, l logger.Logger) *AuthUserSyncWorker { +func NewAuthUserSyncWorker(ctx context.Context, supabaseDB *supabasedb.Client, authDB *authdb.Client, l logger.Logger) *AuthUserSyncWorker { jobsCounter, err := workerMeter.Int64Counter( "jobs_total", metric.WithDescription("Total auth user sync jobs by operation and result."), @@ -90,7 +90,7 @@ func (w *AuthUserSyncWorker) Work(ctx context.Context, job *river.Job[AuthUserSy switch job.Args.Operation { case "delete": - if err := w.authDB.DeletePublicUser(ctx, userID); err != nil { + if err := w.authDB.Write.DeletePublicUser(ctx, userID); err != nil { telemetry.ReportError(ctx, "auth user sync delete public user", err) w.observeJob(ctx, job.Args.Operation, jobResultError) @@ -100,7 +100,7 @@ func (w *AuthUserSyncWorker) Work(ctx context.Context, job *river.Job[AuthUserSy case "upsert": supabaseUser, err := w.supabaseDB.Write.GetAuthUserByID(ctx, userID) if dberrors.IsNotFoundError(err) { - if err := w.authDB.DeletePublicUser(ctx, userID); err != nil { + if err := w.authDB.Write.DeletePublicUser(ctx, userID); err != nil { telemetry.ReportError(ctx, "auth user sync delete stale public user", err) w.observeJob(ctx, job.Args.Operation, jobResultError) @@ -127,7 +127,7 @@ func (w *AuthUserSyncWorker) Work(ctx context.Context, job *river.Job[AuthUserSy return river.JobCancel(err) } - if err := w.authDB.UpsertPublicUser(ctx, queries.UpsertPublicUserParams{ + if err := w.authDB.Write.UpsertPublicUser(ctx, authqueries.UpsertPublicUserParams{ ID: userID, Email: supabaseUser.Email, }); err != nil { diff --git a/packages/dashboard-api/internal/backgroundworker/auth_user_sync_test.go b/packages/dashboard-api/internal/backgroundworker/auth_user_sync_test.go index bbedbc12df..0c82cb2c2b 100644 --- a/packages/dashboard-api/internal/backgroundworker/auth_user_sync_test.go +++ b/packages/dashboard-api/internal/backgroundworker/auth_user_sync_test.go @@ -14,8 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + authqueries "github.com/e2b-dev/infra/packages/db/pkg/auth/queries" "github.com/e2b-dev/infra/packages/db/pkg/testutils" - "github.com/e2b-dev/infra/packages/db/queries" "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) @@ -23,7 +23,6 @@ const ( testEventuallyTimeout = 10 * time.Second testEventuallyTick = 50 * time.Millisecond testStopTimeout = 5 * time.Second - supabaseMigrationsDir = "packages/db/pkg/supabase/migrations" ) @@ -67,14 +66,14 @@ func TestAuthUserSyncWorker_UpsertDeletesStaleProjectedUser(t *testing.T) { userID := uuid.New() staleEmail := fmt.Sprintf("stale-%s@example.com", userID.String()[:8]) - err := db.SqlcClient.UpsertPublicUser(ctx, queries.UpsertPublicUserParams{ + err := db.AuthDB.Write.UpsertPublicUser(ctx, authqueries.UpsertPublicUserParams{ ID: userID, Email: staleEmail, }) require.NoError(t, err) require.Equal(t, 1, publicUserCount(t, ctx, db, userID)) - worker := NewAuthUserSyncWorker(ctx, db.SupabaseDB, db.SqlcClient, logger.NewNopLogger()) + worker := NewAuthUserSyncWorker(ctx, db.SupabaseDB, db.AuthDB, logger.NewNopLogger()) err = worker.Work(ctx, &river.Job[AuthUserSyncArgs]{ JobRow: &rivertype.JobRow{ID: 1, Attempt: 1}, @@ -127,7 +126,7 @@ func startRiverWorker(t *testing.T, db *testutils.Database) *riverProcess { l := logger.NewNopLogger() workers := river.NewWorkers() - river.AddWorker(workers, NewAuthUserSyncWorker(ctx, db.SupabaseDB, db.SqlcClient, l)) + river.AddWorker(workers, NewAuthUserSyncWorker(ctx, db.SupabaseDB, db.AuthDB, l)) client, err := NewRiverClient(db.SupabaseDB.WritePool(), workers) require.NoError(t, err) diff --git a/packages/dashboard-api/internal/backgroundworker/river.go b/packages/dashboard-api/internal/backgroundworker/river.go index a5e07635e9..9e2689fdbc 100644 --- a/packages/dashboard-api/internal/backgroundworker/river.go +++ b/packages/dashboard-api/internal/backgroundworker/river.go @@ -11,7 +11,7 @@ import ( "github.com/riverqueue/river/rivermigrate" "go.uber.org/zap" - sqlcdb "github.com/e2b-dev/infra/packages/db/client" + authdb "github.com/e2b-dev/infra/packages/db/pkg/auth" supabasedb "github.com/e2b-dev/infra/packages/db/pkg/supabase" "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) @@ -41,7 +41,7 @@ func NewRiverClient(pool *pgxpool.Pool, workers *river.Workers) (*river.Client[p }) } -func StartAuthUserSyncWorker(setupCtx, runCtx context.Context, supabaseDB *supabasedb.Client, authDB *sqlcdb.Client, l logger.Logger) (*river.Client[pgx.Tx], error) { +func StartAuthUserSyncWorker(setupCtx, runCtx context.Context, supabaseDB *supabasedb.Client, authDB *authdb.Client, l logger.Logger) (*river.Client[pgx.Tx], error) { if err := RunRiverMigrations(setupCtx, supabaseDB.WritePool()); err != nil { return nil, fmt.Errorf("run River migrations on supabase DB: %w", err) } diff --git a/packages/dashboard-api/internal/cfg/model.go b/packages/dashboard-api/internal/cfg/model.go index 3743f45372..1b2cc5c4f0 100644 --- a/packages/dashboard-api/internal/cfg/model.go +++ b/packages/dashboard-api/internal/cfg/model.go @@ -10,6 +10,7 @@ type Config struct { Port int `env:"PORT" envDefault:"3010"` PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING,required,notEmpty"` ClickhouseConnectionString string `env:"CLICKHOUSE_CONNECTION_STRING"` + AdminToken string `env:"ADMIN_TOKEN,required,notEmpty"` SupabaseJWTSecrets []string `env:"SUPABASE_JWT_SECRETS"` AuthDBConnectionString string `env:"AUTH_DB_CONNECTION_STRING"` @@ -20,7 +21,10 @@ type Config struct { RedisClusterURL string `env:"REDIS_CLUSTER_URL"` RedisTLSCABase64 string `env:"REDIS_TLS_CA_BASE64"` - EnableAuthUserSyncBackgroundWorker bool `env:"ENABLE_AUTH_USER_SYNC_BACKGROUND_WORKER" envDefault:"false"` + EnableAuthUserSyncBackgroundWorker bool `env:"ENABLE_AUTH_USER_SYNC_BACKGROUND_WORKER" envDefault:"false"` + EnableBillingHTTPTeamProvisionSink bool `env:"ENABLE_BILLING_HTTP_TEAM_PROVISION_SINK" envDefault:"false"` + BillingServerURL string `env:"BILLING_SERVER_URL"` + BillingServerAPIToken string `env:"BILLING_SERVER_API_TOKEN"` } func Parse() (Config, error) { diff --git a/packages/dashboard-api/internal/handlers/store.go b/packages/dashboard-api/internal/handlers/store.go index 0cb45346a3..f64013f394 100644 --- a/packages/dashboard-api/internal/handlers/store.go +++ b/packages/dashboard-api/internal/handlers/store.go @@ -12,28 +12,34 @@ import ( clickhouse "github.com/e2b-dev/infra/packages/clickhouse/pkg" "github.com/e2b-dev/infra/packages/dashboard-api/internal/api" "github.com/e2b-dev/infra/packages/dashboard-api/internal/cfg" + internalteamprovision "github.com/e2b-dev/infra/packages/dashboard-api/internal/teamprovision" sqlcdb "github.com/e2b-dev/infra/packages/db/client" authdb "github.com/e2b-dev/infra/packages/db/pkg/auth" + supabasedb "github.com/e2b-dev/infra/packages/db/pkg/supabase" "github.com/e2b-dev/infra/packages/shared/pkg/apierrors" ) var _ api.ServerInterface = (*APIStore)(nil) type APIStore struct { - config cfg.Config - db *sqlcdb.Client - authDB *authdb.Client - clickhouse clickhouse.Clickhouse - authService *sharedauth.AuthService[*types.Team] + config cfg.Config + db *sqlcdb.Client + authDB *authdb.Client + supabaseDB *supabasedb.Client + clickhouse clickhouse.Clickhouse + authService *sharedauth.AuthService[*types.Team] + teamProvisionSink internalteamprovision.TeamProvisionSink } -func NewAPIStore(config cfg.Config, db *sqlcdb.Client, authDB *authdb.Client, ch clickhouse.Clickhouse, authService *sharedauth.AuthService[*types.Team]) *APIStore { +func NewAPIStore(config cfg.Config, db *sqlcdb.Client, authDB *authdb.Client, supabaseDB *supabasedb.Client, ch clickhouse.Clickhouse, authService *sharedauth.AuthService[*types.Team], teamProvisionSink internalteamprovision.TeamProvisionSink) *APIStore { return &APIStore{ - config: config, - db: db, - authDB: authDB, - clickhouse: ch, - authService: authService, + config: config, + db: db, + authDB: authDB, + supabaseDB: supabaseDB, + clickhouse: ch, + authService: authService, + teamProvisionSink: teamProvisionSink, } } diff --git a/packages/dashboard-api/internal/handlers/team_handlers_test.go b/packages/dashboard-api/internal/handlers/team_handlers_test.go index 6eb42fe775..7fd9e7709d 100644 --- a/packages/dashboard-api/internal/handlers/team_handlers_test.go +++ b/packages/dashboard-api/internal/handlers/team_handlers_test.go @@ -1,20 +1,27 @@ package handlers import ( + "context" + "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/e2b-dev/infra/packages/auth/pkg/auth" authtypes "github.com/e2b-dev/infra/packages/auth/pkg/types" + internalteamprovision "github.com/e2b-dev/infra/packages/dashboard-api/internal/teamprovision" authqueries "github.com/e2b-dev/infra/packages/db/pkg/auth/queries" "github.com/e2b-dev/infra/packages/db/pkg/testutils" "github.com/e2b-dev/infra/packages/db/queries" + "github.com/e2b-dev/infra/packages/shared/pkg/teamprovision" ) func TestParseUpdateTeamBody_ProfilePictureNullClearsValue(t *testing.T) { @@ -252,13 +259,27 @@ func TestDeleteTeamsTeamIDMembersUserId_RechecksDefaultAfterLock(t *testing.T) { func createHandlerTestUser(t *testing.T, db *testutils.Database) uuid.UUID { t.Helper() + createdAt := time.Now().Add(-newUserNewTeamRequireBillingMethodThreshold - time.Hour) + + return createHandlerTestUserWithCreatedAt(t, db, &createdAt) +} + +func createHandlerTestUserAt(t *testing.T, db *testutils.Database, createdAt time.Time) uuid.UUID { + t.Helper() + + return createHandlerTestUserWithCreatedAt(t, db, &createdAt) +} + +func createHandlerTestUserWithCreatedAt(t *testing.T, db *testutils.Database, createdAt *time.Time) uuid.UUID { + t.Helper() + userID := uuid.New() email := handlerTestUserEmail(userID) - err := db.AuthDb.TestsRawSQL(t.Context(), ` -INSERT INTO auth.users (id, email) -VALUES ($1, $2) -`, userID, email) + err := db.SupabaseDB.TestsRawSQL(t.Context(), ` +INSERT INTO auth.users (id, email, created_at) +VALUES ($1, $2, $3) +`, userID, email, createdAt) if err != nil { t.Fatalf("failed to create test user: %v", err) } @@ -281,3 +302,643 @@ VALUES ($1, $2, $3) t.Fatalf("failed to create team member relation: %v", err) } } + +func TestPostUsersBootstrap_CreatesDefaultTeamAndCallsSink(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + sink := &fakeTeamProvisionSink{} + + existingTeam, err := testDB.AuthDB.Write.GetDefaultTeamByUserID(ctx, userID) + if err != nil { + t.Fatalf("expected trigger-created default team: %v", err) + } + if err := testDB.AuthDB.Write.DeleteTeamByID(ctx, existingTeam.ID); err != nil { + t.Fatalf("failed to remove trigger-created default team: %v", err) + } + if err := testDB.AuthDB.Write.DeletePublicUser(ctx, userID); err != nil { + t.Fatalf("failed to remove trigger-created public user: %v", err) + } + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequestWithContext(ctx, http.MethodPost, "/", nil) + auth.SetUserID(ginCtx, userID) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + store.PostAdminUsersBootstrap(ginCtx) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", recorder.Code) + } + + team, err := testDB.AuthDB.Write.GetDefaultTeamByUserID(ctx, userID) + if err != nil { + t.Fatalf("expected default team to be created: %v", err) + } + + if len(sink.requests) != 1 { + t.Fatalf("expected one billing provisioning call, got %d", len(sink.requests)) + } + + req := sink.requests[0] + if req.TeamID != team.ID { + t.Fatalf("expected sink team id %s, got %s", team.ID, req.TeamID) + } + if req.Reason != teamprovision.ReasonDefaultSignupTeam { + t.Fatalf("expected default signup reason, got %s", req.Reason) + } + + var responseBody map[string]any + if err := json.Unmarshal(recorder.Body.Bytes(), &responseBody); err != nil { + t.Fatalf("failed to parse response body: %v", err) + } + if responseBody["slug"] != team.Slug { + t.Fatalf("expected slug %s, got %v", team.Slug, responseBody["slug"]) + } +} + +func TestPostUsersBootstrap_ProvisioningFailureKeepsCreatedDefaultTeam(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + sink := &fakeTeamProvisionSink{ + err: &internalteamprovision.ProvisionError{ + StatusCode: http.StatusInternalServerError, + Message: "boom", + }, + } + + existingTeam, err := testDB.AuthDB.Write.GetDefaultTeamByUserID(ctx, userID) + if err != nil { + t.Fatalf("expected trigger-created default team: %v", err) + } + if err := testDB.AuthDB.Write.DeleteTeamByID(ctx, existingTeam.ID); err != nil { + t.Fatalf("failed to remove trigger-created default team: %v", err) + } + if err := testDB.AuthDB.Write.DeletePublicUser(ctx, userID); err != nil { + t.Fatalf("failed to remove trigger-created public user: %v", err) + } + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequestWithContext(ctx, http.MethodPost, "/", nil) + auth.SetUserID(ginCtx, userID) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + store.PostAdminUsersBootstrap(ginCtx) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", recorder.Code) + } + if len(sink.requests) != 1 { + t.Fatalf("expected one provisioning call, got %d", len(sink.requests)) + } + + team, err := testDB.AuthDB.Write.GetDefaultTeamByUserID(ctx, userID) + if err != nil { + t.Fatalf("expected default team to remain after provisioning failure: %v", err) + } + + rows, err := testDB.AuthDB.Read.GetTeamsWithUsersTeamsWithTier(ctx, userID) + if err != nil { + t.Fatalf("failed to query user teams: %v", err) + } + if len(rows) != 1 { + t.Fatalf("expected one default team to remain, got %d rows", len(rows)) + } + if rows[0].Team.ID != team.ID { + t.Fatalf("expected remaining team %s, got %s", team.ID, rows[0].Team.ID) + } + if !rows[0].IsDefault { + t.Fatal("expected remaining team to be the default team") + } +} + +func TestBootstrapUser_ConcurrentRequestsCreateSingleDefaultTeam(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + sink := &fakeTeamProvisionSink{} + + existingTeam, err := testDB.AuthDB.Write.GetDefaultTeamByUserID(ctx, userID) + if err != nil { + t.Fatalf("expected trigger-created default team: %v", err) + } + if err := testDB.AuthDB.Write.DeleteTeamByID(ctx, existingTeam.ID); err != nil { + t.Fatalf("failed to remove trigger-created default team: %v", err) + } + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + + var wg sync.WaitGroup + results := make(chan provisionedTeam, 2) + errs := make(chan error, 2) + + for range 2 { + wg.Go(func() { + team, err := store.bootstrapUser(ctx, userID) + if err != nil { + errs <- err + + return + } + + results <- team + }) + } + + wg.Wait() + close(results) + close(errs) + + for err := range errs { + if err != nil { + t.Fatalf("expected bootstrap to succeed, got %v", err) + } + } + + var teamIDs []uuid.UUID + for team := range results { + teamIDs = append(teamIDs, team.ID) + } + if len(teamIDs) != 2 { + t.Fatalf("expected two bootstrap results, got %d", len(teamIDs)) + } + if teamIDs[0] != teamIDs[1] { + t.Fatalf("expected both bootstrap requests to resolve to the same team, got %s and %s", teamIDs[0], teamIDs[1]) + } + + var defaultTeamCount int + err = testDB.AuthDB.TestsRawSQLQuery(ctx, + `SELECT count(*) + FROM public.users_teams + WHERE user_id = $1 AND is_default = true`, + func(rows pgx.Rows) error { + if !rows.Next() { + return errors.New("missing default team count row") + } + + return rows.Scan(&defaultTeamCount) + }, + userID, + ) + if err != nil { + t.Fatalf("failed to count default team memberships: %v", err) + } + if defaultTeamCount != 1 { + t.Fatalf("expected exactly one default team membership, got %d", defaultTeamCount) + } +} + +func TestCreateTeam_RecentUserCreatesBlockedTeam(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUserAt(t, testDB, time.Now().Add(-time.Hour)) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: &fakeTeamProvisionSink{}, + } + + team, err := store.createTeam(ctx, userID, "Acme") + if err != nil { + t.Fatalf("expected team creation to succeed for recent user, got %v", err) + } + if !team.IsBlocked { + t.Fatal("expected recent user team to be blocked") + } + if team.BlockedReason == nil || *team.BlockedReason != blockedReasonMissingPayment { + t.Fatalf("expected blocked reason %q, got %v", blockedReasonMissingPayment, team.BlockedReason) + } +} + +func TestCreateTeam_NullCreatedAtLeavesTeamUnblocked(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUserWithCreatedAt(t, testDB, nil) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: &fakeTeamProvisionSink{}, + } + + team, err := store.createTeam(ctx, userID, "Acme") + if err != nil { + t.Fatalf("expected team creation to succeed with nil created_at, got %v", err) + } + if team.IsBlocked { + t.Fatal("expected nil created_at team to remain unblocked") + } + if team.BlockedReason != nil { + t.Fatalf("expected nil blocked reason, got %v", team.BlockedReason) + } +} + +func TestPostTeams_LocalPolicyDeniedReturnsBadRequestWithoutCreatingTeam(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + sink := &fakeTeamProvisionSink{} + + for range 2 { + team, err := testDB.AuthDB.Write.CreateTeam(ctx, authqueries.CreateTeamParams{ + Name: "extra", + Tier: baseTierID, + Email: handlerTestUserEmail(userID), + }) + if err != nil { + t.Fatalf("failed to create extra team: %v", err) + } + if err := testDB.AuthDB.Write.CreateTeamMembership(ctx, authqueries.CreateTeamMembershipParams{ + UserID: userID, + TeamID: team.ID, + IsDefault: false, + AddedBy: &userID, + }); err != nil { + t.Fatalf("failed to attach extra team membership: %v", err) + } + } + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequestWithContext(ctx, http.MethodPost, "/", strings.NewReader(`{"name":"Acme"}`)) + ginCtx.Request.Header.Set("Content-Type", "application/json") + auth.SetUserID(ginCtx, userID) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + store.PostTeams(ginCtx) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", recorder.Code) + } + if len(sink.requests) != 0 { + t.Fatalf("expected no provisioning call, got %d", len(sink.requests)) + } + + rows, err := testDB.AuthDB.Read.GetTeamsWithUsersTeamsWithTier(ctx, userID) + if err != nil { + t.Fatalf("failed to query user teams: %v", err) + } + if len(rows) != 3 { + t.Fatalf("expected existing teams to remain unchanged, got %d rows", len(rows)) + } +} + +func TestPostTeams_InvalidNameReturnsBadRequest(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + + for _, body := range []string{`{}`, `{"name":""}`, `{"name":" "}`} { + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequestWithContext(ctx, http.MethodPost, "/", strings.NewReader(body)) + ginCtx.Request.Header.Set("Content-Type", "application/json") + auth.SetUserID(ginCtx, userID) + + sink := &fakeTeamProvisionSink{} + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + store.PostTeams(ginCtx) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("expected status 400 for body %s, got %d", body, recorder.Code) + } + if len(sink.requests) != 0 { + t.Fatalf("expected no provisioning call for body %s, got %d", body, len(sink.requests)) + } + } +} + +func TestPostTeams_InvalidRequestBodyReturnsBadRequest(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + sink := &fakeTeamProvisionSink{} + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequestWithContext(ctx, http.MethodPost, "/", strings.NewReader(`{"name":`)) + ginCtx.Request.Header.Set("Content-Type", "application/json") + auth.SetUserID(ginCtx, userID) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + store.PostTeams(ginCtx) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("PostTeams(invalid JSON) status = %d, want %d", recorder.Code, http.StatusBadRequest) + } + if !strings.Contains(recorder.Body.String(), "Invalid request body") { + t.Fatalf("PostTeams(invalid JSON) body = %q, want message containing %q", recorder.Body.String(), "Invalid request body") + } + if len(sink.requests) != 0 { + t.Fatalf("PostTeams(invalid JSON) provisioning calls = %d, want %d", len(sink.requests), 0) + } +} + +func TestPostTeams_TrimsNameBeforeCreate(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + sink := &fakeTeamProvisionSink{} + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequestWithContext(ctx, http.MethodPost, "/", strings.NewReader(`{"name":" Acme "}`)) + ginCtx.Request.Header.Set("Content-Type", "application/json") + auth.SetUserID(ginCtx, userID) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + store.PostTeams(ginCtx) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", recorder.Code) + } + + rows, err := testDB.AuthDB.Read.GetTeamsWithUsersTeamsWithTier(ctx, userID) + if err != nil { + t.Fatalf("failed to query user teams: %v", err) + } + + foundCreatedTeam := false + for _, row := range rows { + if row.IsDefault { + continue + } + + foundCreatedTeam = true + if row.Team.Name != "Acme" { + t.Fatalf("expected trimmed team name %q, got %q", "Acme", row.Team.Name) + } + } + if !foundCreatedTeam { + t.Fatal("expected created team to exist") + } +} + +func TestPostTeams_ProvisioningFailureRollsBackCreatedTeam(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + sink := &fakeTeamProvisionSink{ + err: &internalteamprovision.ProvisionError{ + StatusCode: http.StatusBadRequest, + Message: "limit reached", + }, + } + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequestWithContext(ctx, http.MethodPost, "/", strings.NewReader(`{"name":"Acme"}`)) + ginCtx.Request.Header.Set("Content-Type", "application/json") + auth.SetUserID(ginCtx, userID) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + store.PostTeams(ginCtx) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d", recorder.Code) + } + if len(sink.requests) != 1 { + t.Fatalf("expected one provisioning call, got %d", len(sink.requests)) + } + + rows, err := testDB.AuthDB.Read.GetTeamsWithUsersTeamsWithTier(ctx, userID) + if err != nil { + t.Fatalf("failed to query user teams: %v", err) + } + if len(rows) != 1 { + t.Fatalf("expected only the default team to remain, got %d rows", len(rows)) + } + if !rows[0].IsDefault { + t.Fatal("expected remaining team to be the default team") + } +} + +func TestPostTeams_ProvisioningFailurePreservesProvisionErrorStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + status int + message string + }{ + {name: "too_many_requests", status: http.StatusTooManyRequests, message: "rate limited"}, + {name: "service_unavailable", status: http.StatusServiceUnavailable, message: "billing unavailable"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + sink := &fakeTeamProvisionSink{ + err: &internalteamprovision.ProvisionError{ + StatusCode: tt.status, + Message: tt.message, + }, + } + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequestWithContext(ctx, http.MethodPost, "/", strings.NewReader(`{"name":"Acme"}`)) + ginCtx.Request.Header.Set("Content-Type", "application/json") + auth.SetUserID(ginCtx, userID) + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: sink, + } + store.PostTeams(ginCtx) + + if recorder.Code != tt.status { + t.Fatalf("PostTeams(provision status %d) status = %d, want %d", tt.status, recorder.Code, tt.status) + } + if len(sink.requests) != 1 { + t.Fatalf("PostTeams(provision status %d) provisioning calls = %d, want %d", tt.status, len(sink.requests), 1) + } + + var responseBody map[string]any + if err := json.Unmarshal(recorder.Body.Bytes(), &responseBody); err != nil { + t.Fatalf("json.Unmarshal(PostTeams response) error = %v, want nil", err) + } + + codeValue, ok := responseBody["code"].(float64) + if !ok { + t.Fatalf("PostTeams(provision status %d) response code type = %T, want float64", tt.status, responseBody["code"]) + } + if got := int(codeValue); got != tt.status { + t.Fatalf("PostTeams(provision status %d) response code = %d, want %d", tt.status, got, tt.status) + } + + messageValue, ok := responseBody["message"].(string) + if !ok { + t.Fatalf("PostTeams(provision status %d) response message type = %T, want string", tt.status, responseBody["message"]) + } + if messageValue != tt.message { + t.Fatalf("PostTeams(provision status %d) response message = %q, want %q", tt.status, messageValue, tt.message) + } + + rows, err := testDB.AuthDB.Read.GetTeamsWithUsersTeamsWithTier(ctx, userID) + if err != nil { + t.Fatalf("GetTeamsWithUsersTeamsWithTier(userID=%s) error = %v, want nil", userID, err) + } + if len(rows) != 1 { + t.Fatalf("GetTeamsWithUsersTeamsWithTier(userID=%s) rows = %d, want %d", userID, len(rows), 1) + } + if !rows[0].IsDefault { + t.Fatal("expected remaining team to be the default team") + } + }) + } +} + +func TestCreateTeam_ConcurrentRequestsRespectLocalPolicyWithZeroMemberships(t *testing.T) { + t.Parallel() + + testDB := testutils.SetupDatabase(t) + ctx := t.Context() + userID := createHandlerTestUser(t, testDB) + + existingTeam, err := testDB.AuthDB.Write.GetDefaultTeamByUserID(ctx, userID) + if err != nil { + t.Fatalf("expected trigger-created default team: %v", err) + } + if err := testDB.AuthDB.Write.DeleteTeamByID(ctx, existingTeam.ID); err != nil { + t.Fatalf("failed to remove default team: %v", err) + } + + store := &APIStore{ + db: testDB.SqlcClient, + authDB: testDB.AuthDB, + supabaseDB: testDB.SupabaseDB, + teamProvisionSink: &fakeTeamProvisionSink{}, + } + + var wg sync.WaitGroup + results := make(chan error, 4) + + for _, name := range []string{"Acme-1", "Acme-2", "Acme-3", "Acme-4"} { + wg.Add(1) + go func(teamName string) { + defer wg.Done() + _, err := store.createTeam(ctx, userID, teamName) + results <- err + }(name) + } + + wg.Wait() + close(results) + + var successCount int + var badRequestCount int + for err := range results { + if err == nil { + successCount++ + + continue + } + + var provisionErr *internalteamprovision.ProvisionError + if !errors.As(err, &provisionErr) { + t.Fatalf("expected provisioning error, got %T: %v", err, err) + } + if provisionErr.StatusCode == http.StatusBadRequest { + badRequestCount++ + + continue + } + + t.Fatalf("expected bad request or success, got %d", provisionErr.StatusCode) + } + + if successCount != maxTeamsPerUser { + t.Fatalf("expected %d successes, got %d", maxTeamsPerUser, successCount) + } + if badRequestCount != 1 { + t.Fatalf("expected one bad request, got %d", badRequestCount) + } +} + +type fakeTeamProvisionSink struct { + mu sync.Mutex + requests []teamprovision.TeamBillingProvisionRequestedV1 + err error +} + +func (s *fakeTeamProvisionSink) ProvisionTeam(_ context.Context, req teamprovision.TeamBillingProvisionRequestedV1) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.requests = append(s.requests, req) + + return s.err +} diff --git a/packages/dashboard-api/internal/handlers/team_provisioning.go b/packages/dashboard-api/internal/handlers/team_provisioning.go new file mode 100644 index 0000000000..70e1f51e45 --- /dev/null +++ b/packages/dashboard-api/internal/handlers/team_provisioning.go @@ -0,0 +1,353 @@ +package handlers + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + + "github.com/e2b-dev/infra/packages/auth/pkg/auth" + "github.com/e2b-dev/infra/packages/dashboard-api/internal/api" + internalteamprovision "github.com/e2b-dev/infra/packages/dashboard-api/internal/teamprovision" + authqueries "github.com/e2b-dev/infra/packages/db/pkg/auth/queries" + "github.com/e2b-dev/infra/packages/db/pkg/dberrors" + "github.com/e2b-dev/infra/packages/shared/pkg/ginutils" + "github.com/e2b-dev/infra/packages/shared/pkg/teamprovision" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" +) + +const ( + baseTierID = "base_v1" + maxTeamsPerUser = 3 + maxTeamsPerUserWithProTier = 10 + newUserNewTeamRequireBillingMethodThreshold = 3 * 24 * time.Hour + blockedReasonMissingPayment = "missing_payment" + teamProvisionRollbackTimeout = 5 * time.Second +) + +type provisionedTeam struct { + ID uuid.UUID + Name string + Email string + Slug string + IsBlocked bool + BlockedReason *string +} + +func (s *APIStore) PostAdminUsersBootstrap(c *gin.Context) { + ctx := c.Request.Context() + telemetry.ReportEvent(ctx, "bootstrap user") + + userID := auth.MustGetUserID(c) + team, err := s.bootstrapUser(ctx, userID) + if err != nil { + s.handleProvisioningError(ctx, c, "bootstrap user", err) + + return + } + + c.JSON(http.StatusOK, api.TeamResolveResponse{ + Id: team.ID, + Slug: team.Slug, + }) +} + +func (s *APIStore) PostTeams(c *gin.Context) { + ctx := c.Request.Context() + telemetry.ReportEvent(ctx, "create team") + + userID := auth.MustGetUserID(c) + attrs := []attribute.KeyValue{ + attribute.String("team.provision.operation", "create team"), + } + body, err := ginutils.ParseBody[api.CreateTeamRequest](ctx, c) + if err != nil { + telemetry.ReportErrorByCode(ctx, http.StatusBadRequest, "create team failed", fmt.Errorf("parse create team request: %w", err), attrs...) + s.sendAPIStoreError(c, http.StatusBadRequest, "Invalid request body") + + return + } + name := strings.TrimSpace(body.Name) + if name == "" { + telemetry.ReportErrorByCode(ctx, http.StatusBadRequest, "create team failed", errors.New("team name is required"), attrs...) + s.sendAPIStoreError(c, http.StatusBadRequest, "Team name is required") + + return + } + + team, err := s.createTeam(ctx, userID, name) + if err != nil { + s.handleProvisioningError(ctx, c, "create team", err) + + return + } + + c.JSON(http.StatusOK, api.TeamResolveResponse{ + Id: team.ID, + Slug: team.Slug, + }) +} + +func (s *APIStore) bootstrapUser(ctx context.Context, userID uuid.UUID) (provisionedTeam, error) { + authUser, err := s.supabaseDB.Write.GetAuthUserByID(ctx, userID) + if err != nil { + return provisionedTeam{}, fmt.Errorf("get auth user: %w", err) + } + + authTxDB, tx, err := s.authDB.WithTx(ctx) + if err != nil { + return provisionedTeam{}, fmt.Errorf("start transaction: %w", err) + } + defer func() { + _ = tx.Rollback(ctx) + }() + + if err := authTxDB.UpsertPublicUser(ctx, authqueries.UpsertPublicUserParams{ + ID: authUser.ID, + Email: authUser.Email, + }); err != nil { + return provisionedTeam{}, fmt.Errorf("upsert public user: %w", err) + } + + // Serialize bootstrap for a user even when they have no team memberships yet. + if _, err := authTxDB.LockPublicUserForUpdate(ctx, authUser.ID); err != nil { + return provisionedTeam{}, fmt.Errorf("lock public user: %w", err) + } + + existingTeam, err := authTxDB.GetDefaultTeamByUserID(ctx, userID) + if err == nil { + if err := tx.Commit(ctx); err != nil { + return provisionedTeam{}, fmt.Errorf("commit existing user bootstrap transaction: %w", err) + } + + req := teamprovision.TeamBillingProvisionRequestedV1{ + TeamID: existingTeam.ID, + TeamName: existingTeam.Name, + TeamEmail: existingTeam.Email, + OwnerUserID: userID, + Reason: teamprovision.ReasonDefaultSignupTeam, + } + _ = s.teamProvisionSink.ProvisionTeam(ctx, req) + + return provisionedTeam{ + ID: existingTeam.ID, + Name: existingTeam.Name, + Email: existingTeam.Email, + Slug: existingTeam.Slug, + IsBlocked: existingTeam.IsBlocked, + BlockedReason: existingTeam.BlockedReason, + }, nil + } + if !dberrors.IsNotFoundError(err) { + return provisionedTeam{}, fmt.Errorf("get default team: %w", err) + } + + team, err := authTxDB.CreateTeam(ctx, authqueries.CreateTeamParams{ + Name: authUser.Email, + Tier: baseTierID, + Email: authUser.Email, + IsBlocked: false, + BlockedReason: nil, + }) + if err != nil { + return provisionedTeam{}, fmt.Errorf("create default team: %w", err) + } + + if err := authTxDB.CreateTeamMembership(ctx, authqueries.CreateTeamMembershipParams{ + UserID: userID, + TeamID: team.ID, + IsDefault: true, + AddedBy: nil, + }); err != nil { + return provisionedTeam{}, fmt.Errorf("create default team membership: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return provisionedTeam{}, fmt.Errorf("commit user bootstrap transaction: %w", err) + } + + req := teamprovision.TeamBillingProvisionRequestedV1{ + TeamID: team.ID, + TeamName: team.Name, + TeamEmail: team.Email, + OwnerUserID: userID, + Reason: teamprovision.ReasonDefaultSignupTeam, + } + _ = s.teamProvisionSink.ProvisionTeam(ctx, req) + + return provisionedTeam{ + ID: team.ID, + Name: team.Name, + Email: team.Email, + Slug: team.Slug, + IsBlocked: team.IsBlocked, + BlockedReason: team.BlockedReason, + }, nil +} + +func (s *APIStore) createTeam(ctx context.Context, userID uuid.UUID, name string) (provisionedTeam, error) { + authUser, err := s.supabaseDB.Write.GetAuthUserByID(ctx, userID) + if err != nil { + return provisionedTeam{}, fmt.Errorf("get auth user: %w", err) + } + + authTxDB, tx, err := s.authDB.WithTx(ctx) + if err != nil { + return provisionedTeam{}, fmt.Errorf("start transaction: %w", err) + } + defer func() { + _ = tx.Rollback(ctx) + }() + + if err := authTxDB.UpsertPublicUser(ctx, authqueries.UpsertPublicUserParams{ + ID: authUser.ID, + Email: authUser.Email, + }); err != nil { + return provisionedTeam{}, fmt.Errorf("upsert public user: %w", err) + } + + // Serialize team creation even when the user currently has no team memberships. + if _, err := authTxDB.LockPublicUserForUpdate(ctx, authUser.ID); err != nil { + return provisionedTeam{}, fmt.Errorf("lock public user: %w", err) + } + + if err := validateTeamCreationAllowed(ctx, authTxDB, userID); err != nil { + return provisionedTeam{}, err + } + + isBlocked, blockedReason := teamBlockPolicy(authUser.CreatedAt, time.Now()) + + team, err := authTxDB.CreateTeam(ctx, authqueries.CreateTeamParams{ + Name: name, + Tier: baseTierID, + Email: authUser.Email, + IsBlocked: isBlocked, + BlockedReason: blockedReason, + }) + if err != nil { + return provisionedTeam{}, fmt.Errorf("create team: %w", err) + } + + if err := authTxDB.CreateTeamMembership(ctx, authqueries.CreateTeamMembershipParams{ + UserID: userID, + TeamID: team.ID, + IsDefault: false, + AddedBy: &userID, + }); err != nil { + return provisionedTeam{}, fmt.Errorf("create team membership: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return provisionedTeam{}, fmt.Errorf("commit team creation transaction: %w", err) + } + + req := teamprovision.TeamBillingProvisionRequestedV1{ + TeamID: team.ID, + TeamName: team.Name, + TeamEmail: team.Email, + OwnerUserID: userID, + Reason: teamprovision.ReasonAdditionalTeam, + } + if err := s.teamProvisionSink.ProvisionTeam(ctx, req); err != nil { + rollbackCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), teamProvisionRollbackTimeout) + defer cancel() + + if deleteErr := s.authDB.Write.DeleteTeamByID(rollbackCtx, team.ID); deleteErr != nil { + return provisionedTeam{}, fmt.Errorf("delete team after provisioning failure: provision=%s delete=%w", err.Error(), deleteErr) + } + + return provisionedTeam{}, err + } + + return provisionedTeam{ + ID: team.ID, + Name: team.Name, + Email: team.Email, + Slug: team.Slug, + IsBlocked: team.IsBlocked, + BlockedReason: team.BlockedReason, + }, nil +} + +func validateTeamCreationAllowed(ctx context.Context, authTxDB *authqueries.Queries, ownerUserID uuid.UUID) error { + teams, err := authTxDB.GetTeamsWithUsersTeamsWithTierForUpdate(ctx, ownerUserID) + if err != nil { + return fmt.Errorf("query user teams for limit check: %w", err) + } + + hasProTier := false + for _, row := range teams { + if row.Tier != baseTierID { + hasProTier = true + } + if row.IsBanned { + return &internalteamprovision.ProvisionError{ + StatusCode: http.StatusBadRequest, + Message: "You're unable to create a team right now. Please contact support if this persists.", + } + } + } + + if hasProTier { + if len(teams) >= maxTeamsPerUserWithProTier { + return &internalteamprovision.ProvisionError{ + StatusCode: http.StatusBadRequest, + Message: fmt.Sprintf("You can't create more than %d teams", maxTeamsPerUserWithProTier), + } + } + } else { + if len(teams) >= maxTeamsPerUser { + return &internalteamprovision.ProvisionError{ + StatusCode: http.StatusBadRequest, + Message: fmt.Sprintf( + "You can't create more than %d teams, you can upgrade to Pro tier to create up to %d teams", + maxTeamsPerUser, + maxTeamsPerUserWithProTier, + ), + } + } + } + + return nil +} + +func teamBlockPolicy(userCreatedAt *time.Time, now time.Time) (bool, *string) { + // Some Supabase users have a NULL created_at; unknown age should not trigger the new-user block. + if userCreatedAt != nil && userCreatedAt.After(now.Add(-newUserNewTeamRequireBillingMethodThreshold)) { + reason := blockedReasonMissingPayment + + return true, &reason + } + + return false, nil +} + +func (s *APIStore) handleProvisioningError(ctx context.Context, c *gin.Context, operation string, err error) { + attrs := []attribute.KeyValue{ + attribute.String("team.provision.operation", operation), + } + + var provisionErr *internalteamprovision.ProvisionError + if errors.As(err, &provisionErr) { + if provisionErr.StatusCode < http.StatusBadRequest || provisionErr.StatusCode >= 600 { + telemetry.ReportErrorByCode(ctx, http.StatusInternalServerError, operation+" failed", err, attrs...) + s.sendAPIStoreError(c, http.StatusInternalServerError, "Failed to "+operation) + + return + } + + telemetry.ReportErrorByCode(ctx, provisionErr.StatusCode, operation+" failed", err, attrs...) + s.sendAPIStoreError(c, provisionErr.StatusCode, provisionErr.Error()) + + return + } + + telemetry.ReportErrorByCode(ctx, http.StatusInternalServerError, operation+" failed", err, attrs...) + s.sendAPIStoreError(c, http.StatusInternalServerError, "Failed to "+operation) +} diff --git a/packages/dashboard-api/internal/teamprovision/factory.go b/packages/dashboard-api/internal/teamprovision/factory.go new file mode 100644 index 0000000000..05bc4625ce --- /dev/null +++ b/packages/dashboard-api/internal/teamprovision/factory.go @@ -0,0 +1,42 @@ +package teamprovision + +import ( + "context" + "errors" + + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/shared/pkg/logger" +) + +var ( + ErrMissingBaseURL = errors.New("billing server url is required when billing http team provision sink is enabled") + ErrMissingAPIToken = errors.New("billing server api token is required when billing http team provision sink is enabled") +) + +func NewProvisionSink(ctx context.Context, enabled bool, baseURL, apiToken string) (TeamProvisionSink, error) { + if !enabled { + logger.L().Info(ctx, "team provision sink configured", + zap.String("sink", "noop"), + zap.String("result", "disabled"), + ) + + return NewNoopProvisionSink(), nil + } + + if baseURL == "" { + return nil, ErrMissingBaseURL + } + + if apiToken == "" { + return nil, ErrMissingAPIToken + } + + logger.L().Info(ctx, "team provision sink configured", + zap.String("sink", "http"), + zap.String("result", "enabled"), + zap.String("base_url", baseURL), + ) + + return NewHTTPProvisionSink(baseURL, apiToken), nil +} diff --git a/packages/dashboard-api/internal/teamprovision/http_sink.go b/packages/dashboard-api/internal/teamprovision/http_sink.go new file mode 100644 index 0000000000..b0c7a8df05 --- /dev/null +++ b/packages/dashboard-api/internal/teamprovision/http_sink.go @@ -0,0 +1,231 @@ +package teamprovision + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net/http" + "strings" + "time" + + "github.com/hashicorp/go-retryablehttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + sharedteamprovision "github.com/e2b-dev/infra/packages/shared/pkg/teamprovision" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" +) + +const billingServerAPIKeyHeader = "X-Billing-Server-API-Key" + +const ( + defaultProvisionTimeout = 30 * time.Second + defaultProvisionRetryMaxAttempts = 3 + defaultProvisionRetryInitialWait = 100 * time.Millisecond + defaultProvisionRetryWaitCeiling = 2 * time.Second + defaultProvisionAttemptTimeout = defaultProvisionTimeout / defaultProvisionRetryMaxAttempts + provisionBackoffMultiplier = 2.0 + // Error responses only need enough body to extract a short API message without buffering large upstream payloads. + provisionErrorMessageReadLimit = 2 * 1024 +) + +type HTTPProvisionSink struct { + baseURL string + apiToken string + client *retryablehttp.Client + timeout time.Duration +} + +var _ TeamProvisionSink = (*HTTPProvisionSink)(nil) + +type errorResponse struct { + Message string `json:"message"` +} + +func NewHTTPProvisionSink(baseURL, apiToken string) *HTTPProvisionSink { + return &HTTPProvisionSink{ + baseURL: strings.TrimRight(baseURL, "/"), + apiToken: apiToken, + client: newRetryableProvisionClient(defaultProvisionAttemptTimeout), + timeout: defaultProvisionTimeout, + } +} + +func (s *HTTPProvisionSink) ProvisionTeam(ctx context.Context, req sharedteamprovision.TeamBillingProvisionRequestedV1) error { + baseAttrs := provisionTelemetryAttrs(req, provisionSinkHTTP) + telemetry.SetAttributes(ctx, baseAttrs...) + telemetry.ReportEvent(ctx, "team_provision.started", baseAttrs...) + + if s.baseURL == "" || s.apiToken == "" { + err := &ProvisionError{ + StatusCode: http.StatusServiceUnavailable, + Message: "billing provisioning sink is not configured", + } + failureAttrs := provisionTelemetryAttrs(req, provisionSinkHTTP, + attribute.String("team.provision.result", "failed"), + attribute.Int64("http.response.status_code", int64(err.StatusCode)), + ) + telemetry.ReportErrorByCode(ctx, err.StatusCode, "team provisioning failed", err, failureAttrs...) + + return err + } + + body, err := json.Marshal(req) + if err != nil { + telemetry.ReportCriticalError(ctx, "marshal billing provisioning request", err, baseAttrs...) + + return fmt.Errorf("marshal billing provisioning request: %w", err) + } + + retryCtx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + startedAt := time.Now() + resp, err := s.provisionTeamOnce(retryCtx, body) + if resp != nil { + defer resp.Body.Close() + } + + duration := time.Since(startedAt) + if err == nil && resp != nil && resp.StatusCode == http.StatusOK { + successAttrs := provisionTelemetryAttrs(req, provisionSinkHTTP, + attribute.String("team.provision.result", "success"), + attribute.Int64("http.response.status_code", int64(http.StatusOK)), + attribute.Int64("team.provision.duration_ms", duration.Milliseconds()), + ) + telemetry.ReportEvent(ctx, "team_provision.completed", successAttrs...) + + fields := append(provisionLogFields(req, provisionSinkHTTP), + zap.String("team.provision.result", "success"), + zap.Int("http.response.status_code", http.StatusOK), + zap.Duration("team.provision.duration", duration), + ) + logger.L().Info(ctx, "team provisioning completed", fields...) + + return nil + } + + provisionErr := buildProvisionError(resp, err) + failureAttrs := provisionTelemetryAttrs(req, provisionSinkHTTP, + attribute.String("team.provision.result", "failed"), + attribute.Int64("team.provision.duration_ms", duration.Milliseconds()), + attribute.Int64("http.response.status_code", int64(provisionErr.StatusCode)), + ) + telemetry.ReportErrorByCode(ctx, provisionErr.StatusCode, "team provisioning failed", provisionErr, failureAttrs...) + + return provisionErr +} + +func (s *HTTPProvisionSink) provisionTeamOnce(ctx context.Context, body []byte) (*http.Response, error) { + httpReq, err := retryablehttp.NewRequestWithContext( + ctx, + http.MethodPost, + s.baseURL+"/internal/teams/provision", + body, + ) + if err != nil { + return nil, fmt.Errorf("create billing provisioning request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set(billingServerAPIKeyHeader, s.apiToken) + + return s.client.Do(httpReq) +} + +func buildProvisionError(resp *http.Response, err error) *ProvisionError { + if resp != nil { + message, readErr := readProvisionErrorMessage(resp) + if readErr != nil { + return &ProvisionError{ + StatusCode: http.StatusBadGateway, + Message: "billing provisioning response was unreadable", + Err: fmt.Errorf("read billing provisioning error response: %w", readErr), + } + } + + return &ProvisionError{ + StatusCode: resp.StatusCode, + Message: message, + Err: err, + } + } + + if errors.Is(err, context.DeadlineExceeded) { + return &ProvisionError{ + StatusCode: http.StatusGatewayTimeout, + Message: "billing provisioning request timed out", + Err: err, + } + } + if errors.Is(err, context.Canceled) { + return &ProvisionError{ + StatusCode: http.StatusServiceUnavailable, + Message: "billing provisioning request was canceled", + Err: err, + } + } + + return &ProvisionError{ + StatusCode: http.StatusServiceUnavailable, + Message: "billing provisioning request failed", + Err: err, + } +} + +func newRetryableProvisionClient(timeout time.Duration) *retryablehttp.Client { + client := retryablehttp.NewClient() + client.Logger = nil + client.RetryMax = defaultProvisionRetryMaxAttempts - 1 + client.RetryWaitMin = defaultProvisionRetryInitialWait + client.RetryWaitMax = defaultProvisionRetryWaitCeiling + client.ErrorHandler = retryablehttp.PassthroughErrorHandler + client.Backoff = func(minWait, maxWait time.Duration, attemptNum int, resp *http.Response) time.Duration { + if resp != nil && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable) { + return retryablehttp.DefaultBackoff(minWait, maxWait, attemptNum, resp) + } + + backoff := minWait + for range attemptNum { + backoff = time.Duration(float64(backoff) * provisionBackoffMultiplier) + if backoff > maxWait { + backoff = maxWait + + break + } + } + + if backoff > 0 { + return time.Duration(rand.Int63n(int64(backoff))) + } + + return backoff + } + client.HTTPClient.Timeout = timeout + client.HTTPClient.Transport = otelhttp.NewTransport(client.HTTPClient.Transport) + + return client +} + +func readProvisionErrorMessage(resp *http.Response) (string, error) { + body, err := io.ReadAll(io.LimitReader(resp.Body, provisionErrorMessageReadLimit)) + if err != nil { + return "", err + } + + var apiErr errorResponse + if err := json.Unmarshal(body, &apiErr); err == nil && apiErr.Message != "" { + return apiErr.Message, nil + } + + message := strings.TrimSpace(string(body)) + if message != "" { + return message, nil + } + + return http.StatusText(resp.StatusCode), nil +} diff --git a/packages/dashboard-api/internal/teamprovision/http_sink_test.go b/packages/dashboard-api/internal/teamprovision/http_sink_test.go new file mode 100644 index 0000000000..ca8fb22d7c --- /dev/null +++ b/packages/dashboard-api/internal/teamprovision/http_sink_test.go @@ -0,0 +1,96 @@ +package teamprovision + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + sharedteamprovision "github.com/e2b-dev/infra/packages/shared/pkg/teamprovision" +) + +func TestHTTPProvisionSink_ReturnsJSONErrorMessage(t *testing.T) { + t.Parallel() + + var requestCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + requestCount.Add(1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message":"invalid payload"}`)) + })) + defer server.Close() + + sink := NewHTTPProvisionSink(server.URL, "token") + err := sink.ProvisionTeam(t.Context(), testProvisionRequest()) + require.Error(t, err) + + var provisionErr *ProvisionError + require.ErrorAs(t, err, &provisionErr) + require.Equal(t, http.StatusBadRequest, provisionErr.StatusCode) + require.Equal(t, "invalid payload", provisionErr.Message) + require.EqualValues(t, 1, requestCount.Load()) +} + +func TestHTTPProvisionSink_RetriesRetryableResponsesAndSucceeds(t *testing.T) { + t.Parallel() + + var requestCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + attempt := requestCount.Add(1) + if attempt == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("temporary outage")) + + return + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + sink := NewHTTPProvisionSink(server.URL, "token") + sink.client.RetryWaitMin = time.Millisecond + sink.client.RetryWaitMax = time.Millisecond + err := sink.ProvisionTeam(t.Context(), testProvisionRequest()) + require.NoError(t, err) + require.EqualValues(t, 2, requestCount.Load()) +} + +func TestHTTPProvisionSink_RetriesRequestTimeoutWithinOverallBudget(t *testing.T) { + t.Parallel() + + var requestCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + attempt := requestCount.Add(1) + if attempt == 1 { + time.Sleep(40 * time.Millisecond) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + sink := NewHTTPProvisionSink(server.URL, "token") + sink.timeout = 80 * time.Millisecond + sink.client.HTTPClient.Timeout = 25 * time.Millisecond + sink.client.RetryWaitMin = time.Millisecond + sink.client.RetryWaitMax = time.Millisecond + + err := sink.ProvisionTeam(t.Context(), testProvisionRequest()) + require.NoError(t, err) + require.EqualValues(t, 2, requestCount.Load()) +} + +func testProvisionRequest() sharedteamprovision.TeamBillingProvisionRequestedV1 { + return sharedteamprovision.TeamBillingProvisionRequestedV1{ + TeamID: uuid.New(), + TeamName: "Acme", + TeamEmail: "acme@example.com", + OwnerUserID: uuid.New(), + Reason: sharedteamprovision.ReasonAdditionalTeam, + } +} diff --git a/packages/dashboard-api/internal/teamprovision/noop_sink.go b/packages/dashboard-api/internal/teamprovision/noop_sink.go new file mode 100644 index 0000000000..60df22408a --- /dev/null +++ b/packages/dashboard-api/internal/teamprovision/noop_sink.go @@ -0,0 +1,35 @@ +package teamprovision + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + sharedteamprovision "github.com/e2b-dev/infra/packages/shared/pkg/teamprovision" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" +) + +type NoopProvisionSink struct{} + +var _ TeamProvisionSink = (*NoopProvisionSink)(nil) + +func NewNoopProvisionSink() *NoopProvisionSink { + return &NoopProvisionSink{} +} + +func (s *NoopProvisionSink) ProvisionTeam(ctx context.Context, req sharedteamprovision.TeamBillingProvisionRequestedV1) error { + attrs := provisionTelemetryAttrs(req, provisionSinkNoop, + attribute.String("team.provision.result", "skipped"), + ) + telemetry.SetAttributes(ctx, attrs...) + telemetry.ReportEvent(ctx, "team_provision.skipped", attrs...) + + fields := append(provisionLogFields(req, provisionSinkNoop), + zap.String("team.provision.result", "skipped"), + ) + logger.L().Info(ctx, "team provisioning skipped", fields...) + + return nil +} diff --git a/packages/dashboard-api/internal/teamprovision/sink.go b/packages/dashboard-api/internal/teamprovision/sink.go new file mode 100644 index 0000000000..7fe61002bb --- /dev/null +++ b/packages/dashboard-api/internal/teamprovision/sink.go @@ -0,0 +1,64 @@ +package teamprovision + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + sharedteamprovision "github.com/e2b-dev/infra/packages/shared/pkg/teamprovision" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" +) + +type TeamProvisionSink interface { + ProvisionTeam(ctx context.Context, req sharedteamprovision.TeamBillingProvisionRequestedV1) error +} + +const ( + provisionSinkHTTP = "http" + provisionSinkNoop = "noop" +) + +type ProvisionError struct { + StatusCode int + Message string + Err error +} + +func (e *ProvisionError) Error() string { + if e.Message != "" { + return e.Message + } + + return fmt.Sprintf("billing provisioning failed with status %d", e.StatusCode) +} + +func (e *ProvisionError) Unwrap() error { + if e == nil { + return nil + } + + return e.Err +} + +func provisionLogFields(req sharedteamprovision.TeamBillingProvisionRequestedV1, sink string) []zap.Field { + return []zap.Field{ + logger.WithTeamID(req.TeamID.String()), + logger.WithUserID(req.OwnerUserID.String()), + zap.String("team.provision.reason", req.Reason), + zap.String("team.provision.sink", sink), + } +} + +func provisionTelemetryAttrs(req sharedteamprovision.TeamBillingProvisionRequestedV1, sink string, attrs ...attribute.KeyValue) []attribute.KeyValue { + base := []attribute.KeyValue{ + telemetry.WithTeamID(req.TeamID.String()), + telemetry.WithUserID(req.OwnerUserID.String()), + attribute.String("team.provision.reason", req.Reason), + attribute.String("team.provision.sink", sink), + } + + return append(base, attrs...) +} diff --git a/packages/dashboard-api/internal/utils/builds.go b/packages/dashboard-api/internal/utils/builds.go index 1ad18210fb..b3afd1c5b2 100644 --- a/packages/dashboard-api/internal/utils/builds.go +++ b/packages/dashboard-api/internal/utils/builds.go @@ -1,4 +1,4 @@ -package utils +package buildutils import ( "github.com/e2b-dev/infra/packages/dashboard-api/internal/api" diff --git a/packages/dashboard-api/main.go b/packages/dashboard-api/main.go index 23a52a1003..c8b3c1b7f6 100644 --- a/packages/dashboard-api/main.go +++ b/packages/dashboard-api/main.go @@ -32,6 +32,7 @@ import ( "github.com/e2b-dev/infra/packages/dashboard-api/internal/backgroundworker" "github.com/e2b-dev/infra/packages/dashboard-api/internal/cfg" "github.com/e2b-dev/infra/packages/dashboard-api/internal/handlers" + internalteamprovision "github.com/e2b-dev/infra/packages/dashboard-api/internal/teamprovision" sqlcdb "github.com/e2b-dev/infra/packages/db/client" authdb "github.com/e2b-dev/infra/packages/db/pkg/auth" "github.com/e2b-dev/infra/packages/db/pkg/pool" @@ -71,7 +72,9 @@ func run() int { tel, err := telemetry.New(ctx, nodeID, serviceName, commitSHA, serviceVersion, serviceInstanceID) if err != nil { - logger.L().Fatal(ctx, "failed to create telemetry", zap.Error(err)) + log.Printf("failed to create telemetry: %v\n", err) + + return 1 } defer func() { if err := tel.Shutdown(ctx); err != nil { @@ -87,14 +90,18 @@ func run() int { EnableConsole: true, }) if err != nil { - logger.L().Fatal(ctx, "failed to create logger", zap.Error(err)) + log.Printf("failed to create logger: %v\n", err) + + return 1 } defer l.Sync() logger.ReplaceGlobals(ctx, l) config, err := cfg.Parse() if err != nil { - l.Fatal(ctx, "failed to parse config", zap.Error(err)) + l.Error(ctx, "failed to parse config", zap.Error(err)) + + return 1 } l.Info(ctx, "Starting dashboard-api service...", zap.String("commit_sha", commitSHA), zap.String("instance_id", serviceInstanceID)) @@ -107,7 +114,9 @@ func run() int { err = sqlcdb.CheckMigrationVersion(ctx, config.PostgresConnectionString, expectedMigration) if err != nil { - l.Fatal(ctx, "failed to check migration version", zap.Error(err)) + l.Error(ctx, "failed to check migration version", zap.Error(err)) + + return 1 } db, err := sqlcdb.NewClient( @@ -116,7 +125,9 @@ func run() int { pool.WithMaxConnections(8), ) if err != nil { - l.Fatal(ctx, "Initializing database client", zap.Error(err)) + l.Error(ctx, "Initializing database client", zap.Error(err)) + + return 1 } defer db.Close() @@ -127,17 +138,33 @@ func run() int { pool.WithMaxConnections(8), ) if err != nil { - l.Fatal(ctx, "Initializing auth database client", zap.Error(err)) + l.Error(ctx, "Initializing auth database client", zap.Error(err)) + + return 1 } defer authDB.Close() + supabaseDB, err := supabasedb.NewClient( + ctx, + config.SupabaseDBConnectionString, + pool.WithMaxConnections(4), + ) + if err != nil { + l.Error(ctx, "Initializing supabase database client", zap.Error(err)) + + return 1 + } + defer supabaseDB.Close() + var clickhouseClient clickhouse.Clickhouse if config.ClickhouseConnectionString == "" { clickhouseClient = clickhouse.NewNoopClient() } else { clickhouseClient, err = clickhouse.New(config.ClickhouseConnectionString) if err != nil { - l.Fatal(ctx, "Initializing ClickHouse client", zap.Error(err)) + l.Error(ctx, "Initializing ClickHouse client", zap.Error(err)) + + return 1 } defer clickhouseClient.Close(ctx) } @@ -148,7 +175,9 @@ func run() int { RedisTLSCABase64: config.RedisTLSCABase64, }) if err != nil { - l.Fatal(ctx, "Initializing Redis client", zap.Error(err)) + l.Error(ctx, "Initializing Redis client", zap.Error(err)) + + return 1 } defer func() { if err := factories.CloseCleanly(redisClient); err != nil { @@ -161,16 +190,31 @@ func run() int { authService := sharedauth.NewAuthService[*types.Team](authStore, authCache, config.SupabaseJWTSecrets) defer authService.Close(ctx) - apiStore := handlers.NewAPIStore(config, db, authDB, clickhouseClient, authService) + teamProvisionSink, err := internalteamprovision.NewProvisionSink( + ctx, + config.EnableBillingHTTPTeamProvisionSink, + config.BillingServerURL, + config.BillingServerAPIToken, + ) + if err != nil { + l.Error(ctx, "initializing team provision sink", zap.Error(err)) + + return 1 + } + + apiStore := handlers.NewAPIStore(config, db, authDB, supabaseDB, clickhouseClient, authService, teamProvisionSink) swagger, err := api.GetSwagger() if err != nil { - l.Fatal(ctx, "Error loading swagger spec", zap.Error(err)) + l.Error(ctx, "Error loading swagger spec", zap.Error(err)) + + return 1 } swagger.Servers = nil authenticationFunc := sharedauth.CreateAuthenticationFunc( []sharedauth.Authenticator{ + sharedauth.NewAdminTokenAuthenticator(config.AdminToken), sharedauth.NewSupabaseTokenAuthenticator(apiStore.GetUserIDFromSupabaseToken), sharedauth.NewSupabaseTeamAuthenticator(apiStore.GetTeamFromSupabaseToken), }, @@ -183,28 +227,18 @@ func run() int { defer sigCancel() var riverClient *river.Client[pgx.Tx] - var supabaseDB *supabasedb.Client - if config.EnableAuthUserSyncBackgroundWorker { - supabaseDB, err = supabasedb.NewClient( - ctx, - config.SupabaseDBConnectionString, - pool.WithMaxConnections(4), - ) - if err != nil { - l.Fatal(ctx, "Initializing supabase database client", zap.Error(err)) - } - defer supabaseDB.Close() - riverClient, err = backgroundworker.StartAuthUserSyncWorker( ctx, signalCtx, supabaseDB, - db, + authDB, l, ) if err != nil { - l.Fatal(ctx, "failed to start auth user sync worker", zap.Error(err)) + l.Error(ctx, "failed to start auth user sync worker", zap.Error(err)) + + return 1 } } @@ -255,6 +289,7 @@ func newHTTPServer( "Origin", "Content-Length", "Content-Type", + sharedauth.HeaderAdminToken, sharedauth.HeaderSupabaseToken, sharedauth.HeaderSupabaseTeam, } diff --git a/packages/db/migrations/20000101000000_auth.sql b/packages/db/migrations/20000101000000_auth.sql index 1a7d0d90b0..e420affef4 100644 --- a/packages/db/migrations/20000101000000_auth.sql +++ b/packages/db/migrations/20000101000000_auth.sql @@ -16,6 +16,7 @@ GRANT EXECUTE ON FUNCTION auth.uid() TO postgres; CREATE TABLE auth.users ( id uuid NOT NULL DEFAULT gen_random_uuid(), email text NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), PRIMARY KEY (id) ); -- +goose StatementEnd diff --git a/packages/db/queries/delete_public_user.sql.go b/packages/db/pkg/auth/queries/delete_public_user.sql.go similarity index 95% rename from packages/db/queries/delete_public_user.sql.go rename to packages/db/pkg/auth/queries/delete_public_user.sql.go index 585d8c977c..b7a3c975a3 100644 --- a/packages/db/queries/delete_public_user.sql.go +++ b/packages/db/pkg/auth/queries/delete_public_user.sql.go @@ -3,7 +3,7 @@ // sqlc v1.29.0 // source: delete_public_user.sql -package queries +package authqueries import ( "context" diff --git a/packages/db/pkg/auth/queries/get_user.sql.go b/packages/db/pkg/auth/queries/get_user.sql.go index 5d7ee53330..8a842acc2f 100644 --- a/packages/db/pkg/auth/queries/get_user.sql.go +++ b/packages/db/pkg/auth/queries/get_user.sql.go @@ -12,12 +12,12 @@ import ( ) const getUser = `-- name: GetUser :one -SELECT id, email FROM "auth"."users" where id = $1 +SELECT id, email, created_at FROM "auth"."users" where id = $1 ` func (q *Queries) GetUser(ctx context.Context, userID uuid.UUID) (AuthUser, error) { row := q.db.QueryRow(ctx, getUser, userID) var i AuthUser - err := row.Scan(&i.ID, &i.Email) + err := row.Scan(&i.ID, &i.Email, &i.CreatedAt) return i, err } diff --git a/packages/db/pkg/auth/queries/lock_public_user_for_update.sql.go b/packages/db/pkg/auth/queries/lock_public_user_for_update.sql.go new file mode 100644 index 0000000000..d255400753 --- /dev/null +++ b/packages/db/pkg/auth/queries/lock_public_user_for_update.sql.go @@ -0,0 +1,25 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: lock_public_user_for_update.sql + +package authqueries + +import ( + "context" + + "github.com/google/uuid" +) + +const lockPublicUserForUpdate = `-- name: LockPublicUserForUpdate :one +SELECT id +FROM public.users +WHERE id = $1 +FOR UPDATE +` + +func (q *Queries) LockPublicUserForUpdate(ctx context.Context, id uuid.UUID) (uuid.UUID, error) { + row := q.db.QueryRow(ctx, lockPublicUserForUpdate, id) + err := row.Scan(&id) + return id, err +} diff --git a/packages/db/pkg/auth/queries/models.go b/packages/db/pkg/auth/queries/models.go index de91154ef2..e6a3a542f3 100644 --- a/packages/db/pkg/auth/queries/models.go +++ b/packages/db/pkg/auth/queries/models.go @@ -50,8 +50,9 @@ type Addon struct { } type AuthUser struct { - ID uuid.UUID - Email string + ID uuid.UUID + Email string + CreatedAt time.Time } type BillingSandboxLog struct { diff --git a/packages/db/pkg/auth/queries/team_creation_guard.sql.go b/packages/db/pkg/auth/queries/team_creation_guard.sql.go new file mode 100644 index 0000000000..f4abb68baf --- /dev/null +++ b/packages/db/pkg/auth/queries/team_creation_guard.sql.go @@ -0,0 +1,103 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: team_creation_guard.sql + +package authqueries + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +const getTeamsWithUsersTeamsWithTierForUpdate = `-- name: GetTeamsWithUsersTeamsWithTierForUpdate :many +SELECT + t.id, + t.created_at, + t.is_blocked, + t.name, + t.tier, + t.email, + t.is_banned, + t.blocked_reason, + t.cluster_id, + t.sandbox_scheduling_labels, + t.slug, + ut.is_default, + tl.id, + tl.max_length_hours, + tl.concurrent_sandboxes, + tl.concurrent_template_builds, + tl.max_vcpu, + tl.max_ram_mb, + tl.disk_mb +FROM public.teams t +JOIN public.users_teams ut ON ut.team_id = t.id +JOIN public.team_limits tl ON tl.id = t.id +WHERE ut.user_id = $1::uuid +FOR UPDATE OF ut +` + +type GetTeamsWithUsersTeamsWithTierForUpdateRow struct { + ID uuid.UUID + CreatedAt time.Time + IsBlocked bool + Name string + Tier string + Email string + IsBanned bool + BlockedReason *string + ClusterID *uuid.UUID + SandboxSchedulingLabels []string + Slug string + IsDefault bool + ID_2 uuid.UUID + MaxLengthHours int64 + ConcurrentSandboxes int32 + ConcurrentTemplateBuilds int32 + MaxVcpu int32 + MaxRamMb int32 + DiskMb int32 +} + +func (q *Queries) GetTeamsWithUsersTeamsWithTierForUpdate(ctx context.Context, userID uuid.UUID) ([]GetTeamsWithUsersTeamsWithTierForUpdateRow, error) { + rows, err := q.db.Query(ctx, getTeamsWithUsersTeamsWithTierForUpdate, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetTeamsWithUsersTeamsWithTierForUpdateRow + for rows.Next() { + var i GetTeamsWithUsersTeamsWithTierForUpdateRow + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.IsBlocked, + &i.Name, + &i.Tier, + &i.Email, + &i.IsBanned, + &i.BlockedReason, + &i.ClusterID, + &i.SandboxSchedulingLabels, + &i.Slug, + &i.IsDefault, + &i.ID_2, + &i.MaxLengthHours, + &i.ConcurrentSandboxes, + &i.ConcurrentTemplateBuilds, + &i.MaxVcpu, + &i.MaxRamMb, + &i.DiskMb, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/packages/db/pkg/auth/queries/team_lifecycle.sql.go b/packages/db/pkg/auth/queries/team_lifecycle.sql.go new file mode 100644 index 0000000000..f9328f2af2 --- /dev/null +++ b/packages/db/pkg/auth/queries/team_lifecycle.sql.go @@ -0,0 +1,143 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: team_lifecycle.sql + +package authqueries + +import ( + "context" + + "github.com/google/uuid" +) + +const createTeam = `-- name: CreateTeam :one +INSERT INTO public.teams (name, tier, email, is_blocked, blocked_reason) +VALUES ( + $1::text, + $2::text, + $3::text, + $4::boolean, + $5::text +) +RETURNING + id, + created_at, + is_blocked, + name, + tier, + email, + is_banned, + blocked_reason, + cluster_id, + sandbox_scheduling_labels, + slug +` + +type CreateTeamParams struct { + Name string + Tier string + Email string + IsBlocked bool + BlockedReason *string +} + +func (q *Queries) CreateTeam(ctx context.Context, arg CreateTeamParams) (Team, error) { + row := q.db.QueryRow(ctx, createTeam, + arg.Name, + arg.Tier, + arg.Email, + arg.IsBlocked, + arg.BlockedReason, + ) + var i Team + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.IsBlocked, + &i.Name, + &i.Tier, + &i.Email, + &i.IsBanned, + &i.BlockedReason, + &i.ClusterID, + &i.SandboxSchedulingLabels, + &i.Slug, + ) + return i, err +} + +const createTeamMembership = `-- name: CreateTeamMembership :exec +INSERT INTO public.users_teams (user_id, team_id, is_default, added_by) +VALUES ( + $1::uuid, + $2::uuid, + $3::boolean, + $4::uuid +) +` + +type CreateTeamMembershipParams struct { + UserID uuid.UUID + TeamID uuid.UUID + IsDefault bool + AddedBy *uuid.UUID +} + +func (q *Queries) CreateTeamMembership(ctx context.Context, arg CreateTeamMembershipParams) error { + _, err := q.db.Exec(ctx, createTeamMembership, + arg.UserID, + arg.TeamID, + arg.IsDefault, + arg.AddedBy, + ) + return err +} + +const deleteTeamByID = `-- name: DeleteTeamByID :exec +DELETE FROM public.teams +WHERE id = $1::uuid +` + +func (q *Queries) DeleteTeamByID(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, deleteTeamByID, id) + return err +} + +const getDefaultTeamByUserID = `-- name: GetDefaultTeamByUserID :one +SELECT + t.id, + t.created_at, + t.is_blocked, + t.name, + t.tier, + t.email, + t.is_banned, + t.blocked_reason, + t.cluster_id, + t.sandbox_scheduling_labels, + t.slug +FROM public.teams t +JOIN public.users_teams ut ON ut.team_id = t.id +WHERE ut.user_id = $1::uuid + AND ut.is_default = true +` + +func (q *Queries) GetDefaultTeamByUserID(ctx context.Context, userID uuid.UUID) (Team, error) { + row := q.db.QueryRow(ctx, getDefaultTeamByUserID, userID) + var i Team + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.IsBlocked, + &i.Name, + &i.Tier, + &i.Email, + &i.IsBanned, + &i.BlockedReason, + &i.ClusterID, + &i.SandboxSchedulingLabels, + &i.Slug, + ) + return i, err +} diff --git a/packages/db/queries/upsert_public_user.sql.go b/packages/db/pkg/auth/queries/upsert_public_user.sql.go similarity index 96% rename from packages/db/queries/upsert_public_user.sql.go rename to packages/db/pkg/auth/queries/upsert_public_user.sql.go index dc0fd6eba3..6c0f05ec5a 100644 --- a/packages/db/queries/upsert_public_user.sql.go +++ b/packages/db/pkg/auth/queries/upsert_public_user.sql.go @@ -3,7 +3,7 @@ // sqlc v1.29.0 // source: upsert_public_user.sql -package queries +package authqueries import ( "context" diff --git a/packages/db/pkg/auth/sql_queries/teams/team_creation_guard.sql b/packages/db/pkg/auth/sql_queries/teams/team_creation_guard.sql new file mode 100644 index 0000000000..8a027583d1 --- /dev/null +++ b/packages/db/pkg/auth/sql_queries/teams/team_creation_guard.sql @@ -0,0 +1,26 @@ +-- name: GetTeamsWithUsersTeamsWithTierForUpdate :many +SELECT + t.id, + t.created_at, + t.is_blocked, + t.name, + t.tier, + t.email, + t.is_banned, + t.blocked_reason, + t.cluster_id, + t.sandbox_scheduling_labels, + t.slug, + ut.is_default, + tl.id, + tl.max_length_hours, + tl.concurrent_sandboxes, + tl.concurrent_template_builds, + tl.max_vcpu, + tl.max_ram_mb, + tl.disk_mb +FROM public.teams t +JOIN public.users_teams ut ON ut.team_id = t.id +JOIN public.team_limits tl ON tl.id = t.id +WHERE ut.user_id = sqlc.arg(user_id)::uuid +FOR UPDATE OF ut; diff --git a/packages/db/pkg/auth/sql_queries/teams/team_lifecycle.sql b/packages/db/pkg/auth/sql_queries/teams/team_lifecycle.sql new file mode 100644 index 0000000000..26f7094156 --- /dev/null +++ b/packages/db/pkg/auth/sql_queries/teams/team_lifecycle.sql @@ -0,0 +1,52 @@ +-- name: CreateTeam :one +INSERT INTO public.teams (name, tier, email, is_blocked, blocked_reason) +VALUES ( + sqlc.arg(name)::text, + sqlc.arg(tier)::text, + sqlc.arg(email)::text, + sqlc.arg(is_blocked)::boolean, + sqlc.narg(blocked_reason)::text +) +RETURNING + id, + created_at, + is_blocked, + name, + tier, + email, + is_banned, + blocked_reason, + cluster_id, + sandbox_scheduling_labels, + slug; + +-- name: CreateTeamMembership :exec +INSERT INTO public.users_teams (user_id, team_id, is_default, added_by) +VALUES ( + sqlc.arg(user_id)::uuid, + sqlc.arg(team_id)::uuid, + sqlc.arg(is_default)::boolean, + sqlc.narg(added_by)::uuid +); + +-- name: GetDefaultTeamByUserID :one +SELECT + t.id, + t.created_at, + t.is_blocked, + t.name, + t.tier, + t.email, + t.is_banned, + t.blocked_reason, + t.cluster_id, + t.sandbox_scheduling_labels, + t.slug +FROM public.teams t +JOIN public.users_teams ut ON ut.team_id = t.id +WHERE ut.user_id = sqlc.arg(user_id)::uuid + AND ut.is_default = true; + +-- name: DeleteTeamByID :exec +DELETE FROM public.teams +WHERE id = sqlc.arg(id)::uuid; diff --git a/packages/db/queries/users/delete_public_user.sql b/packages/db/pkg/auth/sql_queries/users/delete_public_user.sql similarity index 100% rename from packages/db/queries/users/delete_public_user.sql rename to packages/db/pkg/auth/sql_queries/users/delete_public_user.sql diff --git a/packages/db/pkg/auth/sql_queries/users/lock_public_user_for_update.sql b/packages/db/pkg/auth/sql_queries/users/lock_public_user_for_update.sql new file mode 100644 index 0000000000..d8201156b8 --- /dev/null +++ b/packages/db/pkg/auth/sql_queries/users/lock_public_user_for_update.sql @@ -0,0 +1,5 @@ +-- name: LockPublicUserForUpdate :one +SELECT id +FROM public.users +WHERE id = @id +FOR UPDATE; diff --git a/packages/db/queries/users/upsert_public_user.sql b/packages/db/pkg/auth/sql_queries/users/upsert_public_user.sql similarity index 100% rename from packages/db/queries/users/upsert_public_user.sql rename to packages/db/pkg/auth/sql_queries/users/upsert_public_user.sql diff --git a/packages/db/pkg/supabase/queries/get_auth_user.sql.go b/packages/db/pkg/supabase/queries/get_auth_user.sql.go index 4f881e7ff3..11c863f8dd 100644 --- a/packages/db/pkg/supabase/queries/get_auth_user.sql.go +++ b/packages/db/pkg/supabase/queries/get_auth_user.sql.go @@ -12,7 +12,7 @@ import ( ) const getAuthUserByID = `-- name: GetAuthUserByID :one -SELECT id, COALESCE(email, '') AS email +SELECT id, COALESCE(email, '') AS email, created_at FROM auth.users WHERE id = $1::uuid ` @@ -20,6 +20,6 @@ WHERE id = $1::uuid func (q *Queries) GetAuthUserByID(ctx context.Context, dollar_1 uuid.UUID) (AuthUser, error) { row := q.db.QueryRow(ctx, getAuthUserByID, dollar_1) var i AuthUser - err := row.Scan(&i.ID, &i.Email) + err := row.Scan(&i.ID, &i.Email, &i.CreatedAt) return i, err } diff --git a/packages/db/pkg/supabase/queries/models.go b/packages/db/pkg/supabase/queries/models.go index 4bbb324fa3..8e9daed53f 100644 --- a/packages/db/pkg/supabase/queries/models.go +++ b/packages/db/pkg/supabase/queries/models.go @@ -5,10 +5,13 @@ package supabasequeries import ( + "time" + "github.com/google/uuid" ) type AuthUser struct { - ID uuid.UUID - Email string + ID uuid.UUID + Email string + CreatedAt *time.Time } diff --git a/packages/db/pkg/supabase/schema/auth_users_override.sql b/packages/db/pkg/supabase/schema/auth_users_override.sql index 83cc73104d..3974df32e7 100644 --- a/packages/db/pkg/supabase/schema/auth_users_override.sql +++ b/packages/db/pkg/supabase/schema/auth_users_override.sql @@ -14,5 +14,6 @@ GRANT EXECUTE ON FUNCTION auth.uid() TO postgres; CREATE TABLE auth.users ( id uuid NOT NULL DEFAULT gen_random_uuid(), email text NOT NULL, + created_at timestamptz DEFAULT now(), PRIMARY KEY (id) ); diff --git a/packages/db/pkg/supabase/sql_queries/users/get_auth_user.sql b/packages/db/pkg/supabase/sql_queries/users/get_auth_user.sql index f01b84a05c..ca83c61b33 100644 --- a/packages/db/pkg/supabase/sql_queries/users/get_auth_user.sql +++ b/packages/db/pkg/supabase/sql_queries/users/get_auth_user.sql @@ -1,4 +1,4 @@ -- name: GetAuthUserByID :one -SELECT id, COALESCE(email, '') AS email +SELECT id, COALESCE(email, '') AS email, created_at FROM auth.users WHERE id = $1::uuid; diff --git a/packages/db/queries/models.go b/packages/db/queries/models.go index 6c960a7a59..3f6cb7f4cd 100644 --- a/packages/db/queries/models.go +++ b/packages/db/queries/models.go @@ -50,8 +50,9 @@ type Addon struct { } type AuthUser struct { - ID uuid.UUID - Email string + ID uuid.UUID + Email string + CreatedAt time.Time } type BillingSandboxLog struct { diff --git a/packages/shared/pkg/teamprovision/request.go b/packages/shared/pkg/teamprovision/request.go new file mode 100644 index 0000000000..0dc183b28a --- /dev/null +++ b/packages/shared/pkg/teamprovision/request.go @@ -0,0 +1,16 @@ +package teamprovision + +import "github.com/google/uuid" + +const ( + ReasonDefaultSignupTeam = "default_signup_team" + ReasonAdditionalTeam = "additional_team" +) + +type TeamBillingProvisionRequestedV1 struct { + TeamID uuid.UUID `json:"team_id"` + TeamName string `json:"team_name"` + TeamEmail string `json:"team_email"` + OwnerUserID uuid.UUID `json:"owner_user_id"` + Reason string `json:"reason"` +} diff --git a/spec/openapi-dashboard.yml b/spec/openapi-dashboard.yml index f3725972e5..f848845da2 100644 --- a/spec/openapi-dashboard.yml +++ b/spec/openapi-dashboard.yml @@ -6,6 +6,10 @@ info: components: securitySchemes: + AdminTokenAuth: + type: apiKey + in: header + name: X-Admin-Token Supabase1TokenAuth: type: apiKey in: header @@ -497,6 +501,16 @@ components: type: string format: email + CreateTeamRequest: + type: object + required: + - name + properties: + name: + type: string + minLength: 1 + maxLength: 255 + DefaultTemplateAlias: type: object required: @@ -718,6 +732,49 @@ paths: $ref: "#/components/responses/401" "500": $ref: "#/components/responses/500" + post: + summary: Create team + tags: [teams] + security: + - Supabase1TokenAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateTeamRequest" + responses: + "200": + description: Successfully created team. + content: + application/json: + schema: + $ref: "#/components/schemas/TeamResolveResponse" + "400": + $ref: "#/components/responses/400" + "401": + $ref: "#/components/responses/401" + "500": + $ref: "#/components/responses/500" + + /admin/users/bootstrap: + post: + summary: Bootstrap user + tags: [teams] + security: + - AdminTokenAuth: [] + Supabase1TokenAuth: [] + responses: + "200": + description: Successfully bootstrapped user. + content: + application/json: + schema: + $ref: "#/components/schemas/TeamResolveResponse" + "401": + $ref: "#/components/responses/401" + "500": + $ref: "#/components/responses/500" /teams/resolve: get: