Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions extensions/tn_local/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tn_local

import (
"sync"
"sync/atomic"

"github.com/trufnetwork/kwil-db/core/log"
"github.com/trufnetwork/kwil-db/node/types/sql"
Expand All @@ -12,7 +13,7 @@ type Extension struct {
logger log.Logger
db sql.DB
localDB *LocalDB
isEnabled bool
isEnabled atomic.Bool
}

var (
Expand All @@ -26,8 +27,7 @@ var (
func GetExtension() *Extension {
once.Do(func() {
extensionInstance = &Extension{
logger: log.New(log.WithLevel(log.LevelInfo)),
isEnabled: false,
logger: log.New(log.WithLevel(log.LevelInfo)),
}
})
return extensionInstance
Expand All @@ -41,7 +41,7 @@ func (e *Extension) configure(logger log.Logger, db sql.DB, localDB *LocalDB) {
e.logger = logger
e.db = db
e.localDB = localDB
e.isEnabled = true
e.isEnabled.Store(true)
}

// Close closes the extension's connection pool.
Expand Down
39 changes: 39 additions & 0 deletions extensions/tn_local/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package tn_local

import (
"context"

jsonrpc "github.com/trufnetwork/kwil-db/core/rpc/json"
)

// Handler stubs — implementations will be added in Tasks 3-6.

// CreateStream creates a local stream. (Task 3)
func (ext *Extension) CreateStream(ctx context.Context, req *CreateStreamRequest) (*CreateStreamResponse, *jsonrpc.Error) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "not implemented", nil)
}

// InsertRecords inserts records into a local primitive stream. (Task 4)
func (ext *Extension) InsertRecords(ctx context.Context, req *InsertRecordsRequest) (*InsertRecordsResponse, *jsonrpc.Error) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "not implemented", nil)
}

// InsertTaxonomy adds a taxonomy entry to a local composed stream. (Task 5)
func (ext *Extension) InsertTaxonomy(ctx context.Context, req *InsertTaxonomyRequest) (*InsertTaxonomyResponse, *jsonrpc.Error) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "not implemented", nil)
}

// GetRecord queries records from a local primitive stream. (Task 6)
func (ext *Extension) GetRecord(ctx context.Context, req *GetRecordRequest) (*GetRecordResponse, *jsonrpc.Error) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "not implemented", nil)
}

// GetIndex queries computed index values from a local stream. (Task 6)
func (ext *Extension) GetIndex(ctx context.Context, req *GetIndexRequest) (*GetIndexResponse, *jsonrpc.Error) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "not implemented", nil)
}

// ListStreams lists all local streams. (Task 6)
func (ext *Extension) ListStreams(ctx context.Context, req *ListStreamsRequest) (*ListStreamsResponse, *jsonrpc.Error) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "not implemented", nil)
}
31 changes: 31 additions & 0 deletions extensions/tn_local/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package tn_local

import (
"context"
"encoding/json"

jsonrpc "github.com/trufnetwork/kwil-db/core/rpc/json"
rpcserver "github.com/trufnetwork/kwil-db/node/services/jsonrpc"
)

// Name implements rpcserver.Svc.
func (ext *Extension) Name() string { return ServiceName }

// Methods implements rpcserver.Svc.
func (ext *Extension) Methods() map[jsonrpc.Method]rpcserver.MethodDef {
return map[jsonrpc.Method]rpcserver.MethodDef{
"local.create_stream": rpcserver.MakeMethodDef(ext.CreateStream, "create a local stream", ""),
"local.insert_records": rpcserver.MakeMethodDef(ext.InsertRecords, "insert records into local stream", "count"),
"local.insert_taxonomy": rpcserver.MakeMethodDef(ext.InsertTaxonomy, "add taxonomy to local composed stream", ""),
"local.get_record": rpcserver.MakeMethodDef(ext.GetRecord, "query local stream records", "records"),
"local.get_index": rpcserver.MakeMethodDef(ext.GetIndex, "query local stream index", "records"),
"local.list_streams": rpcserver.MakeMethodDef(ext.ListStreams, "list all local streams", "streams"),
}
}

// Health implements rpcserver.Svc.
func (ext *Extension) Health(ctx context.Context) (json.RawMessage, bool) {
enabled := ext.isEnabled.Load()
resp, _ := json.Marshal(struct{ Enabled bool }{enabled})
return resp, enabled
}
114 changes: 114 additions & 0 deletions extensions/tn_local/tn_local.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package tn_local

import (
"context"
"fmt"
"time"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/trufnetwork/kwil-db/common"
"github.com/trufnetwork/kwil-db/core/log"
"github.com/trufnetwork/kwil-db/extensions/hooks"
rpcserver "github.com/trufnetwork/kwil-db/node/services/jsonrpc"
)

// InitializeExtension registers the tn_local hooks.
// Called from extensions/register.go during init().
func InitializeExtension() {
err := hooks.RegisterEngineReadyHook("tn_local_engine_ready", engineReadyHook)
if err != nil {
panic(fmt.Sprintf("failed to register tn_local engine ready hook: %v", err))
}

err = hooks.RegisterAdminServerHook("tn_local_admin", adminServerHook)
if err != nil {
panic(fmt.Sprintf("failed to register tn_local admin server hook: %v", err))
}
}

// adminServerHook registers the local storage Svc on the admin JSON-RPC server.
func adminServerHook(server *rpcserver.Server) error {
ext := GetExtension()
server.RegisterSvc(ext)
return nil
}

// engineReadyHook initializes the extension's database and schema.
func engineReadyHook(ctx context.Context, app *common.App) error {
logger := app.Service.Logger.New("tn_local")

var localDB *LocalDB
if testDB := getTestDB(); testDB != nil {
localDB = NewLocalDB(testDB, logger)
} else {
pool, err := createIndependentConnectionPool(ctx, app.Service, logger)
if err != nil {
return fmt.Errorf("failed to create connection pool: %w", err)
}

// Close pool on any subsequent failure to prevent connection leak.
success := false
defer func() {
if !success {
pool.Close()
}
}()

wrapper := NewPoolDBWrapper(pool)
localDB = NewLocalDB(wrapper, logger)

if err := localDB.SetupSchema(ctx); err != nil {
return fmt.Errorf("failed to setup local schema: %w", err)
}

success = true
}

// Update existing singleton in-place to preserve the pointer registered
// with the admin server's RegisterSvc.
ext := GetExtension()
ext.configure(logger, localDB.db, localDB)

logger.Info("tn_local extension initialized")
return nil
}

// createIndependentConnectionPool creates a dedicated connection pool for local storage.
func createIndependentConnectionPool(ctx context.Context, service *common.Service, logger log.Logger) (*pgxpool.Pool, error) {
dbConfig := service.LocalConfig.DB

connStr := fmt.Sprintf("host=%s port=%s user=%s database=%s sslmode=disable",
dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.DBName)

if dbConfig.Pass != "" {
connStr += " password=" + dbConfig.Pass
}

poolConfig, err := pgxpool.ParseConfig(connStr)
if err != nil {
return nil, fmt.Errorf("parse pool config: %w", err)
}

poolConfig.MaxConns = 10
poolConfig.MinConns = 2
poolConfig.MaxConnLifetime = 30 * time.Minute
poolConfig.MaxConnIdleTime = 5 * time.Minute

pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("create pool: %w", err)
}

conn, err := pool.Acquire(ctx)
if err != nil {
pool.Close()
return nil, fmt.Errorf("test connection: %w", err)
}
conn.Release()

logger.Info("created independent connection pool for local storage",
"max_conns", poolConfig.MaxConns,
"min_conns", poolConfig.MinConns)

return pool, nil
}
144 changes: 144 additions & 0 deletions extensions/tn_local/tn_local_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package tn_local

import (
"context"
"io"
"strings"
"sync"
"testing"

"github.com/stretchr/testify/require"
"github.com/trufnetwork/kwil-db/core/log"
jsonrpc "github.com/trufnetwork/kwil-db/core/rpc/json"
kwilsql "github.com/trufnetwork/kwil-db/node/types/sql"
"github.com/trufnetwork/node/tests/utils"
)

func TestSetupSchema(t *testing.T) {
var statements []string
mockTx := &utils.MockTx{
ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) {
statements = append(statements, stmt)
return &kwilsql.ResultSet{}, nil
},
}
mockDB := &utils.MockDB{
BeginTxFn: func(ctx context.Context) (kwilsql.Tx, error) {
return mockTx, nil
},
}

logger := log.New(log.WithWriter(io.Discard))
localDB := NewLocalDB(mockDB, logger)

err := localDB.SetupSchema(context.Background())
require.NoError(t, err)

// Verify schema creation
require.True(t, containsSQL(statements, "CREATE SCHEMA IF NOT EXISTS "+SchemaName),
"should create schema")

// Verify streams table
require.True(t, containsSQL(statements, SchemaName+".streams"),
"should create streams table")
require.True(t, containsSQL(statements, "data_provider TEXT NOT NULL"),
"streams should have data_provider column")
require.True(t, containsSQL(statements, "stream_type TEXT NOT NULL"),
"streams should have stream_type column")

// Verify primitive_events table
require.True(t, containsSQL(statements, SchemaName+".primitive_events"),
"should create primitive_events table")
require.True(t, containsSQL(statements, "NUMERIC(36,18)"),
"primitive_events should use NUMERIC(36,18) for value")

// Verify primitive_events index
require.True(t, containsSQL(statements, "local_pe_stream_time_idx"),
"should create primitive_events index")

// Verify taxonomies table
require.True(t, containsSQL(statements, SchemaName+".taxonomies"),
"should create taxonomies table")
require.True(t, containsSQL(statements, "taxonomy_id UUID PRIMARY KEY"),
"taxonomies should have UUID primary key")
}

func TestSetupSchema_RollbackOnError(t *testing.T) {
rolledBack := false
mockTx := &utils.MockTx{
ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) {
if strings.Contains(stmt, "streams") {
return nil, context.DeadlineExceeded
}
return &kwilsql.ResultSet{}, nil
},
RollbackFn: func(ctx context.Context) error {
rolledBack = true
return nil
},
}
mockDB := &utils.MockDB{
BeginTxFn: func(ctx context.Context) (kwilsql.Tx, error) {
return mockTx, nil
},
}

logger := log.New(log.WithWriter(io.Discard))
localDB := NewLocalDB(mockDB, logger)

err := localDB.SetupSchema(context.Background())
require.Error(t, err)
require.True(t, rolledBack, "transaction should be rolled back on error")
}

func TestExtensionSingleton(t *testing.T) {
// Reset for test isolation — never copy sync.Once (contains mutex)
prev := extensionInstance
extensionInstance = nil
once = sync.Once{}
t.Cleanup(func() {
extensionInstance = prev
once = sync.Once{}
if prev != nil {
once.Do(func() {}) // mark as done since instance already exists
}
})

ext1 := GetExtension()
ext2 := GetExtension()
require.Same(t, ext1, ext2, "GetExtension should return same instance")
require.False(t, ext1.isEnabled.Load(), "default extension should be disabled")

// configure updates the existing instance in-place (preserves pointer identity)
ext1.configure(ext1.logger, nil, nil)
require.True(t, ext1.isEnabled.Load())
require.Same(t, ext1, GetExtension(), "still same pointer after configure")
}

func TestServiceInterface(t *testing.T) {
ext := &Extension{}
ext.isEnabled.Store(true)

require.Equal(t, ServiceName, ext.Name())

methods := ext.Methods()
require.Contains(t, methods, jsonrpc.Method("local.create_stream"))
require.Contains(t, methods, jsonrpc.Method("local.insert_records"))
require.Contains(t, methods, jsonrpc.Method("local.insert_taxonomy"))
require.Contains(t, methods, jsonrpc.Method("local.get_record"))
require.Contains(t, methods, jsonrpc.Method("local.get_index"))
require.Contains(t, methods, jsonrpc.Method("local.list_streams"))

health, ok := ext.Health(context.Background())
require.True(t, ok)
require.NotNil(t, health)
}

func containsSQL(statements []string, substr string) bool {
for _, s := range statements {
if strings.Contains(s, substr) {
return true
}
}
return false
}
Loading
Loading