Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions backend/driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package backend

import (
"fmt"
"sort"
"sync"
)

// Driver is the factory interface each backend package registers.
// It mirrors the database/sql driver pattern: import the driver package for
// its side-effect (init registers the driver), then open it by name.
type Driver interface {
Open(dsn string) (Backend, error)
}

var (
driversMu sync.RWMutex
drivers = make(map[string]Driver)
)

// Register makes a backend driver available under the given name.
// It panics if name is empty or the same name is registered twice, matching
// the database/sql convention so mis-wired init calls fail loudly at startup.
func Register(name string, d Driver) {
driversMu.Lock()
defer driversMu.Unlock()
if name == "" {
panic("backend: Register called with empty name")
}
if _, dup := drivers[name]; dup {
panic("backend: Register called twice for driver " + name)
}
drivers[name] = d
}

// Open opens a Backend using the named driver and the given DSN.
// The driver must have been registered (typically by importing its package).
func Open(name, dsn string) (Backend, error) {
driversMu.RLock()
d, ok := drivers[name]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("backend: unknown driver %q (forgotten import?)", name)
}
return d.Open(dsn)
}

// Drivers returns a sorted list of registered driver names.
func Drivers() []string {
driversMu.RLock()
defer driversMu.RUnlock()
list := make([]string, 0, len(drivers))
for name := range drivers {
list = append(list, name)
}
sort.Strings(list)
return list
}
31 changes: 8 additions & 23 deletions backend/mongo/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"go.mongodb.org/mongo-driver/v2/bson"
mgodriver "go.mongodb.org/mongo-driver/v2/mongo"
mgoptions "go.mongodb.org/mongo-driver/v2/mongo/options"

"github.com/tamnd/dbrest/backend"
"github.com/tamnd/dbrest/ir"
Expand Down Expand Up @@ -139,7 +138,7 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C
q := plan.Query
coll := b.db.Collection(q.Relation.Name)
colTypes := columnTypes(plan.Rel)
res := &bodyResult{controls: rc.Controls()}
res := &bodyResult{controls: rc.Controls(), rows: newDocRowStream(nil)}

filter := filterDoc(q.Where, colTypes)
setDoc := writePayloadToSetDoc(q.Write, plan.Rel)
Expand All @@ -157,9 +156,6 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C
}
res.rows = rows
}
if res.rows == nil {
res.rows = newDocRowStream(nil)
}
return res, nil
}

Expand All @@ -168,50 +164,39 @@ func (b *Backend) executeDelete(ctx context.Context, plan *ir.Plan, rc *reqctx.C
q := plan.Query
coll := b.db.Collection(q.Relation.Name)
colTypes := columnTypes(plan.Rel)
res := &bodyResult{controls: rc.Controls()}
res := &bodyResult{controls: rc.Controls(), rows: newDocRowStream(nil)}

filter := filterDoc(q.Where, colTypes)

var returnDocs []map[string]any
if q.Write != nil && q.Write.Return == ir.ReturnRepresentation {
// Capture rows before deleting.
var err error
returnDocs, err = b.findDocs(ctx, coll, filter, nil)
returnDocs, err := b.findDocs(ctx, coll, filter)
if err != nil {
return nil, err
}
res.rows = newDocRowStream(convertDocs(returnDocs))
}

out, err := coll.DeleteMany(ctx, filter)
if err != nil {
return nil, b.MapError(err)
}
res.affected, res.hasAff = out.DeletedCount, true

if returnDocs != nil {
res.rows = newDocRowStream(convertDocs(returnDocs))
} else {
res.rows = newDocRowStream(nil)
}
return res, nil
}

// readForReturn re-queries after a write to produce the RETURNING row stream.
func (b *Backend) readForReturn(ctx context.Context, coll *mgodriver.Collection, filter bson.D) (*docRowStream, error) {
docs, err := b.findDocs(ctx, coll, filter, nil)
docs, err := b.findDocs(ctx, coll, filter)
if err != nil {
return nil, err
}
return newDocRowStream(convertDocs(docs)), nil
}

// findDocs runs a find with the given filter and project, returning raw BSON maps.
func (b *Backend) findDocs(ctx context.Context, coll *mgodriver.Collection, filter bson.D, project bson.D) ([]map[string]any, error) {
opts := mgoptions.Find()
if project != nil {
opts.SetProjection(project)
}
cur, err := coll.Find(ctx, filter, opts)
// findDocs runs a find with the given filter, returning raw BSON maps.
func (b *Backend) findDocs(ctx context.Context, coll *mgodriver.Collection, filter bson.D) ([]map[string]any, error) {
cur, err := coll.Find(ctx, filter)
if err != nil {
return nil, b.MapError(err)
}
Expand Down
6 changes: 6 additions & 0 deletions backend/mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,9 @@ func (b *Backend) MapError(err error) *pgerr.APIError {
}
return pgerr.ErrInternal(err.Error())
}

func init() { backend.Register("mongodb", mongoDriver{}) }

type mongoDriver struct{}

func (mongoDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
6 changes: 6 additions & 0 deletions backend/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,9 @@ func buildBoolCols(rel *schema.Relation) map[string]bool {
}
return m
}

func init() { backend.Register("mysql", mysqlDriver{}) }

type mysqlDriver struct{}

func (mysqlDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
6 changes: 6 additions & 0 deletions backend/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,9 @@ func statusForSQLState(code string) int {
}
return 400
}

func init() { backend.Register("postgres", postgresDriver{}) }

type postgresDriver struct{}

func (postgresDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
6 changes: 6 additions & 0 deletions backend/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,9 @@ func drain(rows *sql.Rows, ncols int) ([][]any, error) {
}
return out, rows.Err()
}

func init() { backend.Register("sqlite", sqliteDriver{}) }

type sqliteDriver struct{}

func (sqliteDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
6 changes: 6 additions & 0 deletions backend/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,9 @@ func asMSSQLError(err error) (*mssql.Error, bool) {
ok := errors.As(err, &me)
return me, ok
}

func init() { backend.Register("sqlserver", sqlserverDriver{}) }

type sqlserverDriver struct{}

func (sqlserverDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
53 changes: 14 additions & 39 deletions cmd/dbrest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ import (

"github.com/tamnd/dbrest/auth"
"github.com/tamnd/dbrest/backend"
mongobackend "github.com/tamnd/dbrest/backend/mongo"
"github.com/tamnd/dbrest/backend/mysql"
"github.com/tamnd/dbrest/backend/postgres"
"github.com/tamnd/dbrest/backend/sqlite"
"github.com/tamnd/dbrest/backend/sqlserver"
_ "github.com/tamnd/dbrest/backend/mongo"
_ "github.com/tamnd/dbrest/backend/mysql"
_ "github.com/tamnd/dbrest/backend/postgres"
_ "github.com/tamnd/dbrest/backend/sqlite"
_ "github.com/tamnd/dbrest/backend/sqlserver"
"github.com/tamnd/dbrest/config"
"github.com/tamnd/dbrest/httpapi"
)
Expand Down Expand Up @@ -69,42 +69,17 @@ func run() error {
}

// openBackend opens the engine the configuration selected.
// Each backend driver self-registers via its package init function; this file
// imports them as blank imports so their init functions run.
func openBackend(cfg *config.Config) (backend.Backend, error) {
switch cfg.Backend {
case config.BackendSQLite:
be, err := sqlite.Open(cfg.DBURI)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
return be, nil
case config.BackendPostgres:
be, err := postgres.Open(cfg.DBURI)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
be.SetSchemas(cfg.Schemas)
return be, nil
case config.BackendMySQL:
be, err := mysql.Open(cfg.DBURI)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
return be, nil
case config.BackendSQLServer:
be, err := sqlserver.Open(cfg.DBURI)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
return be, nil
case config.BackendMongoDB:
be, err := mongobackend.Open(cfg.DBURI)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
return be, nil
default:
return nil, fmt.Errorf("db-backend %q is unknown", cfg.Backend)
be, err := backend.Open(cfg.Backend, cfg.DBURI)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
if sc, ok := be.(interface{ SetSchemas([]string) }); ok {
sc.SetSchemas(cfg.Schemas)
}
return be, nil
}

// attachAuth wires a JWT verifier onto the server when a key is configured.
Expand Down
6 changes: 3 additions & 3 deletions compat/compat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,18 +463,18 @@ var cases = []compatCase{
}

// resetTestDB deletes all non-seed rows from both servers so each TestCompatibility
// run starts from the same known state (3 todos, 2 persons, 2 assignments).
// run starts from the same known state (3 todos, 3 persons, 2 assignments).
func resetTestDB(t *testing.T, pgrest, dbrest string) {
t.Helper()
client := &http.Client{Timeout: 5 * time.Second}
cleanup := []struct{ method, url string }{
{"DELETE", pgrest + "/todos?id=gt.3"},
{"DELETE", pgrest + "/assignments?id=gt.2"},
{"DELETE", pgrest + "/persons?id=gt.2"},
{"DELETE", pgrest + "/persons?id=gt.3"},
{"DELETE", pgrest + "/private_todos?id=gt.2"},
{"DELETE", dbrest + "/todos?id=gt.3"},
{"DELETE", dbrest + "/assignments?id=gt.2"},
{"DELETE", dbrest + "/persons?id=gt.2"},
{"DELETE", dbrest + "/persons?id=gt.3"},
{"DELETE", dbrest + "/private_todos?id=gt.2"},
// undo any modifications to seed rows
{"PATCH", pgrest + "/todos?id=eq.1"},
Expand Down
3 changes: 2 additions & 1 deletion docker/seed/03-data.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ ON CONFLICT (id) DO UPDATE SET

INSERT INTO api.persons (id, name, age, email) VALUES
(1, 'Alice', 30, 'alice@example.com'),
(2, 'Bob', 25, 'bob@example.com')
(2, 'Bob', 25, 'bob@example.com'),
(3, 'Carol', 35, 'carol@example.com')
ON CONFLICT (id) DO NOTHING;

INSERT INTO api.assignments (id, person_id, todo_id) VALUES
Expand Down
Loading