diff --git a/packages/db/pkg/testutils/db.go b/packages/db/pkg/testutils/db.go index edde569887..5d936d23b5 100644 --- a/packages/db/pkg/testutils/db.go +++ b/packages/db/pkg/testutils/db.go @@ -41,17 +41,43 @@ type Database struct { TestQueries *queries.Queries } +var ( + oneDB *Database + dblock sync.RWMutex +) + // SetupDatabase creates a fresh PostgreSQL container with migrations applied func SetupDatabase(t *testing.T) *Database { t.Helper() + ctx := context.WithoutCancel(t.Context()) + if testing.Short() { t.Skip("Skipping integration test in short mode") } + // cheap lookup + dblock.RLock() + if oneDB != nil { + dblock.RUnlock() + + return oneDB + } + dblock.RUnlock() + + // expensive lookup + dblock.Lock() + defer dblock.Unlock() + + if oneDB != nil { + return oneDB + } + + // lookup failed, create new + // Start PostgreSQL container container, err := postgres.Run( - t.Context(), + ctx, testPostgresImage, postgres.WithDatabase(testDatabaseName), postgres.WithUsername(testUsername), @@ -63,48 +89,33 @@ func SetupDatabase(t *testing.T) *Database { ), ) require.NoError(t, err, "Failed to start postgres container") - t.Cleanup(func() { - ctx := t.Context() - ctx = context.WithoutCancel(ctx) - err := container.Terminate(ctx) - assert.NoError(t, err) - }) - connStr, err := container.ConnectionString(t.Context(), "sslmode=disable") + connStr, err := container.ConnectionString(ctx, "sslmode=disable") require.NoError(t, err, "Failed to get connection string") // Setup environment and run migrations runDatabaseMigrations(t, connStr) // create test queries client - dbClient, connPool, err := pool.New(t.Context(), connStr, "tests") + dbClient, _, err := pool.New(ctx, connStr, "tests") require.NoError(t, err) - t.Cleanup(func() { - connPool.Close() - }) testQueries := queries.New(dbClient) // Create app db client - sqlcClient, err := db.NewClient(t.Context(), connStr) + sqlcClient, err := db.NewClient(ctx, connStr) require.NoError(t, err, "Failed to create sqlc client") - t.Cleanup(func() { - err := sqlcClient.Close() - assert.NoError(t, err) - }) // Create the auth db client - authDb, err := authdb.NewClient(t.Context(), connStr, connStr) + authDb, err := authdb.NewClient(ctx, connStr, connStr) require.NoError(t, err, "Failed to create auth db client") - t.Cleanup(func() { - err := authDb.Close() - assert.NoError(t, err) - }) - return &Database{ + oneDB = &Database{ SqlcClient: sqlcClient, AuthDb: authDb, TestQueries: testQueries, } + + return oneDB } // gooseMu serializes goose operations across parallel tests. @@ -125,10 +136,10 @@ func runDatabaseMigrations(t *testing.T, connStr string) { gooseMu.Lock() defer gooseMu.Unlock() - db, err := goose.OpenDBWithDriver("pgx", connStr) + gooseDB, err := goose.OpenDBWithDriver("pgx", connStr) require.NoError(t, err) t.Cleanup(func() { - err := db.Close() + err := gooseDB.Close() assert.NoError(t, err) }) @@ -136,7 +147,7 @@ func runDatabaseMigrations(t *testing.T, connStr string) { err = goose.RunWithOptionsContext( t.Context(), "up", - db, + gooseDB, filepath.Join(repoRoot, "packages", "db", "migrations"), nil, ) diff --git a/packages/docker-reverse-proxy/internal/auth/validate_test.go b/packages/docker-reverse-proxy/internal/auth/validate_test.go index 29ff34d215..357d59a61c 100644 --- a/packages/docker-reverse-proxy/internal/auth/validate_test.go +++ b/packages/docker-reverse-proxy/internal/auth/validate_test.go @@ -25,6 +25,44 @@ func TestValidate(t *testing.T) { teamID := uuid.New() envID := "test-env-id" + dbClient := testutils.SetupDatabase(t) + + // Create team + err = dbClient.AuthDb.TestsRawSQL(t.Context(), ` + INSERT INTO "auth"."users" (id, email) + VALUES ($1, 'test@e2b.dev') + ON CONFLICT DO NOTHING + `, userID) + require.NoError(t, err) + + err = dbClient.AuthDb.TestsRawSQL(t.Context(), ` + INSERT INTO teams (id, name, email, tier, slug) + VALUES ($1, 'test-team', 'test@e2b.dev', 'base_v1', 'test-team-slug') + ON CONFLICT DO NOTHING + `, teamID) + require.NoError(t, err) + + // Link user to team + err = dbClient.AuthDb.TestsRawSQL(t.Context(), ` + INSERT INTO users_teams (user_id, team_id, is_default) + VALUES ($1, $2, true) + ON CONFLICT DO NOTHING + `, userID, teamID) + require.NoError(t, err) + + // Create access token + _, err = dbClient.AuthDb.Write.CreateAccessToken(t.Context(), authqueries.CreateAccessTokenParams{ + ID: uuid.New(), + UserID: userID, + AccessTokenHash: accessToken.HashedValue, + AccessTokenPrefix: accessToken.Masked.Prefix, + AccessTokenLength: int32(accessToken.Masked.ValueLength), + AccessTokenMaskPrefix: accessToken.Masked.MaskedValuePrefix, + AccessTokenMaskSuffix: accessToken.Masked.MaskedValueSuffix, + Name: "Test token", + }) + require.NoError(t, err) + testcases := []struct { name string valid bool @@ -70,15 +108,6 @@ func TestValidate(t *testing.T) { accessTokenUsed: accessToken.PrefixedRawValue, error: false, }, - { - name: "completed build status", - valid: false, - createdEnvId: envID, - createdEnvStatus: "uploaded", - validateEnvId: envID, - accessTokenUsed: accessToken.PrefixedRawValue, - error: false, - }, { name: "invalid access token", valid: false, @@ -93,8 +122,7 @@ func TestValidate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - dbClient := testutils.SetupDatabase(t) - setupValidateTest(t, dbClient, userID, teamID, accessToken, tc.createdEnvId, tc.createdEnvStatus) + setupValidateTest(t, dbClient, teamID, tc.createdEnvId, tc.createdEnvStatus) valid, err := Validate(t.Context(), dbClient.SqlcClient, tc.accessTokenUsed, tc.validateEnvId) if tc.error { @@ -105,48 +133,24 @@ func TestValidate(t *testing.T) { assert.Equal(t, tc.valid, valid) }) } + + t.Run("completed build status", func(t *testing.T) { + envID := uuid.NewString() + setupValidateTest(t, dbClient, teamID, envID, "uploaded") + valid, err := Validate(t.Context(), dbClient.SqlcClient, accessToken.PrefixedRawValue, envID) + assert.False(t, valid) + assert.NoError(t, err) + }) } -func setupValidateTest(tb testing.TB, db *testutils.Database, userID, teamID uuid.UUID, accessToken keys.Key, envID, createdEnvStatus string) { +func setupValidateTest(tb testing.TB, db *testutils.Database, teamID uuid.UUID, envID, createdEnvStatus string) { tb.Helper() - // Create team - err := db.AuthDb.TestsRawSQL(tb.Context(), ` - INSERT INTO "auth"."users" (id, email) - VALUES ($1, 'test@e2b.dev') - `, userID) - require.NoError(tb, err) - - err = db.AuthDb.TestsRawSQL(tb.Context(), ` - INSERT INTO teams (id, name, email, tier, slug) - VALUES ($1, 'test-team', 'test@e2b.dev', 'base_v1', 'test-team-slug') - `, teamID) - require.NoError(tb, err) - - // Link user to team - err = db.AuthDb.TestsRawSQL(tb.Context(), ` - INSERT INTO users_teams (user_id, team_id, is_default) - VALUES ($1, $2, true) - `, userID, teamID) - require.NoError(tb, err) - - // Create access token - _, err = db.AuthDb.Write.CreateAccessToken(tb.Context(), authqueries.CreateAccessTokenParams{ - ID: uuid.New(), - UserID: userID, - AccessTokenHash: accessToken.HashedValue, - AccessTokenPrefix: accessToken.Masked.Prefix, - AccessTokenLength: int32(accessToken.Masked.ValueLength), - AccessTokenMaskPrefix: accessToken.Masked.MaskedValuePrefix, - AccessTokenMaskSuffix: accessToken.Masked.MaskedValueSuffix, - Name: "Test token", - }) - require.NoError(tb, err) - // Create env - err = db.SqlcClient.TestsRawSQL(tb.Context(), ` + err := db.SqlcClient.TestsRawSQL(tb.Context(), ` INSERT INTO envs (id, team_id, updated_at, source) VALUES ($1, $2, NOW(), 'template') + ON CONFLICT DO NOTHING `, envID, teamID) require.NoError(tb, err)