Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pkg/dotc1z/c1file.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ type C1File struct {
slowQueryThreshold time.Duration
slowQueryLogFrequency time.Duration

// Prepared statement cache: keyed by SQL text, lazily populated.
// With MaxOpenConns(1) all stmts are bound to the single connection.
stmtCache map[string]*sql.Stmt
stmtCacheMu sync.Mutex

// Sync cleanup settings
syncLimit int
skipCleanup bool
Expand Down Expand Up @@ -184,6 +189,7 @@ func NewC1File(ctx context.Context, dbFilePath string, opts ...C1FOption) (*C1Fi
slowQueryThreshold: 5 * time.Second,
slowQueryLogFrequency: 1 * time.Minute,
encoderConcurrency: 1,
stmtCache: make(map[string]*sql.Stmt),
}

for _, opt := range opts {
Expand Down Expand Up @@ -492,12 +498,53 @@ func (c *C1File) closeRawDB(ctx context.Context) error {
_, span := tracer.Start(ctx, "C1File.closeRawDB")
var err error
defer func() { uotel.EndSpanWithError(span, err) }()

c.stmtCacheMu.Lock()
for _, stmt := range c.stmtCache {
stmt.Close()
}
c.stmtCache = nil
c.stmtCacheMu.Unlock()

err = c.rawDb.Close()
c.rawDb = nil
c.db = nil
return err
}

func (c *C1File) getOrPrepare(ctx context.Context, query string) (*sql.Stmt, error) {
c.stmtCacheMu.Lock()
if c.stmtCache == nil {
c.stmtCacheMu.Unlock()
return nil, ErrDbNotOpen
}
if stmt, ok := c.stmtCache[query]; ok {
c.stmtCacheMu.Unlock()
return stmt, nil
}
c.stmtCacheMu.Unlock()

stmt, err := c.rawDb.PrepareContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("getOrPrepare: %w", err)
}

c.stmtCacheMu.Lock()
if c.stmtCache == nil {
c.stmtCacheMu.Unlock()
stmt.Close()
return nil, ErrDbNotOpen
}
if existing, ok := c.stmtCache[query]; ok {
c.stmtCacheMu.Unlock()
stmt.Close()
return existing, nil
}
c.stmtCache[query] = stmt
Copy link
Copy Markdown
Contributor

@alan-lee-12 alan-lee-12 May 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[🤖] P1: getOrPrepare can still panic on the cache-miss path if closeRawDB runs between preparing the new statement and re-acquiring stmtCacheMu. closeRawDB sets c.stmtCache = nil, so this write can become a write to a nil map. Please mirror the entry nil check after the second lock, close the newly prepared statement, and return ErrDbNotOpen before checking/storing the cache entry.

c.stmtCacheMu.Unlock()
Comment on lines +539 to +544
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Suggestion: The new nil guard at the top of getOrPrepare prevents a nil-map read on entry, but this second lock section can still write to a nil map if closeRawDB sets stmtCache = nil between the first unlock and this re-lock (the window exists because closeRawDB nils the cache before closing rawDb). Adding a nil check here would make the function fully safe against that shutdown race.

Suggested change
c.stmtCacheMu.Unlock()
stmt.Close()
return existing, nil
}
c.stmtCache[query] = stmt
c.stmtCacheMu.Unlock()
c.stmtCacheMu.Lock()
if c.stmtCache == nil {
c.stmtCacheMu.Unlock()
stmt.Close()
return nil, ErrDbNotOpen
}
if existing, ok := c.stmtCache[query]; ok {
c.stmtCacheMu.Unlock()
stmt.Close()
return existing, nil
}
c.stmtCache[query] = stmt
c.stmtCacheMu.Unlock()

return stmt, nil
}

// truncateWAL truncates the WAL file.
// Returns the busy, log, and checkpointed values.
func (c *C1File) truncateWAL(ctx context.Context) (int, int, int, error) {
Expand Down
6 changes: 5 additions & 1 deletion pkg/dotc1z/grants.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@ func listGrantsGeneric(ctx context.Context, c *C1File, req listRequest) ([]*v2.G
}

queryStart := time.Now()
rows, err := c.db.QueryContext(ctx, query, args...)
stmt, err := c.getOrPrepare(ctx, query)
if err != nil {
return nil, "", err
}
rows, err := stmt.QueryContext(ctx, args...)
if err != nil {
return nil, "", err
}
Expand Down
12 changes: 10 additions & 2 deletions pkg/dotc1z/grants_expandable_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ func (c *C1File) listExpandableGrantsInternal(
return nil, "", err
}

rows, err := c.db.QueryContext(ctx, query, args...)
stmt, err := c.getOrPrepare(ctx, query)
if err != nil {
return nil, "", err
}
rows, err := stmt.QueryContext(ctx, args...)
if err != nil {
return nil, "", err
}
Expand Down Expand Up @@ -196,7 +200,11 @@ func (c *C1File) listGrantsWithExpansionInternal(ctx context.Context, opts grant
return nil, err
}

rows, err := c.db.QueryContext(ctx, query, args...)
stmt, err := c.getOrPrepare(ctx, query)
if err != nil {
return nil, err
}
rows, err := stmt.QueryContext(ctx, args...)
if err != nil {
return nil, err
}
Expand Down
10 changes: 7 additions & 3 deletions pkg/dotc1z/grants_hydrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,14 @@ func hydrateSingleGrant(ctx context.Context, c *C1File, syncID string, g *v2.Gra
return nil
}

row := c.db.QueryRowContext(ctx, fmt.Sprintf(
q := fmt.Sprintf(
`SELECT entitlement_id, resource_type_id, resource_id, principal_resource_type_id, principal_resource_id
FROM %s WHERE external_id = ? AND sync_id = ?`, grants.Name(),
), g.GetId(), syncID)
FROM %s WHERE external_id = ? AND sync_id = ?`, grants.Name())
stmt, err := c.getOrPrepare(ctx, q)
if err != nil {
return err
}
row := stmt.QueryRowContext(ctx, g.GetId(), syncID)
var k grantJoinKeys
if err := row.Scan(
&k.EntitlementID,
Expand Down
20 changes: 16 additions & 4 deletions pkg/dotc1z/sql_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,12 @@ func listConnectorObjects[T proto.Message](ctx context.Context, c *C1File, table
// Start timing the query execution
queryStartTime := time.Now()

// Execute the query
rows, err := c.db.QueryContext(ctx, query, args...)
// Execute the query via prepared statement cache
stmt, err := c.getOrPrepare(ctx, query)
if err != nil {
return nil, "", err
}
rows, err := stmt.QueryContext(ctx, args...)
if err != nil {
return nil, "", err
}
Expand Down Expand Up @@ -640,7 +644,11 @@ func (c *C1File) getResourceObject(ctx context.Context, resourceID *v2.ResourceI
}

data := make([]byte, 0)
row := c.db.QueryRowContext(ctx, query, args...)
stmt, err := c.getOrPrepare(ctx, query)
if err != nil {
return err
}
row := stmt.QueryRowContext(ctx, args...)
err = row.Scan(&data)
if err != nil {
return err
Expand Down Expand Up @@ -700,7 +708,11 @@ func (c *C1File) getConnectorObject(ctx context.Context, tableName string, id st
}

var data []byte
row := c.db.QueryRowContext(ctx, query, args...)
stmt, err := c.getOrPrepare(ctx, query)
if err != nil {
return err
}
row := stmt.QueryRowContext(ctx, args...)
err = row.Scan(&data)
if err != nil {
return err
Expand Down
Loading