diff --git a/backend/driver.go b/backend/driver.go new file mode 100644 index 0000000..42a7aa7 --- /dev/null +++ b/backend/driver.go @@ -0,0 +1,58 @@ +package backend + +import ( + "fmt" + "sort" + "sync" +) + +// Driver is the factory interface each backend package registers. +// It mirrors the database/sql driver pattern: import the driver package for +// its side-effect (init registers the driver), then open it by name. +type Driver interface { + Open(dsn string) (Backend, error) +} + +var ( + driversMu sync.RWMutex + drivers = make(map[string]Driver) +) + +// Register makes a backend driver available under the given name. +// It panics if name is empty or the same name is registered twice, matching +// the database/sql convention so mis-wired init calls fail loudly at startup. +func Register(name string, d Driver) { + driversMu.Lock() + defer driversMu.Unlock() + if name == "" { + panic("backend: Register called with empty name") + } + if _, dup := drivers[name]; dup { + panic("backend: Register called twice for driver " + name) + } + drivers[name] = d +} + +// Open opens a Backend using the named driver and the given DSN. +// The driver must have been registered (typically by importing its package). +func Open(name, dsn string) (Backend, error) { + driversMu.RLock() + d, ok := drivers[name] + driversMu.RUnlock() + if !ok { + return nil, fmt.Errorf("backend: unknown driver %q (forgotten import?)", name) + } + return d.Open(dsn) +} + +// Drivers returns a sorted list of registered driver names. +func Drivers() []string { + driversMu.RLock() + defer driversMu.RUnlock() + list := make([]string, 0, len(drivers)) + for name := range drivers { + list = append(list, name) + } + sort.Strings(list) + return list +} diff --git a/backend/mongo/execute.go b/backend/mongo/execute.go index 9955c0a..d59f544 100644 --- a/backend/mongo/execute.go +++ b/backend/mongo/execute.go @@ -7,7 +7,6 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" mgodriver "go.mongodb.org/mongo-driver/v2/mongo" - mgoptions "go.mongodb.org/mongo-driver/v2/mongo/options" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/ir" @@ -139,7 +138,7 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C q := plan.Query coll := b.db.Collection(q.Relation.Name) colTypes := columnTypes(plan.Rel) - res := &bodyResult{controls: rc.Controls()} + res := &bodyResult{controls: rc.Controls(), rows: newDocRowStream(nil)} filter := filterDoc(q.Where, colTypes) setDoc := writePayloadToSetDoc(q.Write, plan.Rel) @@ -157,9 +156,6 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C } res.rows = rows } - if res.rows == nil { - res.rows = newDocRowStream(nil) - } return res, nil } @@ -168,18 +164,17 @@ func (b *Backend) executeDelete(ctx context.Context, plan *ir.Plan, rc *reqctx.C q := plan.Query coll := b.db.Collection(q.Relation.Name) colTypes := columnTypes(plan.Rel) - res := &bodyResult{controls: rc.Controls()} + res := &bodyResult{controls: rc.Controls(), rows: newDocRowStream(nil)} filter := filterDoc(q.Where, colTypes) - var returnDocs []map[string]any if q.Write != nil && q.Write.Return == ir.ReturnRepresentation { // Capture rows before deleting. - var err error - returnDocs, err = b.findDocs(ctx, coll, filter, nil) + returnDocs, err := b.findDocs(ctx, coll, filter) if err != nil { return nil, err } + res.rows = newDocRowStream(convertDocs(returnDocs)) } out, err := coll.DeleteMany(ctx, filter) @@ -187,31 +182,21 @@ func (b *Backend) executeDelete(ctx context.Context, plan *ir.Plan, rc *reqctx.C return nil, b.MapError(err) } res.affected, res.hasAff = out.DeletedCount, true - - if returnDocs != nil { - res.rows = newDocRowStream(convertDocs(returnDocs)) - } else { - res.rows = newDocRowStream(nil) - } return res, nil } // readForReturn re-queries after a write to produce the RETURNING row stream. func (b *Backend) readForReturn(ctx context.Context, coll *mgodriver.Collection, filter bson.D) (*docRowStream, error) { - docs, err := b.findDocs(ctx, coll, filter, nil) + docs, err := b.findDocs(ctx, coll, filter) if err != nil { return nil, err } return newDocRowStream(convertDocs(docs)), nil } -// findDocs runs a find with the given filter and project, returning raw BSON maps. -func (b *Backend) findDocs(ctx context.Context, coll *mgodriver.Collection, filter bson.D, project bson.D) ([]map[string]any, error) { - opts := mgoptions.Find() - if project != nil { - opts.SetProjection(project) - } - cur, err := coll.Find(ctx, filter, opts) +// findDocs runs a find with the given filter, returning raw BSON maps. +func (b *Backend) findDocs(ctx context.Context, coll *mgodriver.Collection, filter bson.D) ([]map[string]any, error) { + cur, err := coll.Find(ctx, filter) if err != nil { return nil, b.MapError(err) } diff --git a/backend/mongo/mongo.go b/backend/mongo/mongo.go index efc8492..356433d 100644 --- a/backend/mongo/mongo.go +++ b/backend/mongo/mongo.go @@ -135,3 +135,9 @@ func (b *Backend) MapError(err error) *pgerr.APIError { } return pgerr.ErrInternal(err.Error()) } + +func init() { backend.Register("mongodb", mongoDriver{}) } + +type mongoDriver struct{} + +func (mongoDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/backend/mysql/mysql.go b/backend/mysql/mysql.go index 2d2a315..e5e3996 100644 --- a/backend/mysql/mysql.go +++ b/backend/mysql/mysql.go @@ -212,3 +212,9 @@ func buildBoolCols(rel *schema.Relation) map[string]bool { } return m } + +func init() { backend.Register("mysql", mysqlDriver{}) } + +type mysqlDriver struct{} + +func (mysqlDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index 202a0cb..5441023 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -244,3 +244,9 @@ func statusForSQLState(code string) int { } return 400 } + +func init() { backend.Register("postgres", postgresDriver{}) } + +type postgresDriver struct{} + +func (postgresDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index 8d59bd7..e817051 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -379,3 +379,9 @@ func drain(rows *sql.Rows, ncols int) ([][]any, error) { } return out, rows.Err() } + +func init() { backend.Register("sqlite", sqliteDriver{}) } + +type sqliteDriver struct{} + +func (sqliteDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/backend/sqlserver/sqlserver.go b/backend/sqlserver/sqlserver.go index 4a74011..c31afb6 100644 --- a/backend/sqlserver/sqlserver.go +++ b/backend/sqlserver/sqlserver.go @@ -212,3 +212,9 @@ func asMSSQLError(err error) (*mssql.Error, bool) { ok := errors.As(err, &me) return me, ok } + +func init() { backend.Register("sqlserver", sqlserverDriver{}) } + +type sqlserverDriver struct{} + +func (sqlserverDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/cmd/dbrest/main.go b/cmd/dbrest/main.go index 5dcae8f..8fa4fcb 100644 --- a/cmd/dbrest/main.go +++ b/cmd/dbrest/main.go @@ -14,11 +14,11 @@ import ( "github.com/tamnd/dbrest/auth" "github.com/tamnd/dbrest/backend" - mongobackend "github.com/tamnd/dbrest/backend/mongo" - "github.com/tamnd/dbrest/backend/mysql" - "github.com/tamnd/dbrest/backend/postgres" - "github.com/tamnd/dbrest/backend/sqlite" - "github.com/tamnd/dbrest/backend/sqlserver" + _ "github.com/tamnd/dbrest/backend/mongo" + _ "github.com/tamnd/dbrest/backend/mysql" + _ "github.com/tamnd/dbrest/backend/postgres" + _ "github.com/tamnd/dbrest/backend/sqlite" + _ "github.com/tamnd/dbrest/backend/sqlserver" "github.com/tamnd/dbrest/config" "github.com/tamnd/dbrest/httpapi" ) @@ -69,42 +69,17 @@ func run() error { } // openBackend opens the engine the configuration selected. +// Each backend driver self-registers via its package init function; this file +// imports them as blank imports so their init functions run. func openBackend(cfg *config.Config) (backend.Backend, error) { - switch cfg.Backend { - case config.BackendSQLite: - be, err := sqlite.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - return be, nil - case config.BackendPostgres: - be, err := postgres.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - be.SetSchemas(cfg.Schemas) - return be, nil - case config.BackendMySQL: - be, err := mysql.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - return be, nil - case config.BackendSQLServer: - be, err := sqlserver.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - return be, nil - case config.BackendMongoDB: - be, err := mongobackend.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - return be, nil - default: - return nil, fmt.Errorf("db-backend %q is unknown", cfg.Backend) + be, err := backend.Open(cfg.Backend, cfg.DBURI) + if err != nil { + return nil, fmt.Errorf("open database: %w", err) + } + if sc, ok := be.(interface{ SetSchemas([]string) }); ok { + sc.SetSchemas(cfg.Schemas) } + return be, nil } // attachAuth wires a JWT verifier onto the server when a key is configured. diff --git a/compat/compat_test.go b/compat/compat_test.go index 94fcf62..2a64778 100644 --- a/compat/compat_test.go +++ b/compat/compat_test.go @@ -463,18 +463,18 @@ var cases = []compatCase{ } // resetTestDB deletes all non-seed rows from both servers so each TestCompatibility -// run starts from the same known state (3 todos, 2 persons, 2 assignments). +// run starts from the same known state (3 todos, 3 persons, 2 assignments). func resetTestDB(t *testing.T, pgrest, dbrest string) { t.Helper() client := &http.Client{Timeout: 5 * time.Second} cleanup := []struct{ method, url string }{ {"DELETE", pgrest + "/todos?id=gt.3"}, {"DELETE", pgrest + "/assignments?id=gt.2"}, - {"DELETE", pgrest + "/persons?id=gt.2"}, + {"DELETE", pgrest + "/persons?id=gt.3"}, {"DELETE", pgrest + "/private_todos?id=gt.2"}, {"DELETE", dbrest + "/todos?id=gt.3"}, {"DELETE", dbrest + "/assignments?id=gt.2"}, - {"DELETE", dbrest + "/persons?id=gt.2"}, + {"DELETE", dbrest + "/persons?id=gt.3"}, {"DELETE", dbrest + "/private_todos?id=gt.2"}, // undo any modifications to seed rows {"PATCH", pgrest + "/todos?id=eq.1"}, diff --git a/docker/seed/03-data.sql b/docker/seed/03-data.sql index 2f6aaf2..2a16f50 100644 --- a/docker/seed/03-data.sql +++ b/docker/seed/03-data.sql @@ -11,7 +11,8 @@ ON CONFLICT (id) DO UPDATE SET INSERT INTO api.persons (id, name, age, email) VALUES (1, 'Alice', 30, 'alice@example.com'), - (2, 'Bob', 25, 'bob@example.com') + (2, 'Bob', 25, 'bob@example.com'), + (3, 'Carol', 35, 'carol@example.com') ON CONFLICT (id) DO NOTHING; INSERT INTO api.assignments (id, person_id, todo_id) VALUES