diff --git a/extensions/tn_local/db_ops.go b/extensions/tn_local/db_ops.go index 90455061..e173e330 100644 --- a/extensions/tn_local/db_ops.go +++ b/extensions/tn_local/db_ops.go @@ -3,6 +3,7 @@ package tn_local import ( "context" "fmt" + "time" "github.com/trufnetwork/kwil-db/core/log" "github.com/trufnetwork/kwil-db/node/types/sql" @@ -19,6 +20,15 @@ func NewLocalDB(db sql.DB, logger log.Logger) *LocalDB { return &LocalDB{db: db, logger: logger} } +// dbCreateStream inserts a new stream into ext_tn_local.streams. +func (ext *Extension) dbCreateStream(ctx context.Context, dataProvider, streamID, streamType string) error { + _, err := ext.db.Execute(ctx, fmt.Sprintf( + `INSERT INTO %s.streams (data_provider, stream_id, stream_type, created_at) + VALUES ($1, $2, $3, $4)`, SchemaName), + dataProvider, streamID, streamType, time.Now().Unix()) + return err +} + // SetupSchema creates the ext_tn_local schema and all tables within a single transaction. func (l *LocalDB) SetupSchema(ctx context.Context) error { l.logger.Info("setting up local storage schema") diff --git a/extensions/tn_local/handlers.go b/extensions/tn_local/handlers.go index aafd1e5f..28c6f919 100644 --- a/extensions/tn_local/handlers.go +++ b/extensions/tn_local/handlers.go @@ -2,15 +2,90 @@ package tn_local import ( "context" + "errors" + "fmt" + "regexp" + "strings" + "github.com/jackc/pgx/v5/pgconn" jsonrpc "github.com/trufnetwork/kwil-db/core/rpc/json" ) -// Handler stubs — implementations will be added in Tasks 3-6. +// pgUniqueViolation is the PostgreSQL error code for unique_violation. +const pgUniqueViolation = "23505" -// CreateStream creates a local stream. (Task 3) +var ethAddrRegex = regexp.MustCompile(`^0x[0-9a-fA-F]{40}$`) + +// validateStreamID checks that stream_id is 32 chars and starts with "st". +func validateStreamID(streamID string) error { + if len(streamID) != 32 { + return fmt.Errorf("stream_id must be exactly 32 characters, got %d", len(streamID)) + } + if !strings.HasPrefix(streamID, "st") { + return fmt.Errorf("stream_id must start with 'st'") + } + return nil +} + +// validateStreamType checks that stream_type is "primitive" or "composed". +func validateStreamType(streamType string) error { + if streamType != "primitive" && streamType != "composed" { + return fmt.Errorf("stream_type must be 'primitive' or 'composed', got %q", streamType) + } + return nil +} + +// validateDataProvider checks that data_provider is a valid Ethereum address (0x + 40 hex). +func validateDataProvider(dataProvider string) error { + if !ethAddrRegex.MatchString(dataProvider) { + return fmt.Errorf("data_provider must be a valid Ethereum address (0x + 40 hex chars)") + } + return nil +} + +// isDuplicateKeyError checks if err is a PostgreSQL unique constraint violation (23505). +func isDuplicateKeyError(err error) bool { + if err == nil { + return false + } + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return pgErr.Code == pgUniqueViolation + } + // Fallback for non-pgx drivers (e.g. mocks): case-insensitive string match. + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "duplicate key") || strings.Contains(msg, "unique constraint") +} + +// CreateStream creates a local stream. func (ext *Extension) CreateStream(ctx context.Context, req *CreateStreamRequest) (*CreateStreamResponse, *jsonrpc.Error) { - return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "not implemented", nil) + if req == nil { + return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, "missing request", nil) + } + + // Normalize data_provider to lowercase to match consensus behavior + // (consensus uses LOWER() in 001-common-actions.sql before insertion). + dataProvider := strings.ToLower(req.DataProvider) + + if err := validateDataProvider(dataProvider); err != nil { + return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, err.Error(), nil) + } + if err := validateStreamID(req.StreamID); err != nil { + return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, err.Error(), nil) + } + if err := validateStreamType(req.StreamType); err != nil { + return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, err.Error(), nil) + } + + if err := ext.dbCreateStream(ctx, dataProvider, req.StreamID, req.StreamType); err != nil { + if isDuplicateKeyError(err) { + return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, fmt.Sprintf("stream already exists: %s/%s", dataProvider, req.StreamID), nil) + } + ext.logger.Error("failed to create local stream", "error", err) + return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "failed to create stream", nil) + } + + return &CreateStreamResponse{}, nil } // InsertRecords inserts records into a local primitive stream. (Task 4) diff --git a/extensions/tn_local/tn_local_test.go b/extensions/tn_local/tn_local_test.go index 6b6795ea..ffd5eb4e 100644 --- a/extensions/tn_local/tn_local_test.go +++ b/extensions/tn_local/tn_local_test.go @@ -2,11 +2,13 @@ package tn_local import ( "context" + "fmt" "io" "strings" "sync" "testing" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" "github.com/trufnetwork/kwil-db/core/log" jsonrpc "github.com/trufnetwork/kwil-db/core/rpc/json" @@ -134,6 +136,197 @@ func TestServiceInterface(t *testing.T) { require.NotNil(t, health) } +// newTestExtension creates an Extension with a mock DB for handler tests. +func newTestExtension(db kwilsql.DB) *Extension { + ext := &Extension{ + logger: log.New(log.WithWriter(io.Discard)), + db: db, + } + ext.isEnabled.Store(true) + return ext +} + +func TestCreateStream_NilRequest(t *testing.T) { + ext := newTestExtension(&utils.MockDB{}) + + resp, rpcErr := ext.CreateStream(context.Background(), nil) + require.Nil(t, resp) + require.NotNil(t, rpcErr) + require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code) + require.Contains(t, rpcErr.Message, "missing request") +} + +func TestCreateStream_Success(t *testing.T) { + var capturedStmt string + var capturedArgs []any + mockDB := &utils.MockDB{ + ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) { + capturedStmt = stmt + capturedArgs = args + return &kwilsql.ResultSet{}, nil + }, + } + ext := newTestExtension(mockDB) + + resp, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{ + DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832", + StreamID: "st00000000000000000000000000test", + StreamType: "primitive", + }) + + require.Nil(t, rpcErr, "expected no error") + require.NotNil(t, resp) + require.Contains(t, capturedStmt, "INSERT INTO "+SchemaName+".streams") + require.Len(t, capturedArgs, 4, "INSERT should have 4 parameters") + // data_provider should be lowercased (matching consensus behavior) + require.Equal(t, "0xec36224a679218ae28fcece8d3c68595b87dd832", capturedArgs[0]) + require.Equal(t, "st00000000000000000000000000test", capturedArgs[1]) + require.Equal(t, "primitive", capturedArgs[2]) + // created_at should be a non-zero unix timestamp + createdAt, ok := capturedArgs[3].(int64) + require.True(t, ok, "created_at should be int64") + require.NotZero(t, createdAt, "created_at should be non-zero") +} + +func TestCreateStream_ComposedType(t *testing.T) { + mockDB := &utils.MockDB{ + ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) { + return &kwilsql.ResultSet{}, nil + }, + } + ext := newTestExtension(mockDB) + + resp, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{ + DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832", + StreamID: "st00000000000000000000000000test", + StreamType: "composed", + }) + + require.Nil(t, rpcErr) + require.NotNil(t, resp) +} + +func TestCreateStream_InvalidStreamID(t *testing.T) { + ext := newTestExtension(&utils.MockDB{}) + + tests := []struct { + name string + streamID string + wantMsg string + }{ + {"too short", "st00", "must be exactly 32 characters"}, + {"too long", "st000000000000000000000000000test1", "must be exactly 32 characters"}, + {"wrong prefix", "xx00000000000000000000000000test", "must start with 'st'"}, + {"empty", "", "must be exactly 32 characters"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{ + DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832", + StreamID: tt.streamID, + StreamType: "primitive", + }) + require.NotNil(t, rpcErr) + require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code) + require.Contains(t, rpcErr.Message, tt.wantMsg) + }) + } +} + +func TestCreateStream_InvalidStreamType(t *testing.T) { + ext := newTestExtension(&utils.MockDB{}) + + _, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{ + DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832", + StreamID: "st00000000000000000000000000test", + StreamType: "invalid", + }) + require.NotNil(t, rpcErr) + require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code) + require.Contains(t, rpcErr.Message, "must be 'primitive' or 'composed'") +} + +func TestCreateStream_InvalidDataProvider(t *testing.T) { + ext := newTestExtension(&utils.MockDB{}) + + tests := []struct { + name string + dataProvider string + }{ + {"no 0x prefix", "EC36224A679218Ae28FCeCe8d3c68595B87Dd832"}, + {"too short", "0xEC36224A679218Ae28"}, + {"too long", "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832FF"}, + {"invalid chars", "0xGG36224A679218Ae28FCeCe8d3c68595B87Dd832"}, + {"empty", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{ + DataProvider: tt.dataProvider, + StreamID: "st00000000000000000000000000test", + StreamType: "primitive", + }) + require.NotNil(t, rpcErr) + require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code) + require.Contains(t, rpcErr.Message, "data_provider must be a valid Ethereum address") + }) + } +} + +func TestCreateStream_DuplicateStream(t *testing.T) { + mockDB := &utils.MockDB{ + ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) { + return nil, fmt.Errorf("duplicate key value violates unique constraint") + }, + } + ext := newTestExtension(mockDB) + + _, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{ + DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832", + StreamID: "st00000000000000000000000000test", + StreamType: "primitive", + }) + require.NotNil(t, rpcErr) + require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code) + require.Contains(t, rpcErr.Message, "stream already exists") +} + +func TestCreateStream_DuplicateStream_PgError(t *testing.T) { + mockDB := &utils.MockDB{ + ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) { + return nil, &pgconn.PgError{Code: pgUniqueViolation, Message: "unique_violation"} + }, + } + ext := newTestExtension(mockDB) + + _, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{ + DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832", + StreamID: "st00000000000000000000000000test", + StreamType: "primitive", + }) + require.NotNil(t, rpcErr) + require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code) + require.Contains(t, rpcErr.Message, "stream already exists") +} + +func TestCreateStream_DBError(t *testing.T) { + mockDB := &utils.MockDB{ + ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) { + return nil, fmt.Errorf("connection refused") + }, + } + ext := newTestExtension(mockDB) + + _, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{ + DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832", + StreamID: "st00000000000000000000000000test", + StreamType: "primitive", + }) + require.NotNil(t, rpcErr) + require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInternal), rpcErr.Code) + require.Contains(t, rpcErr.Message, "failed to create stream") +} + func containsSQL(statements []string, substr string) bool { for _, s := range statements { if strings.Contains(s, substr) {