diff --git a/pkg/dotc1z/c1file.go b/pkg/dotc1z/c1file.go index 38a1f685b..9d7aba6b3 100644 --- a/pkg/dotc1z/c1file.go +++ b/pkg/dotc1z/c1file.go @@ -67,6 +67,11 @@ type C1File struct { slowQueryThreshold time.Duration slowQueryLogFrequency time.Duration + // Prepared statement cache: keyed by SQL text, lazily populated. + // With MaxOpenConns(1) all stmts are bound to the single connection. + stmtCache map[string]*sql.Stmt + stmtCacheMu sync.Mutex + // Sync cleanup settings syncLimit int skipCleanup bool @@ -184,6 +189,7 @@ func NewC1File(ctx context.Context, dbFilePath string, opts ...C1FOption) (*C1Fi slowQueryThreshold: 5 * time.Second, slowQueryLogFrequency: 1 * time.Minute, encoderConcurrency: 1, + stmtCache: make(map[string]*sql.Stmt), } for _, opt := range opts { @@ -492,12 +498,53 @@ func (c *C1File) closeRawDB(ctx context.Context) error { _, span := tracer.Start(ctx, "C1File.closeRawDB") var err error defer func() { uotel.EndSpanWithError(span, err) }() + + c.stmtCacheMu.Lock() + for _, stmt := range c.stmtCache { + stmt.Close() + } + c.stmtCache = nil + c.stmtCacheMu.Unlock() + err = c.rawDb.Close() c.rawDb = nil c.db = nil return err } +func (c *C1File) getOrPrepare(ctx context.Context, query string) (*sql.Stmt, error) { + c.stmtCacheMu.Lock() + if c.stmtCache == nil { + c.stmtCacheMu.Unlock() + return nil, ErrDbNotOpen + } + if stmt, ok := c.stmtCache[query]; ok { + c.stmtCacheMu.Unlock() + return stmt, nil + } + c.stmtCacheMu.Unlock() + + stmt, err := c.rawDb.PrepareContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("getOrPrepare: %w", err) + } + + c.stmtCacheMu.Lock() + if c.stmtCache == nil { + c.stmtCacheMu.Unlock() + stmt.Close() + return nil, ErrDbNotOpen + } + if existing, ok := c.stmtCache[query]; ok { + c.stmtCacheMu.Unlock() + stmt.Close() + return existing, nil + } + c.stmtCache[query] = stmt + c.stmtCacheMu.Unlock() + return stmt, nil +} + // truncateWAL truncates the WAL file. // Returns the busy, log, and checkpointed values. func (c *C1File) truncateWAL(ctx context.Context) (int, int, int, error) { diff --git a/pkg/dotc1z/grants.go b/pkg/dotc1z/grants.go index 5b0bedcd9..2aef3fba7 100644 --- a/pkg/dotc1z/grants.go +++ b/pkg/dotc1z/grants.go @@ -229,7 +229,11 @@ func listGrantsGeneric(ctx context.Context, c *C1File, req listRequest) ([]*v2.G } queryStart := time.Now() - rows, err := c.db.QueryContext(ctx, query, args...) + stmt, err := c.getOrPrepare(ctx, query) + if err != nil { + return nil, "", err + } + rows, err := stmt.QueryContext(ctx, args...) if err != nil { return nil, "", err } diff --git a/pkg/dotc1z/grants_expandable_query.go b/pkg/dotc1z/grants_expandable_query.go index 049d58a01..f3d8f4c83 100644 --- a/pkg/dotc1z/grants_expandable_query.go +++ b/pkg/dotc1z/grants_expandable_query.go @@ -73,7 +73,11 @@ func (c *C1File) listExpandableGrantsInternal( return nil, "", err } - rows, err := c.db.QueryContext(ctx, query, args...) + stmt, err := c.getOrPrepare(ctx, query) + if err != nil { + return nil, "", err + } + rows, err := stmt.QueryContext(ctx, args...) if err != nil { return nil, "", err } @@ -196,7 +200,11 @@ func (c *C1File) listGrantsWithExpansionInternal(ctx context.Context, opts grant return nil, err } - rows, err := c.db.QueryContext(ctx, query, args...) + stmt, err := c.getOrPrepare(ctx, query) + if err != nil { + return nil, err + } + rows, err := stmt.QueryContext(ctx, args...) if err != nil { return nil, err } diff --git a/pkg/dotc1z/grants_hydrate.go b/pkg/dotc1z/grants_hydrate.go index 1a47a970c..593f5a6d5 100644 --- a/pkg/dotc1z/grants_hydrate.go +++ b/pkg/dotc1z/grants_hydrate.go @@ -64,10 +64,14 @@ func hydrateSingleGrant(ctx context.Context, c *C1File, syncID string, g *v2.Gra return nil } - row := c.db.QueryRowContext(ctx, fmt.Sprintf( + q := fmt.Sprintf( `SELECT entitlement_id, resource_type_id, resource_id, principal_resource_type_id, principal_resource_id - FROM %s WHERE external_id = ? AND sync_id = ?`, grants.Name(), - ), g.GetId(), syncID) + FROM %s WHERE external_id = ? AND sync_id = ?`, grants.Name()) + stmt, err := c.getOrPrepare(ctx, q) + if err != nil { + return err + } + row := stmt.QueryRowContext(ctx, g.GetId(), syncID) var k grantJoinKeys if err := row.Scan( &k.EntitlementID, diff --git a/pkg/dotc1z/sql_helpers.go b/pkg/dotc1z/sql_helpers.go index cb1f613cc..bccb58d9a 100644 --- a/pkg/dotc1z/sql_helpers.go +++ b/pkg/dotc1z/sql_helpers.go @@ -244,8 +244,12 @@ func listConnectorObjects[T proto.Message](ctx context.Context, c *C1File, table // Start timing the query execution queryStartTime := time.Now() - // Execute the query - rows, err := c.db.QueryContext(ctx, query, args...) + // Execute the query via prepared statement cache + stmt, err := c.getOrPrepare(ctx, query) + if err != nil { + return nil, "", err + } + rows, err := stmt.QueryContext(ctx, args...) if err != nil { return nil, "", err } @@ -640,7 +644,11 @@ func (c *C1File) getResourceObject(ctx context.Context, resourceID *v2.ResourceI } data := make([]byte, 0) - row := c.db.QueryRowContext(ctx, query, args...) + stmt, err := c.getOrPrepare(ctx, query) + if err != nil { + return err + } + row := stmt.QueryRowContext(ctx, args...) err = row.Scan(&data) if err != nil { return err @@ -700,7 +708,11 @@ func (c *C1File) getConnectorObject(ctx context.Context, tableName string, id st } var data []byte - row := c.db.QueryRowContext(ctx, query, args...) + stmt, err := c.getOrPrepare(ctx, query) + if err != nil { + return err + } + row := stmt.QueryRowContext(ctx, args...) err = row.Scan(&data) if err != nil { return err