From 1945bb238cb122494b4e943513e269d8c3a6f2c3 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 4 Jan 2026 04:56:48 +0000 Subject: [PATCH 1/2] test: add comprehensive test coverage for model, daemon, and commands packages Add 18 new test files covering: Model package: - crypto_test.go: AES-GCM and RSA encryption services - command_test.go: Command serialization and comparison functions - db_test.go: Database operations and storage functions - path_test.go: File path helper functions - log_cleanup_test.go: Log file cleanup operations - shell_test.go: Shell hook services (bash, zsh, fish) - api.base_test.go: HTTP client and GraphQL request handling - heartbeat_test.go: Heartbeat data structures and JSON serialization - alias_test.go: Alias model and import/export types - cc_statusline_cache_test.go: CC statusline cache with TTL and concurrency Daemon package: - socket_test.go: Unix socket server and message handling - cleanup_timer_test.go: Cleanup timer service lifecycle - client_test.go: Daemon client socket communication - circuit_breaker_test.go: Sync circuit breaker wrapper - heartbeat_resync_test.go: Heartbeat resync service and file operations - handlers.heartbeat_test.go: Heartbeat file persistence handlers - aicode_otel_processor_test.go: OTEL processor and metric mapping Commands package: - utils_test.go: Path expansion and user path adjustment utilities --- commands/utils_test.go | 284 ++++++++++++ daemon/aicode_otel_processor_test.go | 521 ++++++++++++++++++++++ daemon/circuit_breaker_test.go | 162 +++++++ daemon/cleanup_timer_test.go | 208 +++++++++ daemon/client_test.go | 277 ++++++++++++ daemon/handlers.heartbeat_test.go | 236 ++++++++++ daemon/heartbeat_resync_test.go | 313 ++++++++++++++ daemon/socket_test.go | 364 ++++++++++++++++ model/alias_test.go | 297 +++++++++++++ model/api.base_test.go | 418 ++++++++++++++++++ model/cc_statusline_cache_test.go | 307 +++++++++++++ model/command_test.go | 621 +++++++++++++++++++++++++++ model/crypto_test.go | 415 ++++++++++++++++++ model/db_test.go | 568 ++++++++++++++++++++++++ model/heartbeat_test.go | 370 ++++++++++++++++ model/log_cleanup_test.go | 356 +++++++++++++++ model/path_test.go | 291 +++++++++++++ model/shell_test.go | 545 +++++++++++++++++++++++ 18 files changed, 6553 insertions(+) create mode 100644 commands/utils_test.go create mode 100644 daemon/aicode_otel_processor_test.go create mode 100644 daemon/circuit_breaker_test.go create mode 100644 daemon/cleanup_timer_test.go create mode 100644 daemon/client_test.go create mode 100644 daemon/handlers.heartbeat_test.go create mode 100644 daemon/heartbeat_resync_test.go create mode 100644 daemon/socket_test.go create mode 100644 model/alias_test.go create mode 100644 model/api.base_test.go create mode 100644 model/cc_statusline_cache_test.go create mode 100644 model/command_test.go create mode 100644 model/crypto_test.go create mode 100644 model/db_test.go create mode 100644 model/heartbeat_test.go create mode 100644 model/log_cleanup_test.go create mode 100644 model/path_test.go create mode 100644 model/shell_test.go diff --git a/commands/utils_test.go b/commands/utils_test.go new file mode 100644 index 0000000..db4c8fa --- /dev/null +++ b/commands/utils_test.go @@ -0,0 +1,284 @@ +package commands + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestExpandPath_Tilde(t *testing.T) { + homeDir, err := os.UserHomeDir() + if err != nil { + t.Skipf("Cannot get home dir: %v", err) + } + + testCases := []struct { + input string + expected string + }{ + {"~", homeDir}, + {"~/", homeDir}, + {"~/Documents", filepath.Join(homeDir, "Documents")}, + {"~/.config/shelltime", filepath.Join(homeDir, ".config/shelltime")}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result, err := expandPath(tc.input) + if err != nil { + t.Fatalf("expandPath(%q) failed: %v", tc.input, err) + } + if result != tc.expected { + t.Errorf("expandPath(%q) = %q, expected %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestExpandPath_Absolute(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"/usr/bin", "/usr/bin"}, + {"/tmp", "/tmp"}, + {"/home/user/file.txt", "/home/user/file.txt"}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result, err := expandPath(tc.input) + if err != nil { + t.Fatalf("expandPath(%q) failed: %v", tc.input, err) + } + if result != tc.expected { + t.Errorf("expandPath(%q) = %q, expected %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestExpandPath_Relative(t *testing.T) { + // Get current working directory + cwd, err := os.Getwd() + if err != nil { + t.Fatalf("Cannot get cwd: %v", err) + } + + testCases := []struct { + input string + expected string + }{ + {"file.txt", filepath.Join(cwd, "file.txt")}, + {"subdir/file.txt", filepath.Join(cwd, "subdir/file.txt")}, + {"./file.txt", filepath.Join(cwd, "file.txt")}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result, err := expandPath(tc.input) + if err != nil { + t.Fatalf("expandPath(%q) failed: %v", tc.input, err) + } + if result != tc.expected { + t.Errorf("expandPath(%q) = %q, expected %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestAdjustPathForCurrentUser_UsersPath(t *testing.T) { + homeDir, err := os.UserHomeDir() + if err != nil { + t.Skipf("Cannot get home dir: %v", err) + } + + testCases := []struct { + name string + input string + expected string + }{ + { + "Users path with 4 parts", + "/Users/someuser/Documents/file.txt", + homeDir + "/Documents/file.txt", + }, + { + "Users path with nested dirs", + "/Users/anotheruser/projects/app/src/main.go", + homeDir + "/projects/app/src/main.go", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := AdjustPathForCurrentUser(tc.input) + if result != tc.expected { + t.Errorf("AdjustPathForCurrentUser(%q) = %q, expected %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestAdjustPathForCurrentUser_HomePath(t *testing.T) { + homeDir, err := os.UserHomeDir() + if err != nil { + t.Skipf("Cannot get home dir: %v", err) + } + + testCases := []struct { + name string + input string + expected string + }{ + { + "home path with 4 parts", + "/home/someuser/Documents/file.txt", + homeDir + "/Documents/file.txt", + }, + { + "home path with nested dirs", + "/home/anotheruser/.config/app/config.yaml", + homeDir + "/.config/app/config.yaml", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := AdjustPathForCurrentUser(tc.input) + if result != tc.expected { + t.Errorf("AdjustPathForCurrentUser(%q) = %q, expected %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestAdjustPathForCurrentUser_RootPath(t *testing.T) { + homeDir, err := os.UserHomeDir() + if err != nil { + t.Skipf("Cannot get home dir: %v", err) + } + + testCases := []struct { + name string + input string + expected string + }{ + { + "root path", + "/root/.bashrc", + homeDir + "/.bashrc", + }, + { + "root path with subdir", + "/root/scripts/deploy.sh", + homeDir + "/scripts/deploy.sh", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := AdjustPathForCurrentUser(tc.input) + if result != tc.expected { + t.Errorf("AdjustPathForCurrentUser(%q) = %q, expected %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestAdjustPathForCurrentUser_NoMatch(t *testing.T) { + testCases := []struct { + name string + input string + }{ + {"absolute path", "/usr/bin/ls"}, + {"var path", "/var/log/messages"}, + {"tmp path", "/tmp/file.txt"}, + {"etc path", "/etc/hosts"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := AdjustPathForCurrentUser(tc.input) + // Should return unchanged + if result != tc.input { + t.Errorf("AdjustPathForCurrentUser(%q) = %q, expected unchanged", tc.input, result) + } + }) + } +} + +func TestAdjustPathForCurrentUser_ShortPaths(t *testing.T) { + testCases := []struct { + name string + input string + }{ + {"short Users path", "/Users/user"}, + {"short home path", "/home/user"}, + {"very short", "/home"}, + {"root only", "/"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Should not panic on short paths + result := AdjustPathForCurrentUser(tc.input) + // Just verify it doesn't panic and returns something + if result == "" { + t.Error("Result should not be empty") + } + }) + } +} + +func TestAdjustPathForCurrentUser_EmptyPath(t *testing.T) { + result := AdjustPathForCurrentUser("") + if result != "" { + t.Errorf("Expected empty string, got %q", result) + } +} + +func TestAdjustPathForCurrentUser_PreservesSubpath(t *testing.T) { + homeDir, err := os.UserHomeDir() + if err != nil { + t.Skipf("Cannot get home dir: %v", err) + } + + input := "/Users/someuser/very/deeply/nested/path/to/file.txt" + result := AdjustPathForCurrentUser(input) + + // Should preserve the subpath after username + expectedSuffix := "/very/deeply/nested/path/to/file.txt" + if !strings.HasSuffix(result, expectedSuffix) { + t.Errorf("Result should preserve subpath. Got: %q, expected suffix: %q", result, expectedSuffix) + } + + // Should start with home dir + if !strings.HasPrefix(result, homeDir) { + t.Errorf("Result should start with home dir. Got: %q, expected prefix: %q", result, homeDir) + } +} + +func TestExpandPath_EmptyString(t *testing.T) { + // Empty string should be expanded to current directory + cwd, _ := os.Getwd() + result, err := expandPath("") + if err != nil { + t.Fatalf("expandPath(\"\") failed: %v", err) + } + if result != cwd { + t.Errorf("expandPath(\"\") = %q, expected %q", result, cwd) + } +} + +func TestExpandPath_SingleDot(t *testing.T) { + cwd, _ := os.Getwd() + result, err := expandPath(".") + if err != nil { + t.Fatalf("expandPath(\".\") failed: %v", err) + } + if result != cwd { + t.Errorf("expandPath(\".\") = %q, expected %q", result, cwd) + } +} diff --git a/daemon/aicode_otel_processor_test.go b/daemon/aicode_otel_processor_test.go new file mode 100644 index 0000000..8f6e265 --- /dev/null +++ b/daemon/aicode_otel_processor_test.go @@ -0,0 +1,521 @@ +package daemon + +import ( + "testing" + + "github.com/malamtime/cli/model" + commonv1 "go.opentelemetry.io/proto/otlp/common/v1" + resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1" +) + +func TestNewAICodeOtelProcessor(t *testing.T) { + config := model.ShellTimeConfig{ + Token: "test-token", + APIEndpoint: "http://localhost:8080", + } + + processor := NewAICodeOtelProcessor(config) + if processor == nil { + t.Fatal("NewAICodeOtelProcessor returned nil") + } + + if processor.endpoint.Token != "test-token" { + t.Errorf("Token mismatch") + } + if processor.endpoint.APIEndpoint != "http://localhost:8080" { + t.Errorf("APIEndpoint mismatch") + } + if processor.hostname == "" { + t.Error("hostname should not be empty") + } +} + +func TestNewAICodeOtelProcessor_Debug(t *testing.T) { + debug := true + config := model.ShellTimeConfig{ + Token: "token", + AICodeOtel: &model.AICodeOtelConfig{ + Debug: &debug, + }, + } + + processor := NewAICodeOtelProcessor(config) + if !processor.debug { + t.Error("debug should be true when configured") + } +} + +func TestDetectOtelSource(t *testing.T) { + testCases := []struct { + name string + serviceName string + expectedType string + }{ + {"claude code", "claude-code", model.AICodeOtelSourceClaudeCode}, + {"claude", "claude", model.AICodeOtelSourceClaudeCode}, + {"codex", "codex-cli", model.AICodeOtelSourceCodex}, + {"unknown", "vscode", ""}, + {"empty", "", ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resource := &resourcev1.Resource{ + Attributes: []*commonv1.KeyValue{ + { + Key: "service.name", + Value: &commonv1.AnyValue{ + Value: &commonv1.AnyValue_StringValue{StringValue: tc.serviceName}, + }, + }, + }, + } + + result := detectOtelSource(resource) + if result != tc.expectedType { + t.Errorf("Expected %s, got %s", tc.expectedType, result) + } + }) + } +} + +func TestDetectOtelSource_NilResource(t *testing.T) { + result := detectOtelSource(nil) + if result != "" { + t.Errorf("Expected empty string for nil resource, got %s", result) + } +} + +func TestExtractResourceAttributes(t *testing.T) { + resource := &resourcev1.Resource{ + Attributes: []*commonv1.KeyValue{ + {Key: "session.id", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "session-123"}}}, + {Key: "conversation.id", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "conv-456"}}}, + {Key: "app.version", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "1.0.0"}}}, + {Key: "organization.id", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "org-789"}}}, + {Key: "user.account_uuid", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "user-abc"}}}, + {Key: "terminal.type", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "terminal"}}}, + {Key: "os.type", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "linux"}}}, + {Key: "os.version", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "5.15"}}}, + {Key: "host.arch", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "amd64"}}}, + {Key: "user.id", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "user123"}}}, + {Key: "user.email", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "user@test.com"}}}, + {Key: "user.name", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "testuser"}}}, + {Key: "machine.name", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "workstation"}}}, + {Key: "team.id", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "team-xyz"}}}, + {Key: "pwd", Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "/home/user/project"}}}, + }, + } + + attrs := extractResourceAttributes(resource) + + if attrs.SessionID != "session-123" { + t.Errorf("SessionID mismatch") + } + if attrs.ConversationID != "conv-456" { + t.Errorf("ConversationID mismatch") + } + if attrs.AppVersion != "1.0.0" { + t.Errorf("AppVersion mismatch") + } + if attrs.OrganizationID != "org-789" { + t.Errorf("OrganizationID mismatch") + } + if attrs.UserAccountUUID != "user-abc" { + t.Errorf("UserAccountUUID mismatch") + } + if attrs.TerminalType != "terminal" { + t.Errorf("TerminalType mismatch") + } + if attrs.OSType != "linux" { + t.Errorf("OSType mismatch") + } + if attrs.OSVersion != "5.15" { + t.Errorf("OSVersion mismatch") + } + if attrs.HostArch != "amd64" { + t.Errorf("HostArch mismatch") + } + if attrs.UserID != "user123" { + t.Errorf("UserID mismatch") + } + if attrs.UserEmail != "user@test.com" { + t.Errorf("UserEmail mismatch") + } + if attrs.UserName != "testuser" { + t.Errorf("UserName mismatch") + } + if attrs.MachineName != "workstation" { + t.Errorf("MachineName mismatch") + } + if attrs.TeamID != "team-xyz" { + t.Errorf("TeamID mismatch") + } + if attrs.Pwd != "/home/user/project" { + t.Errorf("Pwd mismatch") + } +} + +func TestExtractResourceAttributes_NilResource(t *testing.T) { + attrs := extractResourceAttributes(nil) + if attrs == nil { + t.Fatal("Should return empty struct, not nil") + } + if attrs.SessionID != "" { + t.Error("SessionID should be empty") + } +} + +func TestMapMetricName_ClaudeCode(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"claude_code.session.count", model.AICodeMetricSessionCount}, + {"claude_code.token.usage", model.AICodeMetricTokenUsage}, + {"claude_code.cost.usage", model.AICodeMetricCostUsage}, + {"claude_code.lines_of_code.count", model.AICodeMetricLinesOfCodeCount}, + {"claude_code.commit.count", model.AICodeMetricCommitCount}, + {"claude_code.pull_request.count", model.AICodeMetricPullRequestCount}, + {"claude_code.active_time.total", model.AICodeMetricActiveTimeTotal}, + {"claude_code.code_edit_tool.decision", model.AICodeMetricCodeEditToolDecision}, + {"unknown.metric", ""}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result := mapMetricName(tc.input, model.AICodeOtelSourceClaudeCode) + if result != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, result) + } + }) + } +} + +func TestMapMetricName_Codex(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"codex.session.count", model.AICodeMetricSessionCount}, + {"codex.token.usage", model.AICodeMetricTokenUsage}, + {"codex.cost.usage", model.AICodeMetricCostUsage}, + {"codex.lines_of_code.count", model.AICodeMetricLinesOfCodeCount}, + {"codex.commit.count", model.AICodeMetricCommitCount}, + {"codex.pull_request.count", model.AICodeMetricPullRequestCount}, + {"codex.active_time.total", model.AICodeMetricActiveTimeTotal}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result := mapMetricName(tc.input, model.AICodeOtelSourceCodex) + if result != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, result) + } + }) + } +} + +func TestMapEventName_ClaudeCode(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"claude_code.user_prompt", model.AICodeEventUserPrompt}, + {"claude_code.tool_result", model.AICodeEventToolResult}, + {"claude_code.api_request", model.AICodeEventApiRequest}, + {"claude_code.api_error", model.AICodeEventApiError}, + {"claude_code.tool_decision", model.AICodeEventToolDecision}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result := mapEventName(tc.input, model.AICodeOtelSourceClaudeCode) + if result != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, result) + } + }) + } +} + +func TestMapEventName_Codex(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"codex.user_prompt", model.AICodeEventUserPrompt}, + {"codex.tool_result", model.AICodeEventToolResult}, + {"codex.api_request", model.AICodeEventApiRequest}, + {"codex.api_error", model.AICodeEventApiError}, + {"codex.exec_command", model.AICodeEventExecCommand}, + {"codex.conversation_starts", model.AICodeEventConversationStarts}, + {"codex.sse_event", model.AICodeEventSSEEvent}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result := mapEventName(tc.input, model.AICodeOtelSourceCodex) + if result != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, result) + } + }) + } +} + +func TestMapEventName_Unknown(t *testing.T) { + // Unknown events should return as-is + result := mapEventName("custom.event", "") + if result != "custom.event" { + t.Errorf("Unknown events should be returned as-is, got %s", result) + } +} + +func TestGetIntFromValue(t *testing.T) { + testCases := []struct { + name string + value *commonv1.AnyValue + expected int + }{ + { + "int value", + &commonv1.AnyValue{Value: &commonv1.AnyValue_IntValue{IntValue: 42}}, + 42, + }, + { + "string value", + &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "123"}}, + 123, + }, + { + "invalid string", + &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "not-a-number"}}, + 0, + }, + { + "empty string", + &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: ""}}, + 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := getIntFromValue(tc.value) + if result != tc.expected { + t.Errorf("Expected %d, got %d", tc.expected, result) + } + }) + } +} + +func TestGetBoolFromValue(t *testing.T) { + testCases := []struct { + name string + value *commonv1.AnyValue + expected bool + }{ + { + "bool true", + &commonv1.AnyValue{Value: &commonv1.AnyValue_BoolValue{BoolValue: true}}, + true, + }, + { + "bool false", + &commonv1.AnyValue{Value: &commonv1.AnyValue_BoolValue{BoolValue: false}}, + false, + }, + { + "string true", + &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "true"}}, + true, + }, + { + "string false", + &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "false"}}, + false, + }, + { + "invalid string", + &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "maybe"}}, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := getBoolFromValue(tc.value) + if result != tc.expected { + t.Errorf("Expected %v, got %v", tc.expected, result) + } + }) + } +} + +func TestGetFloatFromValue(t *testing.T) { + testCases := []struct { + name string + value *commonv1.AnyValue + expected float64 + }{ + { + "double value", + &commonv1.AnyValue{Value: &commonv1.AnyValue_DoubleValue{DoubleValue: 3.14}}, + 3.14, + }, + { + "string value", + &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "2.71"}}, + 2.71, + }, + { + "invalid string", + &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "not-a-float"}}, + 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := getFloatFromValue(tc.value) + if result != tc.expected { + t.Errorf("Expected %f, got %f", tc.expected, result) + } + }) + } +} + +func TestGetStringArrayFromValue(t *testing.T) { + t.Run("valid array", func(t *testing.T) { + value := &commonv1.AnyValue{ + Value: &commonv1.AnyValue_ArrayValue{ + ArrayValue: &commonv1.ArrayValue{ + Values: []*commonv1.AnyValue{ + {Value: &commonv1.AnyValue_StringValue{StringValue: "a"}}, + {Value: &commonv1.AnyValue_StringValue{StringValue: "b"}}, + {Value: &commonv1.AnyValue_StringValue{StringValue: "c"}}, + }, + }, + }, + } + + result := getStringArrayFromValue(value) + if len(result) != 3 { + t.Errorf("Expected 3 elements, got %d", len(result)) + } + if result[0] != "a" || result[1] != "b" || result[2] != "c" { + t.Errorf("Array content mismatch") + } + }) + + t.Run("nil array", func(t *testing.T) { + value := &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{StringValue: "not-array"}} + result := getStringArrayFromValue(value) + if result != nil { + t.Error("Expected nil for non-array value") + } + }) + + t.Run("empty array", func(t *testing.T) { + value := &commonv1.AnyValue{ + Value: &commonv1.AnyValue_ArrayValue{ + ArrayValue: &commonv1.ArrayValue{ + Values: []*commonv1.AnyValue{}, + }, + }, + } + result := getStringArrayFromValue(value) + if len(result) != 0 { + t.Errorf("Expected 0 elements, got %d", len(result)) + } + }) +} + +func TestApplyResourceAttributesToMetric(t *testing.T) { + attrs := &model.AICodeOtelResourceAttributes{ + SessionID: "session-1", + ConversationID: "conv-1", + UserAccountUUID: "user-1", + OrganizationID: "org-1", + TerminalType: "terminal", + AppVersion: "1.0.0", + OSType: "linux", + OSVersion: "5.15", + HostArch: "amd64", + UserID: "user123", + UserEmail: "test@test.com", + UserName: "testuser", + MachineName: "workstation", + TeamID: "team-1", + Pwd: "/home/user", + } + + metric := &model.AICodeOtelMetric{} + applyResourceAttributesToMetric(metric, attrs) + + if metric.SessionID != "session-1" { + t.Error("SessionID not applied") + } + if metric.ConversationID != "conv-1" { + t.Error("ConversationID not applied") + } + if metric.UserAccountUUID != "user-1" { + t.Error("UserAccountUUID not applied") + } + if metric.OrganizationID != "org-1" { + t.Error("OrganizationID not applied") + } + if metric.TerminalType != "terminal" { + t.Error("TerminalType not applied") + } + if metric.AppVersion != "1.0.0" { + t.Error("AppVersion not applied") + } + if metric.OSType != "linux" { + t.Error("OSType not applied") + } + if metric.OSVersion != "5.15" { + t.Error("OSVersion not applied") + } + if metric.HostArch != "amd64" { + t.Error("HostArch not applied") + } + if metric.UserID != "user123" { + t.Error("UserID not applied") + } + if metric.UserEmail != "test@test.com" { + t.Error("UserEmail not applied") + } + if metric.UserName != "testuser" { + t.Error("UserName not applied") + } + if metric.MachineName != "workstation" { + t.Error("MachineName not applied") + } + if metric.TeamID != "team-1" { + t.Error("TeamID not applied") + } + if metric.Pwd != "/home/user" { + t.Error("Pwd not applied") + } +} + +func TestApplyResourceAttributesToEvent(t *testing.T) { + attrs := &model.AICodeOtelResourceAttributes{ + SessionID: "session-1", + ConversationID: "conv-1", + UserAccountUUID: "user-1", + } + + event := &model.AICodeOtelEvent{} + applyResourceAttributesToEvent(event, attrs) + + if event.SessionID != "session-1" { + t.Error("SessionID not applied") + } + if event.ConversationID != "conv-1" { + t.Error("ConversationID not applied") + } + if event.UserAccountUUID != "user-1" { + t.Error("UserAccountUUID not applied") + } +} diff --git a/daemon/circuit_breaker_test.go b/daemon/circuit_breaker_test.go new file mode 100644 index 0000000..6d858eb --- /dev/null +++ b/daemon/circuit_breaker_test.go @@ -0,0 +1,162 @@ +package daemon + +import ( + "context" + "encoding/json" + "testing" + + "github.com/ThreeDotsLabs/watermill/message" +) + +// Mock publisher for testing +type mockPublisher struct { + publishedMessages []*message.Message + publishError error +} + +func (m *mockPublisher) Publish(topic string, messages ...*message.Message) error { + if m.publishError != nil { + return m.publishError + } + m.publishedMessages = append(m.publishedMessages, messages...) + return nil +} + +func (m *mockPublisher) Close() error { + return nil +} + +func TestNewSyncCircuitBreakerService(t *testing.T) { + publisher := &mockPublisher{} + + wrapper := NewSyncCircuitBreakerService(publisher) + if wrapper == nil { + t.Fatal("NewSyncCircuitBreakerService returned nil") + } + + if wrapper.CircuitBreakerService == nil { + t.Error("CircuitBreakerService should be initialized") + } + + // Check global variable was set + if syncCircuitBreaker == nil { + t.Error("Global syncCircuitBreaker should be set") + } +} + +func TestSyncCircuitBreakerWrapper_IsOpen(t *testing.T) { + publisher := &mockPublisher{} + wrapper := NewSyncCircuitBreakerService(publisher) + + // Initially should be closed (not open) + if wrapper.IsOpen() { + t.Error("Circuit breaker should be closed initially") + } +} + +func TestSyncCircuitBreakerWrapper_RecordSuccess(t *testing.T) { + publisher := &mockPublisher{} + wrapper := NewSyncCircuitBreakerService(publisher) + + // Should not panic + wrapper.RecordSuccess() + + // Circuit should still be closed + if wrapper.IsOpen() { + t.Error("Circuit breaker should remain closed after success") + } +} + +func TestSyncCircuitBreakerWrapper_RecordFailure(t *testing.T) { + publisher := &mockPublisher{} + wrapper := NewSyncCircuitBreakerService(publisher) + + // Should not panic + wrapper.RecordFailure() +} + +func TestSyncCircuitBreakerWrapper_SaveForRetry(t *testing.T) { + publisher := &mockPublisher{} + wrapper := NewSyncCircuitBreakerService(publisher) + + ctx := context.Background() + payload := map[string]string{"key": "value"} + + err := wrapper.SaveForRetry(ctx, payload) + if err != nil { + t.Fatalf("SaveForRetry failed: %v", err) + } +} + +func TestSyncCircuitBreakerWrapper_SaveForRetry_WrapsInSocketMessage(t *testing.T) { + publisher := &mockPublisher{} + wrapper := NewSyncCircuitBreakerService(publisher) + + ctx := context.Background() + payload := map[string]string{"test": "data"} + + err := wrapper.SaveForRetry(ctx, payload) + if err != nil { + t.Fatalf("SaveForRetry failed: %v", err) + } + + // The SaveForRetry should wrap the payload in a SocketMessage + // We can verify by checking what would be saved + socketMsg := SocketMessage{ + Type: SocketMessageTypeSync, + Payload: payload, + } + + _, err = json.Marshal(socketMsg) + if err != nil { + t.Fatalf("Failed to marshal wrapped message: %v", err) + } +} + +func TestDaemonCircuitBreaker_Interface(t *testing.T) { + // Verify SyncCircuitBreakerWrapper implements DaemonCircuitBreaker + var _ DaemonCircuitBreaker = &SyncCircuitBreakerWrapper{} +} + +func TestSyncCircuitBreakerWrapper_MultipleFailures(t *testing.T) { + publisher := &mockPublisher{} + wrapper := NewSyncCircuitBreakerService(publisher) + + // Record multiple failures + for i := 0; i < 10; i++ { + wrapper.RecordFailure() + } + + // Record a success + wrapper.RecordSuccess() + + // Should not panic during any of these operations +} + +func TestSyncCircuitBreakerWrapper_ConcurrentAccess(t *testing.T) { + publisher := &mockPublisher{} + wrapper := NewSyncCircuitBreakerService(publisher) + + done := make(chan bool, 10) + + // Concurrent failures + for i := 0; i < 5; i++ { + go func() { + wrapper.RecordFailure() + done <- true + }() + } + + // Concurrent successes + for i := 0; i < 5; i++ { + go func() { + wrapper.RecordSuccess() + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } +} diff --git a/daemon/cleanup_timer_test.go b/daemon/cleanup_timer_test.go new file mode 100644 index 0000000..f9fbbfd --- /dev/null +++ b/daemon/cleanup_timer_test.go @@ -0,0 +1,208 @@ +package daemon + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/malamtime/cli/model" +) + +func TestNewCleanupTimerService(t *testing.T) { + config := model.ShellTimeConfig{ + LogCleanup: model.LogCleanupConfig{ + ThresholdMB: 100, + }, + } + + service := NewCleanupTimerService(config) + if service == nil { + t.Fatal("NewCleanupTimerService returned nil") + } + + if service.config.LogCleanup.ThresholdMB != 100 { + t.Errorf("Expected ThresholdMB 100, got %d", service.config.LogCleanup.ThresholdMB) + } + + if service.stopChan == nil { + t.Error("stopChan should be initialized") + } +} + +func TestCleanupTimerService_StartStop(t *testing.T) { + config := model.ShellTimeConfig{ + LogCleanup: model.LogCleanupConfig{ + ThresholdMB: 100, + }, + } + + service := NewCleanupTimerService(config) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := service.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + if service.ticker == nil { + t.Error("Ticker should be initialized after Start") + } + + // Stop the service + service.Stop() + + // Give it a moment to stop + time.Sleep(50 * time.Millisecond) +} + +func TestCleanupTimerService_StopWithoutStart(t *testing.T) { + config := model.ShellTimeConfig{ + LogCleanup: model.LogCleanupConfig{ + ThresholdMB: 100, + }, + } + + service := NewCleanupTimerService(config) + + // This should not panic + defer func() { + if r := recover(); r != nil { + t.Errorf("Stop should not panic: %v", r) + } + }() + + service.Stop() +} + +func TestCleanupTimerService_ContextCancellation(t *testing.T) { + config := model.ShellTimeConfig{ + LogCleanup: model.LogCleanupConfig{ + ThresholdMB: 100, + }, + } + + service := NewCleanupTimerService(config) + + ctx, cancel := context.WithCancel(context.Background()) + + err := service.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + // Cancel context should trigger stop + cancel() + + // Give it time to process + time.Sleep(50 * time.Millisecond) +} + +func TestCleanupInterval_Constant(t *testing.T) { + expected := 24 * time.Hour + if CleanupInterval != expected { + t.Errorf("Expected CleanupInterval to be 24h, got %v", CleanupInterval) + } +} + +func TestCleanupTimerService_Cleanup(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := model.COMMAND_BASE_STORAGE_FOLDER + defer func() { + model.COMMAND_BASE_STORAGE_FOLDER = origBase + }() + + model.COMMAND_BASE_STORAGE_FOLDER = tempDir + + config := model.ShellTimeConfig{ + LogCleanup: model.LogCleanupConfig{ + ThresholdMB: 1, // 1MB threshold (1024 * 1024 bytes) + }, + } + + service := NewCleanupTimerService(config) + + // Create a log file larger than threshold + logFile := model.GetLogFilePath() + if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil { + t.Fatalf("Failed to create dir: %v", err) + } + + // Create a file larger than 1MB + largeContent := make([]byte, 2*1024*1024) // 2MB + for i := range largeContent { + largeContent[i] = 'x' + } + if err := os.WriteFile(logFile, largeContent, 0644); err != nil { + t.Fatalf("Failed to write log file: %v", err) + } + + // Verify file exists + if _, err := os.Stat(logFile); os.IsNotExist(err) { + t.Fatal("Log file was not created") + } + + // Run cleanup + ctx := context.Background() + service.cleanup(ctx) + + // Verify file was deleted (exceeded threshold) + if _, err := os.Stat(logFile); !os.IsNotExist(err) { + t.Error("Log file should have been deleted") + } +} + +func TestCleanupTimerService_CleanupBelowThreshold(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := model.COMMAND_BASE_STORAGE_FOLDER + defer func() { + model.COMMAND_BASE_STORAGE_FOLDER = origBase + }() + + model.COMMAND_BASE_STORAGE_FOLDER = tempDir + + config := model.ShellTimeConfig{ + LogCleanup: model.LogCleanupConfig{ + ThresholdMB: 10, // 10MB threshold + }, + } + + service := NewCleanupTimerService(config) + + // Create a small log file + logFile := model.GetLogFilePath() + if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil { + t.Fatalf("Failed to create dir: %v", err) + } + + smallContent := []byte("small log content") + if err := os.WriteFile(logFile, smallContent, 0644); err != nil { + t.Fatalf("Failed to write log file: %v", err) + } + + // Run cleanup + ctx := context.Background() + service.cleanup(ctx) + + // Verify file still exists (below threshold) + if _, err := os.Stat(logFile); os.IsNotExist(err) { + t.Error("Log file should NOT have been deleted (below threshold)") + } +} diff --git a/daemon/client_test.go b/daemon/client_test.go new file mode 100644 index 0000000..ca3bea4 --- /dev/null +++ b/daemon/client_test.go @@ -0,0 +1,277 @@ +package daemon + +import ( + "context" + "encoding/json" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/malamtime/cli/model" +) + +func TestIsSocketReady(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-client-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + socketPath := filepath.Join(tempDir, "test.sock") + ctx := context.Background() + + // Test non-existent socket + if IsSocketReady(ctx, socketPath) { + t.Error("Expected false for non-existent socket") + } + + // Create the socket file + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Failed to create socket: %v", err) + } + defer listener.Close() + + // Test existing socket + if !IsSocketReady(ctx, socketPath) { + t.Error("Expected true for existing socket") + } +} + +func TestSendLocalDataToSocket(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-client-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + socketPath := filepath.Join(tempDir, "test.sock") + + // Create a mock server + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Failed to create socket: %v", err) + } + defer listener.Close() + + // Start accepting connections in background + received := make(chan *SocketMessage, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + + var msg SocketMessage + decoder := json.NewDecoder(conn) + if err := decoder.Decode(&msg); err != nil { + return + } + received <- &msg + }() + + ctx := context.Background() + config := model.ShellTimeConfig{} + cursor := time.Now() + trackingData := []model.TrackingData{ + {Command: "ls -la"}, + } + meta := model.TrackingMetaData{ + OS: "linux", + } + + err = SendLocalDataToSocket(ctx, socketPath, config, cursor, trackingData, meta) + if err != nil { + t.Fatalf("SendLocalDataToSocket failed: %v", err) + } + + // Wait for message to be received + select { + case msg := <-received: + if msg.Type != SocketMessageTypeSync { + t.Errorf("Expected message type %s, got %s", SocketMessageTypeSync, msg.Type) + } + case <-time.After(1 * time.Second): + t.Error("Timeout waiting for message") + } +} + +func TestSendLocalDataToSocket_SocketNotExists(t *testing.T) { + ctx := context.Background() + config := model.ShellTimeConfig{} + cursor := time.Now() + + err := SendLocalDataToSocket(ctx, "/nonexistent/socket.sock", config, cursor, nil, model.TrackingMetaData{}) + if err == nil { + t.Error("Expected error when socket doesn't exist") + } +} + +func TestRequestCCInfo(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-client-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + socketPath := filepath.Join(tempDir, "test.sock") + + // Create a mock server + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Failed to create socket: %v", err) + } + defer listener.Close() + + // Start mock server + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + + // Read request + var msg SocketMessage + decoder := json.NewDecoder(conn) + if err := decoder.Decode(&msg); err != nil { + return + } + + // Send response + response := CCInfoResponse{ + TotalCostUSD: 5.50, + TimeRange: "today", + CachedAt: time.Now(), + } + encoder := json.NewEncoder(conn) + encoder.Encode(response) + }() + + // Give server time to start + time.Sleep(50 * time.Millisecond) + + response, err := RequestCCInfo(socketPath, CCInfoTimeRangeToday, 5*time.Second) + if err != nil { + t.Fatalf("RequestCCInfo failed: %v", err) + } + + if response.TotalCostUSD != 5.50 { + t.Errorf("Expected TotalCostUSD 5.50, got %f", response.TotalCostUSD) + } + if response.TimeRange != "today" { + t.Errorf("Expected TimeRange 'today', got %s", response.TimeRange) + } +} + +func TestRequestCCInfo_Timeout(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-client-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + socketPath := filepath.Join(tempDir, "test.sock") + + // Create a mock server that doesn't respond + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Failed to create socket: %v", err) + } + defer listener.Close() + + // Start mock server that accepts but doesn't respond + go func() { + conn, _ := listener.Accept() + if conn != nil { + // Keep connection open but don't respond + time.Sleep(10 * time.Second) + conn.Close() + } + }() + + // Give server time to start + time.Sleep(50 * time.Millisecond) + + _, err = RequestCCInfo(socketPath, CCInfoTimeRangeToday, 100*time.Millisecond) + if err == nil { + t.Error("Expected timeout error") + } +} + +func TestRequestCCInfo_SocketNotExists(t *testing.T) { + _, err := RequestCCInfo("/nonexistent/socket.sock", CCInfoTimeRangeToday, 1*time.Second) + if err == nil { + t.Error("Expected error when socket doesn't exist") + } +} + +func TestRequestCCInfo_AllTimeRanges(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-client-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + socketPath := filepath.Join(tempDir, "test.sock") + + testRanges := []CCInfoTimeRange{ + CCInfoTimeRangeToday, + CCInfoTimeRangeWeek, + CCInfoTimeRangeMonth, + } + + for _, timeRange := range testRanges { + t.Run(string(timeRange), func(t *testing.T) { + // Create a mock server + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Failed to create socket: %v", err) + } + defer listener.Close() + + // Start mock server + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + + var msg SocketMessage + decoder := json.NewDecoder(conn) + decoder.Decode(&msg) + + response := CCInfoResponse{ + TotalCostUSD: 1.0, + TimeRange: string(timeRange), + CachedAt: time.Now(), + } + encoder := json.NewEncoder(conn) + encoder.Encode(response) + }() + + time.Sleep(50 * time.Millisecond) + + response, err := RequestCCInfo(socketPath, timeRange, 5*time.Second) + if err != nil { + t.Fatalf("RequestCCInfo failed: %v", err) + } + + if response.TimeRange != string(timeRange) { + t.Errorf("Expected TimeRange %s, got %s", timeRange, response.TimeRange) + } + + // Clean up for next iteration + os.Remove(socketPath) + }) + } +} diff --git a/daemon/handlers.heartbeat_test.go b/daemon/handlers.heartbeat_test.go new file mode 100644 index 0000000..95fb76e --- /dev/null +++ b/daemon/handlers.heartbeat_test.go @@ -0,0 +1,236 @@ +package daemon + +import ( + "os" + "path/filepath" + "testing" + + "github.com/malamtime/cli/model" +) + +func TestSaveHeartbeatToFile(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Set HOME to temp dir + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", oldHome) + + // Create the .shelltime directory + shelltimeDir := filepath.Join(tempDir, ".shelltime") + if err := os.MkdirAll(shelltimeDir, 0755); err != nil { + t.Fatalf("Failed to create shelltime dir: %v", err) + } + + payload := model.HeartbeatPayload{ + Heartbeats: []model.HeartbeatData{ + { + HeartbeatID: "test-id-1", + Entity: "/path/to/file.go", + Time: 1234567890, + Project: "test-project", + }, + }, + } + + err = saveHeartbeatToFile(payload) + if err != nil { + t.Fatalf("saveHeartbeatToFile failed: %v", err) + } + + // Verify file was created + logFile := filepath.Join(shelltimeDir, "coding-heartbeat.data.log") + if _, err := os.Stat(logFile); os.IsNotExist(err) { + t.Error("Heartbeat log file was not created") + } + + // Verify content + content, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + if !contains(string(content), "test-id-1") { + t.Error("Log file should contain heartbeat ID") + } + + if !contains(string(content), "test-project") { + t.Error("Log file should contain project name") + } +} + +func TestSaveHeartbeatToFile_Append(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Set HOME to temp dir + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", oldHome) + + // Create the .shelltime directory + shelltimeDir := filepath.Join(tempDir, ".shelltime") + if err := os.MkdirAll(shelltimeDir, 0755); err != nil { + t.Fatalf("Failed to create shelltime dir: %v", err) + } + + // Save first heartbeat + payload1 := model.HeartbeatPayload{ + Heartbeats: []model.HeartbeatData{ + {HeartbeatID: "id-1", Entity: "file1.go", Time: 1234567890}, + }, + } + err = saveHeartbeatToFile(payload1) + if err != nil { + t.Fatalf("First saveHeartbeatToFile failed: %v", err) + } + + // Save second heartbeat + payload2 := model.HeartbeatPayload{ + Heartbeats: []model.HeartbeatData{ + {HeartbeatID: "id-2", Entity: "file2.go", Time: 1234567891}, + }, + } + err = saveHeartbeatToFile(payload2) + if err != nil { + t.Fatalf("Second saveHeartbeatToFile failed: %v", err) + } + + // Verify both are in the file + logFile := filepath.Join(shelltimeDir, "coding-heartbeat.data.log") + content, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + if !contains(string(content), "id-1") { + t.Error("Log file should contain first heartbeat ID") + } + + if !contains(string(content), "id-2") { + t.Error("Log file should contain second heartbeat ID") + } +} + +func TestSaveHeartbeatToFile_EmptyPayload(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Set HOME to temp dir + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", oldHome) + + // Create the .shelltime directory + shelltimeDir := filepath.Join(tempDir, ".shelltime") + if err := os.MkdirAll(shelltimeDir, 0755); err != nil { + t.Fatalf("Failed to create shelltime dir: %v", err) + } + + // Empty payload + payload := model.HeartbeatPayload{ + Heartbeats: []model.HeartbeatData{}, + } + + err = saveHeartbeatToFile(payload) + if err != nil { + t.Fatalf("saveHeartbeatToFile with empty payload failed: %v", err) + } + + // File should still be created (with empty heartbeats array) + logFile := filepath.Join(shelltimeDir, "coding-heartbeat.data.log") + content, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + if !contains(string(content), `"heartbeats":[]`) { + t.Error("Log file should contain empty heartbeats array") + } +} + +func TestSaveHeartbeatToFile_DirectoryNotExists(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Set HOME to a non-existent path + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", oldHome) + + // Don't create the .shelltime directory - saveHeartbeatToFile should fail + payload := model.HeartbeatPayload{ + Heartbeats: []model.HeartbeatData{ + {HeartbeatID: "test-id", Entity: "file.go", Time: 1234567890}, + }, + } + + err = saveHeartbeatToFile(payload) + if err == nil { + t.Error("Expected error when directory doesn't exist") + } +} + +func TestSaveHeartbeatToFile_MultipleHeartbeats(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Set HOME to temp dir + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", oldHome) + + // Create the .shelltime directory + shelltimeDir := filepath.Join(tempDir, ".shelltime") + if err := os.MkdirAll(shelltimeDir, 0755); err != nil { + t.Fatalf("Failed to create shelltime dir: %v", err) + } + + // Payload with multiple heartbeats + payload := model.HeartbeatPayload{ + Heartbeats: []model.HeartbeatData{ + {HeartbeatID: "id-1", Entity: "file1.go", Time: 1234567890, Project: "proj1"}, + {HeartbeatID: "id-2", Entity: "file2.go", Time: 1234567891, Project: "proj1"}, + {HeartbeatID: "id-3", Entity: "file3.go", Time: 1234567892, Project: "proj2"}, + }, + } + + err = saveHeartbeatToFile(payload) + if err != nil { + t.Fatalf("saveHeartbeatToFile failed: %v", err) + } + + // Verify all heartbeats are in the file + logFile := filepath.Join(shelltimeDir, "coding-heartbeat.data.log") + content, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + for _, hb := range payload.Heartbeats { + if !contains(string(content), hb.HeartbeatID) { + t.Errorf("Log file should contain heartbeat ID: %s", hb.HeartbeatID) + } + } +} diff --git a/daemon/heartbeat_resync_test.go b/daemon/heartbeat_resync_test.go new file mode 100644 index 0000000..156d4cb --- /dev/null +++ b/daemon/heartbeat_resync_test.go @@ -0,0 +1,313 @@ +package daemon + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/malamtime/cli/model" +) + +func TestNewHeartbeatResyncService(t *testing.T) { + config := model.ShellTimeConfig{} + + service := NewHeartbeatResyncService(config) + if service == nil { + t.Fatal("NewHeartbeatResyncService returned nil") + } + + if service.stopChan == nil { + t.Error("stopChan should be initialized") + } +} + +func TestHeartbeatResyncService_StartStop(t *testing.T) { + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := service.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + if service.ticker == nil { + t.Error("Ticker should be initialized after Start") + } + + // Stop the service + service.Stop() + + // Give it a moment to stop + time.Sleep(50 * time.Millisecond) +} + +func TestHeartbeatResyncService_StopWithoutStart(t *testing.T) { + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + // This should not panic + defer func() { + if r := recover(); r != nil { + t.Errorf("Stop should not panic: %v", r) + } + }() + + service.Stop() +} + +func TestHeartbeatResyncService_ContextCancellation(t *testing.T) { + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + ctx, cancel := context.WithCancel(context.Background()) + + err := service.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + // Cancel context should trigger stop + cancel() + + // Give it time to process + time.Sleep(50 * time.Millisecond) +} + +func TestHeartbeatResyncInterval_Constant(t *testing.T) { + expected := 30 * time.Minute + if HeartbeatResyncInterval != expected { + t.Errorf("Expected HeartbeatResyncInterval to be 30m, got %v", HeartbeatResyncInterval) + } +} + +func TestHeartbeatResyncService_RewriteLogFile_EmptyLines(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-resync-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + logFile := filepath.Join(tempDir, "heartbeat.log") + + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + // Test with empty lines - should remove file + err = service.rewriteLogFile(logFile, []string{}) + if err != nil { + t.Fatalf("rewriteLogFile failed: %v", err) + } + + // File should not exist + if _, err := os.Stat(logFile); !os.IsNotExist(err) { + t.Error("File should be removed when lines are empty") + } +} + +func TestHeartbeatResyncService_RewriteLogFile_WithLines(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-resync-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + logFile := filepath.Join(tempDir, "heartbeat.log") + + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + lines := []string{ + `{"heartbeats":[{"heartbeatId":"1"}]}`, + `{"heartbeats":[{"heartbeatId":"2"}]}`, + } + + err = service.rewriteLogFile(logFile, lines) + if err != nil { + t.Fatalf("rewriteLogFile failed: %v", err) + } + + // Verify file content + content, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + for _, line := range lines { + if !contains(string(content), line) { + t.Errorf("Expected file to contain: %s", line) + } + } +} + +func TestHeartbeatResyncService_RewriteLogFile_AtomicRename(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-resync-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + logFile := filepath.Join(tempDir, "heartbeat.log") + + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + lines := []string{`{"test":"data"}`} + + err = service.rewriteLogFile(logFile, lines) + if err != nil { + t.Fatalf("rewriteLogFile failed: %v", err) + } + + // Temp file should not exist after atomic rename + tempFile := logFile + ".tmp" + if _, err := os.Stat(tempFile); !os.IsNotExist(err) { + t.Error("Temp file should be removed after atomic rename") + } +} + +func TestHeartbeatResyncService_ResyncNoFile(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-resync-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Set HOME to temp dir + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", oldHome) + + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + // This should not panic when file doesn't exist + ctx := context.Background() + service.resync(ctx) +} + +func TestHeartbeatResyncService_ResyncEmptyFile(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-resync-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Set HOME to temp dir + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", oldHome) + + // Create empty heartbeat log file + logDir := filepath.Join(tempDir, ".shelltime") + if err := os.MkdirAll(logDir, 0755); err != nil { + t.Fatalf("Failed to create dir: %v", err) + } + + logFile := filepath.Join(logDir, "coding-heartbeat.data.log") + if err := os.WriteFile(logFile, []byte(""), 0644); err != nil { + t.Fatalf("Failed to create empty log file: %v", err) + } + + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + // This should not panic with empty file + ctx := context.Background() + service.resync(ctx) +} + +func TestHeartbeatResyncService_ResyncInvalidJSON(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-resync-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Set HOME to temp dir + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tempDir) + defer os.Setenv("HOME", oldHome) + + // Create heartbeat log file with invalid JSON + logDir := filepath.Join(tempDir, ".shelltime") + if err := os.MkdirAll(logDir, 0755); err != nil { + t.Fatalf("Failed to create dir: %v", err) + } + + logFile := filepath.Join(logDir, "coding-heartbeat.data.log") + invalidContent := "not valid json\n{also not valid}\n" + if err := os.WriteFile(logFile, []byte(invalidContent), 0644); err != nil { + t.Fatalf("Failed to create log file: %v", err) + } + + config := model.ShellTimeConfig{} + service := NewHeartbeatResyncService(config) + + // This should not panic with invalid JSON (lines are discarded) + ctx := context.Background() + service.resync(ctx) + + // File should be removed since all lines were invalid + if _, err := os.Stat(logFile); !os.IsNotExist(err) { + t.Error("File with all invalid lines should be removed") + } +} + +func TestHeartbeatPayload_JSON(t *testing.T) { + payload := model.HeartbeatPayload{ + Heartbeats: []model.HeartbeatData{ + { + HeartbeatID: "test-id-1", + Entity: "/path/to/file.go", + Time: time.Now().Unix(), + Project: "test-project", + }, + }, + } + + // Marshal + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal + var decoded model.HeartbeatPayload + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if len(decoded.Heartbeats) != 1 { + t.Errorf("Expected 1 heartbeat, got %d", len(decoded.Heartbeats)) + } + + if decoded.Heartbeats[0].HeartbeatID != "test-id-1" { + t.Errorf("HeartbeatID mismatch: expected test-id-1, got %s", decoded.Heartbeats[0].HeartbeatID) + } +} + +// Helper function +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/daemon/socket_test.go b/daemon/socket_test.go new file mode 100644 index 0000000..2eaf87f --- /dev/null +++ b/daemon/socket_test.go @@ -0,0 +1,364 @@ +package daemon + +import ( + "encoding/json" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/malamtime/cli/model" +) + +func TestNewSocketHandler(t *testing.T) { + config := &model.ShellTimeConfig{ + SocketPath: "/tmp/test-shelltime.sock", + } + ch := NewGoChannel() + + handler := NewSocketHandler(config, ch) + if handler == nil { + t.Fatal("NewSocketHandler returned nil") + } + + if handler.config != config { + t.Error("Config not properly set") + } + + if handler.channel != ch { + t.Error("Channel not properly set") + } + + if handler.stopChan == nil { + t.Error("stopChan should be initialized") + } + + if handler.ccInfoTimer == nil { + t.Error("ccInfoTimer should be initialized") + } +} + +func TestSocketHandler_StartStop(t *testing.T) { + // Create temp socket path + tempDir, err := os.MkdirTemp("", "shelltime-socket-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + socketPath := filepath.Join(tempDir, "test.sock") + config := &model.ShellTimeConfig{ + SocketPath: socketPath, + } + ch := NewGoChannel() + + handler := NewSocketHandler(config, ch) + + // Start the handler + err = handler.Start() + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + // Verify socket file was created + if _, err := os.Stat(socketPath); os.IsNotExist(err) { + t.Error("Socket file was not created") + } + + // Stop the handler + handler.Stop() + + // Give it a moment to clean up + time.Sleep(100 * time.Millisecond) + + // Verify socket file was removed + if _, err := os.Stat(socketPath); !os.IsNotExist(err) { + t.Error("Socket file should be removed after stop") + } +} + +func TestSocketHandler_StatusRequest(t *testing.T) { + // Create temp socket path + tempDir, err := os.MkdirTemp("", "shelltime-socket-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + socketPath := filepath.Join(tempDir, "test.sock") + config := &model.ShellTimeConfig{ + SocketPath: socketPath, + } + ch := NewGoChannel() + + handler := NewSocketHandler(config, ch) + + err = handler.Start() + if err != nil { + t.Fatalf("Start failed: %v", err) + } + defer handler.Stop() + + // Give server time to start + time.Sleep(50 * time.Millisecond) + + // Connect to socket + conn, err := net.Dial("unix", socketPath) + if err != nil { + t.Fatalf("Failed to connect to socket: %v", err) + } + defer conn.Close() + + // Send status request + msg := SocketMessage{ + Type: SocketMessageTypeStatus, + } + encoder := json.NewEncoder(conn) + if err := encoder.Encode(msg); err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + // Read response + var response StatusResponse + decoder := json.NewDecoder(conn) + if err := decoder.Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Verify response + if response.GoVersion == "" { + t.Error("GoVersion should not be empty") + } + if response.Platform == "" { + t.Error("Platform should not be empty") + } + if response.Uptime == "" { + t.Error("Uptime should not be empty") + } +} + +func TestSocketMessageType_Constants(t *testing.T) { + testCases := []struct { + msgType SocketMessageType + expected string + }{ + {SocketMessageTypeSync, "sync"}, + {SocketMessageTypeHeartbeat, "heartbeat"}, + {SocketMessageTypeStatus, "status"}, + {SocketMessageTypeCCInfo, "cc_info"}, + } + + for _, tc := range testCases { + if string(tc.msgType) != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, tc.msgType) + } + } +} + +func TestCCInfoTimeRange_Constants(t *testing.T) { + testCases := []struct { + timeRange CCInfoTimeRange + expected string + }{ + {CCInfoTimeRangeToday, "today"}, + {CCInfoTimeRangeWeek, "week"}, + {CCInfoTimeRangeMonth, "month"}, + } + + for _, tc := range testCases { + if string(tc.timeRange) != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, tc.timeRange) + } + } +} + +func TestFormatDuration(t *testing.T) { + testCases := []struct { + duration time.Duration + expected string + }{ + {5 * time.Second, "5s"}, + {65 * time.Second, "1m 5s"}, + {3665 * time.Second, "1h 1m 5s"}, + {90065 * time.Second, "1d 1h 1m 5s"}, + {0, "0s"}, + {30 * time.Minute, "30m 0s"}, + {2 * time.Hour, "2h 0m 0s"}, + {48 * time.Hour, "2d 0h 0m 0s"}, + } + + for _, tc := range testCases { + t.Run(tc.expected, func(t *testing.T) { + result := formatDuration(tc.duration) + if result != tc.expected { + t.Errorf("formatDuration(%v) = %s, expected %s", tc.duration, result, tc.expected) + } + }) + } +} + +func TestSocketMessage_JSON(t *testing.T) { + msg := SocketMessage{ + Type: SocketMessageTypeSync, + Payload: map[string]interface{}{ + "key": "value", + }, + } + + // Marshal + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal + var decoded SocketMessage + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Type != SocketMessageTypeSync { + t.Errorf("Expected type %s, got %s", SocketMessageTypeSync, decoded.Type) + } +} + +func TestStatusResponse_JSON(t *testing.T) { + response := StatusResponse{ + Version: "1.0.0", + StartedAt: time.Now(), + Uptime: "1h 30m 0s", + GoVersion: "go1.21.0", + Platform: "linux/amd64", + } + + // Marshal + data, err := json.Marshal(response) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal + var decoded StatusResponse + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Version != response.Version { + t.Errorf("Version mismatch: expected %s, got %s", response.Version, decoded.Version) + } + if decoded.GoVersion != response.GoVersion { + t.Errorf("GoVersion mismatch: expected %s, got %s", response.GoVersion, decoded.GoVersion) + } +} + +func TestCCInfoRequest_JSON(t *testing.T) { + request := CCInfoRequest{ + TimeRange: CCInfoTimeRangeToday, + } + + data, err := json.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded CCInfoRequest + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.TimeRange != CCInfoTimeRangeToday { + t.Errorf("TimeRange mismatch: expected %s, got %s", CCInfoTimeRangeToday, decoded.TimeRange) + } +} + +func TestCCInfoResponse_JSON(t *testing.T) { + now := time.Now() + response := CCInfoResponse{ + TotalCostUSD: 1.23, + TimeRange: "today", + CachedAt: now, + } + + data, err := json.Marshal(response) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded CCInfoResponse + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.TotalCostUSD != 1.23 { + t.Errorf("TotalCostUSD mismatch: expected 1.23, got %f", decoded.TotalCostUSD) + } + if decoded.TimeRange != "today" { + t.Errorf("TimeRange mismatch: expected today, got %s", decoded.TimeRange) + } +} + +func TestSocketHandler_MultipleConnections(t *testing.T) { + // Create temp socket path + tempDir, err := os.MkdirTemp("", "shelltime-socket-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + socketPath := filepath.Join(tempDir, "test.sock") + config := &model.ShellTimeConfig{ + SocketPath: socketPath, + } + ch := NewGoChannel() + + handler := NewSocketHandler(config, ch) + + err = handler.Start() + if err != nil { + t.Fatalf("Start failed: %v", err) + } + defer handler.Stop() + + // Give server time to start + time.Sleep(50 * time.Millisecond) + + // Make multiple concurrent connections + done := make(chan bool, 3) + + for i := 0; i < 3; i++ { + go func() { + conn, err := net.Dial("unix", socketPath) + if err != nil { + done <- false + return + } + defer conn.Close() + + msg := SocketMessage{Type: SocketMessageTypeStatus} + encoder := json.NewEncoder(conn) + encoder.Encode(msg) + + var response StatusResponse + decoder := json.NewDecoder(conn) + if err := decoder.Decode(&response); err != nil { + done <- false + return + } + + done <- response.Platform != "" + }() + } + + // Wait for all connections + successCount := 0 + for i := 0; i < 3; i++ { + if <-done { + successCount++ + } + } + + if successCount != 3 { + t.Errorf("Expected 3 successful connections, got %d", successCount) + } +} diff --git a/model/alias_test.go b/model/alias_test.go new file mode 100644 index 0000000..7a32db4 --- /dev/null +++ b/model/alias_test.go @@ -0,0 +1,297 @@ +package model + +import ( + "encoding/json" + "testing" +) + +func TestAlias_JSON(t *testing.T) { + alias := Alias{ + Name: "ll", + Value: "ls -la", + Shell: "bash", + } + + data, err := json.Marshal(alias) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded Alias + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Name != alias.Name { + t.Errorf("Name mismatch: expected %s, got %s", alias.Name, decoded.Name) + } + if decoded.Value != alias.Value { + t.Errorf("Value mismatch: expected %s, got %s", alias.Value, decoded.Value) + } + if decoded.Shell != alias.Shell { + t.Errorf("Shell mismatch: expected %s, got %s", alias.Shell, decoded.Shell) + } +} + +func TestImportShellAliasRequest_JSON(t *testing.T) { + req := importShellAliasRequest{ + Aliases: []string{"alias ll='ls -la'", "alias gs='git status'"}, + IsFullRefresh: true, + ShellType: "bash", + FileLocation: "~/.bashrc", + Hostname: "myhost", + Username: "myuser", + OS: "linux", + OSVersion: "ubuntu 22.04", + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded importShellAliasRequest + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if len(decoded.Aliases) != 2 { + t.Errorf("Expected 2 aliases, got %d", len(decoded.Aliases)) + } + if decoded.IsFullRefresh != true { + t.Error("IsFullRefresh should be true") + } + if decoded.ShellType != "bash" { + t.Errorf("ShellType mismatch: expected bash, got %s", decoded.ShellType) + } + if decoded.FileLocation != "~/.bashrc" { + t.Errorf("FileLocation mismatch") + } + if decoded.Hostname != "myhost" { + t.Errorf("Hostname mismatch") + } + if decoded.Username != "myuser" { + t.Errorf("Username mismatch") + } + if decoded.OS != "linux" { + t.Errorf("OS mismatch") + } + if decoded.OSVersion != "ubuntu 22.04" { + t.Errorf("OSVersion mismatch") + } +} + +func TestImportShellAliasResponse_JSON(t *testing.T) { + resp := importShellAliasResponse{ + Success: true, + Count: 5, + } + + data, err := json.Marshal(resp) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded importShellAliasResponse + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Success != true { + t.Error("Success should be true") + } + if decoded.Count != 5 { + t.Errorf("Count mismatch: expected 5, got %d", decoded.Count) + } +} + +func TestImportShellAliasRequest_EmptyAliases(t *testing.T) { + req := importShellAliasRequest{ + Aliases: []string{}, + IsFullRefresh: false, + ShellType: "zsh", + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded importShellAliasRequest + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if len(decoded.Aliases) != 0 { + t.Errorf("Expected 0 aliases, got %d", len(decoded.Aliases)) + } +} + +func TestAlias_EmptyValues(t *testing.T) { + alias := Alias{ + Name: "", + Value: "", + Shell: "", + } + + data, err := json.Marshal(alias) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded Alias + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Name != "" { + t.Errorf("Expected empty name") + } + if decoded.Value != "" { + t.Errorf("Expected empty value") + } + if decoded.Shell != "" { + t.Errorf("Expected empty shell") + } +} + +func TestAlias_SpecialCharacters(t *testing.T) { + testCases := []struct { + name string + alias Alias + }{ + { + "quotes in value", + Alias{Name: "greet", Value: `echo "hello world"`, Shell: "bash"}, + }, + { + "pipes and redirects", + Alias{Name: "findlog", Value: "find . -name '*.log' | xargs grep error", Shell: "bash"}, + }, + { + "unicode characters", + Alias{Name: "emoji", Value: "echo '🚀 Starting...'", Shell: "zsh"}, + }, + { + "backslashes", + Alias{Name: "path", Value: `echo "C:\Users\test"`, Shell: "bash"}, + }, + { + "newlines escaped", + Alias{Name: "multiline", Value: "echo 'line1\\nline2'", Shell: "bash"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data, err := json.Marshal(tc.alias) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded Alias + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Value != tc.alias.Value { + t.Errorf("Value mismatch: expected %q, got %q", tc.alias.Value, decoded.Value) + } + }) + } +} + +func TestImportShellAliasRequest_ShellTypes(t *testing.T) { + shells := []string{"bash", "zsh", "fish", "sh", "ksh"} + + for _, shell := range shells { + t.Run(shell, func(t *testing.T) { + req := importShellAliasRequest{ + Aliases: []string{"alias test='echo test'"}, + ShellType: shell, + } + + data, _ := json.Marshal(req) + var decoded importShellAliasRequest + json.Unmarshal(data, &decoded) + + if decoded.ShellType != shell { + t.Errorf("ShellType mismatch: expected %s, got %s", shell, decoded.ShellType) + } + }) + } +} + +func TestImportShellAliasResponse_FailureCase(t *testing.T) { + resp := importShellAliasResponse{ + Success: false, + Count: 0, + } + + data, _ := json.Marshal(resp) + var decoded importShellAliasResponse + json.Unmarshal(data, &decoded) + + if decoded.Success { + t.Error("Expected success to be false") + } + if decoded.Count != 0 { + t.Error("Expected count to be 0") + } +} + +func TestImportShellAliasRequest_FileLocations(t *testing.T) { + locations := []string{ + "~/.bashrc", + "~/.bash_profile", + "~/.zshrc", + "~/.config/fish/config.fish", + "/etc/bash.bashrc", + } + + for _, loc := range locations { + t.Run(loc, func(t *testing.T) { + req := importShellAliasRequest{ + Aliases: []string{"alias test='test'"}, + FileLocation: loc, + } + + data, _ := json.Marshal(req) + var decoded importShellAliasRequest + json.Unmarshal(data, &decoded) + + if decoded.FileLocation != loc { + t.Errorf("FileLocation mismatch: expected %s, got %s", loc, decoded.FileLocation) + } + }) + } +} + +func TestAlias_LongValue(t *testing.T) { + // Create a very long alias value + longValue := "" + for i := 0; i < 1000; i++ { + longValue += "echo 'test' && " + } + longValue += "echo 'done'" + + alias := Alias{ + Name: "longcmd", + Value: longValue, + Shell: "bash", + } + + data, err := json.Marshal(alias) + if err != nil { + t.Fatalf("Failed to marshal long alias: %v", err) + } + + var decoded Alias + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal long alias: %v", err) + } + + if decoded.Value != longValue { + t.Error("Long value was not preserved correctly") + } +} diff --git a/model/api.base_test.go b/model/api.base_test.go new file mode 100644 index 0000000..4ef9bd4 --- /dev/null +++ b/model/api.base_test.go @@ -0,0 +1,418 @@ +package model + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestHTTPRequestOptions_Defaults(t *testing.T) { + opts := HTTPRequestOptions[interface{}, interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: "http://localhost", Token: "test"}, + Method: http.MethodGet, + Path: "/test", + } + + if opts.ContentType != "" { + t.Error("ContentType should default to empty") + } + + if opts.Timeout != 0 { + t.Error("Timeout should default to 0") + } +} + +func TestSendHTTPRequestJSON_Success(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request + if r.Method != http.MethodPost { + t.Errorf("Expected POST, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected application/json, got %s", r.Header.Get("Content-Type")) + } + if r.Header.Get("Authorization") != "CLI test-token" { + t.Errorf("Unexpected Authorization header: %s", r.Header.Get("Authorization")) + } + + // Send response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + type request struct { + Data string `json:"data"` + } + type response struct { + Status string `json:"status"` + } + + var resp response + err := SendHTTPRequestJSON(HTTPRequestOptions[request, response]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "test-token"}, + Method: http.MethodPost, + Path: "/api/test", + Payload: request{Data: "test"}, + Response: &resp, + }) + + if err != nil { + t.Fatalf("SendHTTPRequestJSON failed: %v", err) + } + + if resp.Status != "ok" { + t.Errorf("Expected status 'ok', got '%s'", resp.Status) + } +} + +func TestSendHTTPRequestJSON_NoContent(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + err := SendHTTPRequestJSON(HTTPRequestOptions[map[string]string, interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Method: http.MethodPost, + Path: "/api/test", + Payload: map[string]string{"key": "value"}, + }) + + if err != nil { + t.Fatalf("Should not error on NoContent: %v", err) + } +} + +func TestSendHTTPRequestJSON_ErrorResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(errorResponse{ErrorMessage: "bad request"}) + })) + defer server.Close() + + err := SendHTTPRequestJSON(HTTPRequestOptions[map[string]string, interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Method: http.MethodPost, + Path: "/api/test", + Payload: map[string]string{}, + }) + + if err == nil { + t.Fatal("Expected error for bad request") + } + + if err.Error() != "bad request" { + t.Errorf("Expected 'bad request', got '%s'", err.Error()) + } +} + +func TestSendHTTPRequestJSON_CustomContentType(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Content-Type") != "application/x-custom" { + t.Errorf("Expected custom content type, got %s", r.Header.Get("Content-Type")) + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + err := SendHTTPRequestJSON(HTTPRequestOptions[map[string]string, interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Method: http.MethodPost, + Path: "/api/test", + Payload: map[string]string{}, + ContentType: "application/x-custom", + }) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } +} + +func TestSendHTTPRequestJSON_CustomTimeout(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + err := SendHTTPRequestJSON(HTTPRequestOptions[map[string]string, map[string]string]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Method: http.MethodPost, + Path: "/api/test", + Payload: map[string]string{}, + Timeout: 50 * time.Millisecond, + }) + + if err == nil { + t.Error("Expected timeout error") + } +} + +func TestSendHTTPRequestJSON_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(errorResponse{ErrorMessage: "internal error"}) + })) + defer server.Close() + + err := SendHTTPRequestJSON(HTTPRequestOptions[map[string]string, interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Method: http.MethodPost, + Path: "/api/test", + Payload: map[string]string{}, + }) + + if err == nil { + t.Fatal("Expected error for server error") + } +} + +func TestSendHTTPRequestJSON_InvalidURL(t *testing.T) { + err := SendHTTPRequestJSON(HTTPRequestOptions[map[string]string, interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: "http://invalid-url-that-doesnt-exist.local:99999", Token: "token"}, + Method: http.MethodPost, + Path: "/api/test", + Payload: map[string]string{}, + Timeout: 1 * time.Second, + }) + + if err == nil { + t.Error("Expected error for invalid URL") + } +} + +func TestSendHTTPRequestJSON_NilResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + // Response is nil - should still work + err := SendHTTPRequestJSON(HTTPRequestOptions[map[string]string, interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Method: http.MethodPost, + Path: "/api/test", + Payload: map[string]string{}, + Response: nil, + }) + + if err != nil { + t.Fatalf("Should not error with nil response: %v", err) + } +} + +func TestGraphQLResponse_Structure(t *testing.T) { + type TestData struct { + Field string `json:"field"` + } + + resp := GraphQLResponse[TestData]{ + Data: TestData{Field: "value"}, + Errors: []GraphQLError{ + { + Message: "test error", + Extensions: map[string]interface{}{ + "code": "TEST_ERROR", + }, + Path: []interface{}{"query", "field"}, + }, + }, + } + + // Marshal and unmarshal + data, err := json.Marshal(resp) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded GraphQLResponse[TestData] + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Data.Field != "value" { + t.Errorf("Expected field 'value', got '%s'", decoded.Data.Field) + } + + if len(decoded.Errors) != 1 { + t.Errorf("Expected 1 error, got %d", len(decoded.Errors)) + } + + if decoded.Errors[0].Message != "test error" { + t.Errorf("Expected error message 'test error', got '%s'", decoded.Errors[0].Message) + } +} + +func TestGraphQLError_Structure(t *testing.T) { + err := GraphQLError{ + Message: "Field not found", + Extensions: map[string]interface{}{ + "code": "NOT_FOUND", + "timestamp": "2024-01-15", + }, + Path: []interface{}{"query", "user", 0, "name"}, + } + + data, marshalErr := json.Marshal(err) + if marshalErr != nil { + t.Fatalf("Failed to marshal: %v", marshalErr) + } + + var decoded GraphQLError + if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil { + t.Fatalf("Failed to unmarshal: %v", unmarshalErr) + } + + if decoded.Message != "Field not found" { + t.Errorf("Message mismatch") + } + + if decoded.Extensions["code"] != "NOT_FOUND" { + t.Errorf("Extensions mismatch") + } + + if len(decoded.Path) != 4 { + t.Errorf("Path should have 4 elements") + } +} + +func TestSendGraphQLRequest_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify it's a POST to GraphQL endpoint + if r.Method != http.MethodPost { + t.Errorf("Expected POST, got %s", r.Method) + } + if r.URL.Path != "/api/v2/graphql" { + t.Errorf("Expected /api/v2/graphql, got %s", r.URL.Path) + } + + // Parse request body + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) + + if body["query"] == nil { + t.Error("Request should contain query") + } + + // Send response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "data": map[string]string{"result": "success"}, + }) + })) + defer server.Close() + + type response struct { + Data struct { + Result string `json:"result"` + } `json:"data"` + } + + var resp response + err := SendGraphQLRequest(GraphQLRequestOptions[response]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Query: "query { test }", + Response: &resp, + }) + + if err != nil { + t.Fatalf("SendGraphQLRequest failed: %v", err) + } + + if resp.Data.Result != "success" { + t.Errorf("Expected result 'success', got '%s'", resp.Data.Result) + } +} + +func TestSendGraphQLRequest_WithVariables(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) + + // Verify variables are present + if body["variables"] == nil { + t.Error("Request should contain variables") + } + + vars := body["variables"].(map[string]interface{}) + if vars["id"] != "123" { + t.Errorf("Expected id '123', got '%v'", vars["id"]) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "data": map[string]string{"id": "123"}, + }) + })) + defer server.Close() + + type response struct { + Data struct { + ID string `json:"id"` + } `json:"data"` + } + + var resp response + err := SendGraphQLRequest(GraphQLRequestOptions[response]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Query: "query GetItem($id: ID!) { item(id: $id) { id } }", + Variables: map[string]interface{}{ + "id": "123", + }, + Response: &resp, + }) + + if err != nil { + t.Fatalf("SendGraphQLRequest failed: %v", err) + } +} + +func TestSendGraphQLRequest_CustomTimeout(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{}) + })) + defer server.Close() + + err := SendGraphQLRequest(GraphQLRequestOptions[interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{APIEndpoint: server.URL, Token: "token"}, + Query: "query { test }", + Timeout: 50 * time.Millisecond, + }) + + if err == nil { + t.Error("Expected timeout error") + } +} + +func TestGraphQLRequestOptions_DefaultTimeout(t *testing.T) { + opts := GraphQLRequestOptions[interface{}]{ + Context: context.Background(), + Endpoint: Endpoint{}, + Query: "query { test }", + } + + if opts.Timeout != 0 { + t.Error("Timeout should default to 0 (will use 30s internally)") + } +} diff --git a/model/cc_statusline_cache_test.go b/model/cc_statusline_cache_test.go new file mode 100644 index 0000000..26c76a0 --- /dev/null +++ b/model/cc_statusline_cache_test.go @@ -0,0 +1,307 @@ +package model + +import ( + "testing" + "time" +) + +func TestCCStatuslineCacheEntry_IsValid(t *testing.T) { + t.Run("nil entry", func(t *testing.T) { + var entry *ccStatuslineCacheEntry + if entry.IsValid() { + t.Error("nil entry should not be valid") + } + }) + + t.Run("valid entry", func(t *testing.T) { + entry := &ccStatuslineCacheEntry{ + Date: time.Now().Format("2006-01-02"), + CostUsd: 1.23, + FetchedAt: time.Now(), + TTL: 5 * time.Minute, + } + if !entry.IsValid() { + t.Error("recent entry should be valid") + } + }) + + t.Run("expired entry", func(t *testing.T) { + entry := &ccStatuslineCacheEntry{ + Date: time.Now().Format("2006-01-02"), + CostUsd: 1.23, + FetchedAt: time.Now().Add(-10 * time.Minute), + TTL: 5 * time.Minute, + } + if entry.IsValid() { + t.Error("expired entry should not be valid") + } + }) + + t.Run("wrong date entry", func(t *testing.T) { + entry := &ccStatuslineCacheEntry{ + Date: "2020-01-01", // Past date + CostUsd: 1.23, + FetchedAt: time.Now(), + TTL: 5 * time.Minute, + } + if entry.IsValid() { + t.Error("entry from different date should not be valid") + } + }) +} + +func TestCCStatuslineCacheGetSet(t *testing.T) { + // Reset cache state + statuslineCache.mu.Lock() + statuslineCache.entry = nil + statuslineCache.fetching = false + statuslineCache.mu.Unlock() + + // Initially cache should be empty + cost, valid := CCStatuslineCacheGet() + if valid { + t.Error("cache should initially be invalid") + } + if cost != 0 { + t.Errorf("Expected 0 cost, got %f", cost) + } + + // Set a value + CCStatuslineCacheSet(2.50) + + // Now cache should be valid + cost, valid = CCStatuslineCacheGet() + if !valid { + t.Error("cache should be valid after set") + } + if cost != 2.50 { + t.Errorf("Expected 2.50, got %f", cost) + } +} + +func TestCCStatuslineCacheGetLastValue(t *testing.T) { + // Reset cache state + statuslineCache.mu.Lock() + statuslineCache.entry = nil + statuslineCache.mu.Unlock() + + // No entry - should return 0 + if CCStatuslineCacheGetLastValue() != 0 { + t.Error("Expected 0 for nil entry") + } + + // Set a value + CCStatuslineCacheSet(3.75) + + // Should return the value + if CCStatuslineCacheGetLastValue() != 3.75 { + t.Errorf("Expected 3.75, got %f", CCStatuslineCacheGetLastValue()) + } + + // Manually expire the entry but keep it + statuslineCache.mu.Lock() + statuslineCache.entry.FetchedAt = time.Now().Add(-1 * time.Hour) + statuslineCache.mu.Unlock() + + // GetLastValue should still return the value even if expired + if CCStatuslineCacheGetLastValue() != 3.75 { + t.Errorf("Expected 3.75 even when expired, got %f", CCStatuslineCacheGetLastValue()) + } + + // But CCStatuslineCacheGet should return invalid + _, valid := CCStatuslineCacheGet() + if valid { + t.Error("Cache should be invalid after expiration") + } +} + +func TestCCStatuslineCacheStartFetch(t *testing.T) { + // Reset cache state + statuslineCache.mu.Lock() + statuslineCache.fetching = false + statuslineCache.mu.Unlock() + + // First call should return true + if !CCStatuslineCacheStartFetch() { + t.Error("First StartFetch should return true") + } + + // Second call should return false (already fetching) + if CCStatuslineCacheStartFetch() { + t.Error("Second StartFetch should return false") + } + + // End fetch + CCStatuslineCacheEndFetch() + + // Now should be able to start again + if !CCStatuslineCacheStartFetch() { + t.Error("StartFetch should return true after EndFetch") + } + + // Cleanup + CCStatuslineCacheEndFetch() +} + +func TestCCStatuslineCacheEndFetch(t *testing.T) { + // Reset cache state + statuslineCache.mu.Lock() + statuslineCache.fetching = true + statuslineCache.mu.Unlock() + + CCStatuslineCacheEndFetch() + + // Verify fetching is false + statuslineCache.mu.RLock() + fetching := statuslineCache.fetching + statuslineCache.mu.RUnlock() + + if fetching { + t.Error("fetching should be false after EndFetch") + } +} + +func TestCCStatuslineCacheSet_ClearsFetching(t *testing.T) { + // Start fetching + statuslineCache.mu.Lock() + statuslineCache.fetching = true + statuslineCache.mu.Unlock() + + // Set value + CCStatuslineCacheSet(1.00) + + // Verify fetching is false + statuslineCache.mu.RLock() + fetching := statuslineCache.fetching + statuslineCache.mu.RUnlock() + + if fetching { + t.Error("fetching should be false after Set") + } +} + +func TestDefaultStatuslineCacheTTL(t *testing.T) { + expected := 5 * time.Minute + if DefaultStatuslineCacheTTL != expected { + t.Errorf("Expected DefaultStatuslineCacheTTL to be 5m, got %v", DefaultStatuslineCacheTTL) + } +} + +func TestCCStatuslineCache_ConcurrentAccess(t *testing.T) { + // Reset cache state + statuslineCache.mu.Lock() + statuslineCache.entry = nil + statuslineCache.fetching = false + statuslineCache.mu.Unlock() + + done := make(chan bool, 20) + + // Concurrent reads + for i := 0; i < 10; i++ { + go func() { + CCStatuslineCacheGet() + CCStatuslineCacheGetLastValue() + done <- true + }() + } + + // Concurrent writes + for i := 0; i < 5; i++ { + go func(val float64) { + CCStatuslineCacheSet(val) + done <- true + }(float64(i)) + } + + // Concurrent fetch attempts + for i := 0; i < 5; i++ { + go func() { + if CCStatuslineCacheStartFetch() { + CCStatuslineCacheEndFetch() + } + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 20; i++ { + <-done + } +} + +func TestCCStatuslineCacheEntry_DateComparison(t *testing.T) { + today := time.Now().Format("2006-01-02") + + t.Run("today matches", func(t *testing.T) { + entry := &ccStatuslineCacheEntry{ + Date: today, + FetchedAt: time.Now(), + TTL: 5 * time.Minute, + } + if !entry.IsValid() { + t.Error("today's date should be valid") + } + }) + + t.Run("yesterday doesn't match", func(t *testing.T) { + yesterday := time.Now().Add(-24 * time.Hour).Format("2006-01-02") + entry := &ccStatuslineCacheEntry{ + Date: yesterday, + FetchedAt: time.Now(), + TTL: 5 * time.Minute, + } + if entry.IsValid() { + t.Error("yesterday's date should not be valid") + } + }) + + t.Run("tomorrow doesn't match", func(t *testing.T) { + tomorrow := time.Now().Add(24 * time.Hour).Format("2006-01-02") + entry := &ccStatuslineCacheEntry{ + Date: tomorrow, + FetchedAt: time.Now(), + TTL: 5 * time.Minute, + } + if entry.IsValid() { + t.Error("tomorrow's date should not be valid") + } + }) +} + +func TestCCStatuslineCacheEntry_TTLBoundary(t *testing.T) { + today := time.Now().Format("2006-01-02") + + t.Run("just before TTL", func(t *testing.T) { + entry := &ccStatuslineCacheEntry{ + Date: today, + FetchedAt: time.Now().Add(-4*time.Minute - 59*time.Second), + TTL: 5 * time.Minute, + } + if !entry.IsValid() { + t.Error("entry just before TTL should be valid") + } + }) + + t.Run("exactly at TTL", func(t *testing.T) { + entry := &ccStatuslineCacheEntry{ + Date: today, + FetchedAt: time.Now().Add(-5 * time.Minute), + TTL: 5 * time.Minute, + } + // At exactly TTL, time.Since >= TTL so it should be invalid + if entry.IsValid() { + t.Error("entry exactly at TTL should be invalid") + } + }) + + t.Run("just after TTL", func(t *testing.T) { + entry := &ccStatuslineCacheEntry{ + Date: today, + FetchedAt: time.Now().Add(-5*time.Minute - 1*time.Second), + TTL: 5 * time.Minute, + } + if entry.IsValid() { + t.Error("entry just after TTL should be invalid") + } + }) +} diff --git a/model/command_test.go b/model/command_test.go new file mode 100644 index 0000000..f258aa3 --- /dev/null +++ b/model/command_test.go @@ -0,0 +1,621 @@ +package model + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestCommand_ToLine(t *testing.T) { + cmd := Command{ + Shell: "bash", + SessionID: 12345, + Command: "ls -la", + Main: "ls", + Hostname: "testhost", + Username: "testuser", + Time: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + EndTime: time.Date(2024, 1, 15, 10, 30, 5, 0, time.UTC), + Result: 0, + Phase: CommandPhasePre, + } + + recordingTime := time.Date(2024, 1, 15, 10, 30, 1, 0, time.UTC) + line, err := cmd.ToLine(recordingTime) + if err != nil { + t.Fatalf("ToLine failed: %v", err) + } + + // Line should contain JSON followed by separator and timestamp + if !strings.Contains(string(line), string(SEPARATOR)) { + t.Error("Line should contain separator") + } + + // Line should end with newline + if !strings.HasSuffix(string(line), "\n") { + t.Error("Line should end with newline") + } + + // Parse the JSON part + parts := strings.Split(strings.TrimSuffix(string(line), "\n"), string(SEPARATOR)) + if len(parts) != 2 { + t.Fatalf("Expected 2 parts, got %d", len(parts)) + } + + var parsedCmd Command + if err := json.Unmarshal([]byte(parts[0]), &parsedCmd); err != nil { + t.Fatalf("Failed to unmarshal JSON: %v", err) + } + + if parsedCmd.Shell != cmd.Shell { + t.Errorf("Expected shell %s, got %s", cmd.Shell, parsedCmd.Shell) + } + if parsedCmd.Command != cmd.Command { + t.Errorf("Expected command %s, got %s", cmd.Command, parsedCmd.Command) + } +} + +func TestCommand_FromLine(t *testing.T) { + originalCmd := Command{ + Shell: "zsh", + SessionID: 67890, + Command: "git status", + Main: "git", + Hostname: "devbox", + Username: "developer", + Time: time.Date(2024, 2, 20, 14, 0, 0, 0, time.UTC), + EndTime: time.Time{}, + Result: 0, + Phase: CommandPhasePre, + } + + recordingTime := time.Date(2024, 2, 20, 14, 0, 1, 0, time.UTC) + line, err := originalCmd.ToLine(recordingTime) + if err != nil { + t.Fatalf("ToLine failed: %v", err) + } + + var parsedCmd Command + parsedRecordingTime, err := parsedCmd.FromLine(strings.TrimSuffix(string(line), "\n")) + if err != nil { + t.Fatalf("FromLine failed: %v", err) + } + + // Verify all fields + if parsedCmd.Shell != originalCmd.Shell { + t.Errorf("Shell mismatch: expected %s, got %s", originalCmd.Shell, parsedCmd.Shell) + } + if parsedCmd.SessionID != originalCmd.SessionID { + t.Errorf("SessionID mismatch: expected %d, got %d", originalCmd.SessionID, parsedCmd.SessionID) + } + if parsedCmd.Command != originalCmd.Command { + t.Errorf("Command mismatch: expected %s, got %s", originalCmd.Command, parsedCmd.Command) + } + if parsedCmd.Username != originalCmd.Username { + t.Errorf("Username mismatch: expected %s, got %s", originalCmd.Username, parsedCmd.Username) + } + + // Recording time should match + if !parsedRecordingTime.Equal(recordingTime) { + t.Errorf("RecordingTime mismatch: expected %v, got %v", recordingTime, parsedRecordingTime) + } +} + +func TestCommand_FromLine_InvalidFormat(t *testing.T) { + testCases := []struct { + name string + line string + }{ + {"empty line", ""}, + {"no separator", `{"shell":"bash"}`}, + {"invalid json", `not json` + string(SEPARATOR) + "123456789"}, + {"invalid timestamp", `{"shell":"bash"}` + string(SEPARATOR) + "not-a-number"}, + {"too many parts", `{"shell":"bash"}` + string(SEPARATOR) + "123" + string(SEPARATOR) + "extra"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var cmd Command + _, err := cmd.FromLine(tc.line) + if err == nil { + t.Error("Expected error for invalid line format") + } + }) + } +} + +func TestCommand_FromLineBytes(t *testing.T) { + originalCmd := Command{ + Shell: "fish", + SessionID: 11111, + Command: "echo hello", + Main: "echo", + Hostname: "fishbox", + Username: "fishuser", + Time: time.Date(2024, 3, 10, 9, 0, 0, 0, time.UTC), + Phase: CommandPhasePost, + Result: 0, + } + + recordingTime := time.Date(2024, 3, 10, 9, 0, 2, 0, time.UTC) + line, err := originalCmd.ToLine(recordingTime) + if err != nil { + t.Fatalf("ToLine failed: %v", err) + } + + // Remove the trailing newline for FromLineBytes + lineBytes := line[:len(line)-1] + + var parsedCmd Command + parsedRecordingTime, err := parsedCmd.FromLineBytes(lineBytes) + if err != nil { + t.Fatalf("FromLineBytes failed: %v", err) + } + + if parsedCmd.Shell != originalCmd.Shell { + t.Errorf("Shell mismatch: expected %s, got %s", originalCmd.Shell, parsedCmd.Shell) + } + + if !parsedRecordingTime.Equal(recordingTime) { + t.Errorf("RecordingTime mismatch: expected %v, got %v", recordingTime, parsedRecordingTime) + } +} + +func TestCommand_FromLineBytes_InvalidFormat(t *testing.T) { + testCases := []struct { + name string + line []byte + }{ + {"empty line", []byte{}}, + {"no separator", []byte(`{"shell":"bash"}`)}, + {"invalid json", append([]byte("not json"), SEPARATOR, '1', '2', '3')}, + {"invalid timestamp", append([]byte(`{"shell":"bash"}`), SEPARATOR, 'x', 'y', 'z')}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var cmd Command + _, err := cmd.FromLineBytes(tc.line) + if err == nil { + t.Error("Expected error for invalid line format") + } + }) + } +} + +func TestCommand_IsSame(t *testing.T) { + cmd1 := Command{ + Shell: "bash", + SessionID: 12345, + Command: "ls -la", + Username: "testuser", + } + + testCases := []struct { + name string + cmd2 Command + expected bool + }{ + { + "identical", + Command{Shell: "bash", SessionID: 12345, Command: "ls -la", Username: "testuser"}, + true, + }, + { + "different shell", + Command{Shell: "zsh", SessionID: 12345, Command: "ls -la", Username: "testuser"}, + false, + }, + { + "different session", + Command{Shell: "bash", SessionID: 99999, Command: "ls -la", Username: "testuser"}, + false, + }, + { + "different command", + Command{Shell: "bash", SessionID: 12345, Command: "pwd", Username: "testuser"}, + false, + }, + { + "different username", + Command{Shell: "bash", SessionID: 12345, Command: "ls -la", Username: "otheruser"}, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := cmd1.IsSame(tc.cmd2) + if result != tc.expected { + t.Errorf("Expected %v, got %v", tc.expected, result) + } + }) + } +} + +func TestCommand_IsPairPreCommand(t *testing.T) { + now := time.Now() + + preCmd := Command{ + Shell: "bash", + SessionID: 12345, + Command: "ls -la", + Username: "testuser", + Time: now.Add(-time.Hour), // 1 hour ago + Phase: CommandPhasePre, + } + + testCases := []struct { + name string + target Command + expected bool + }{ + { + "matching pre command", + Command{Shell: "bash", SessionID: 12345, Command: "ls -la", Username: "testuser"}, + true, + }, + { + "post phase command", + Command{Shell: "bash", SessionID: 12345, Command: "ls -la", Username: "testuser"}, + true, + }, + { + "different command", + Command{Shell: "bash", SessionID: 12345, Command: "pwd", Username: "testuser"}, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := preCmd.IsPairPreCommand(tc.target) + if result != tc.expected { + t.Errorf("Expected %v, got %v", tc.expected, result) + } + }) + } +} + +func TestCommand_IsPairPreCommand_OldCommand(t *testing.T) { + // Command older than 10 days + oldPreCmd := Command{ + Shell: "bash", + SessionID: 12345, + Command: "ls -la", + Username: "testuser", + Time: time.Now().Add(-11 * 24 * time.Hour), // 11 days ago + Phase: CommandPhasePre, + } + + target := Command{Shell: "bash", SessionID: 12345, Command: "ls -la", Username: "testuser"} + + if oldPreCmd.IsPairPreCommand(target) { + t.Error("Should not pair with command older than 10 days") + } +} + +func TestCommand_IsPairPreCommand_NotPrePhase(t *testing.T) { + postCmd := Command{ + Shell: "bash", + SessionID: 12345, + Command: "ls -la", + Username: "testuser", + Time: time.Now().Add(-time.Hour), + Phase: CommandPhasePost, // Not a pre command + } + + target := Command{Shell: "bash", SessionID: 12345, Command: "ls -la", Username: "testuser"} + + if postCmd.IsPairPreCommand(target) { + t.Error("Should not pair with post phase command") + } +} + +func TestCommand_IsNil(t *testing.T) { + testCases := []struct { + name string + cmd Command + expected bool + }{ + {"empty command", Command{}, true}, + {"has command", Command{Command: "ls"}, false}, + {"has session id", Command{SessionID: 123}, false}, + {"has username", Command{Username: "user"}, false}, + {"has shell", Command{Shell: "bash"}, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.cmd.IsNil() + if result != tc.expected { + t.Errorf("Expected %v, got %v", tc.expected, result) + } + }) + } +} + +func TestCommand_GetUniqueKey(t *testing.T) { + cmd := Command{ + Shell: "bash", + SessionID: 12345, + Command: "git commit -m 'test'", + Username: "developer", + } + + key := cmd.GetUniqueKey() + expected := "bash|12345|git commit -m 'test'|developer" + + if key != expected { + t.Errorf("Expected key %s, got %s", expected, key) + } +} + +func TestCommand_FindClosestCommand(t *testing.T) { + baseTime := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) + + targetCmd := Command{ + Shell: "bash", + SessionID: 100, + Command: "ls", + Username: "user", + Time: baseTime, + } + + commands := []*Command{ + {Shell: "bash", SessionID: 100, Command: "ls", Username: "user", Time: baseTime.Add(-10 * time.Second)}, + {Shell: "bash", SessionID: 100, Command: "ls", Username: "user", Time: baseTime.Add(-5 * time.Second)}, // Closest + {Shell: "bash", SessionID: 100, Command: "ls", Username: "user", Time: baseTime.Add(-30 * time.Second)}, + nil, // Should be skipped + } + + closest := targetCmd.FindClosestCommand(commands, true) + if closest == nil { + t.Fatal("Expected to find a closest command") + } + + expectedTime := baseTime.Add(-5 * time.Second) + if !closest.Time.Equal(expectedTime) { + t.Errorf("Expected closest command time %v, got %v", expectedTime, closest.Time) + } +} + +func TestCommand_FindClosestCommand_WithoutSameKey(t *testing.T) { + baseTime := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) + + targetCmd := Command{ + Shell: "bash", + SessionID: 100, + Command: "ls", + Username: "user", + Time: baseTime, + } + + // Commands with different keys + commands := []*Command{ + {Shell: "zsh", SessionID: 200, Command: "pwd", Username: "other", Time: baseTime.Add(-5 * time.Second)}, + {Shell: "fish", SessionID: 300, Command: "cd", Username: "another", Time: baseTime.Add(-10 * time.Second)}, + } + + // Without same key requirement, should still find the closest + closest := targetCmd.FindClosestCommand(commands, false) + if closest == nil { + t.Fatal("Expected to find a closest command") + } + + expectedTime := baseTime.Add(-5 * time.Second) + if !closest.Time.Equal(expectedTime) { + t.Errorf("Expected closest command time %v, got %v", expectedTime, closest.Time) + } +} + +func TestCommand_FindClosestCommand_NoMatch(t *testing.T) { + baseTime := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) + + targetCmd := Command{ + Shell: "bash", + SessionID: 100, + Command: "ls", + Username: "user", + Time: baseTime, + } + + // Commands with different keys and withSameKey=true + commands := []*Command{ + {Shell: "zsh", SessionID: 200, Command: "pwd", Username: "other", Time: baseTime.Add(-5 * time.Second)}, + } + + closest := targetCmd.FindClosestCommand(commands, true) + if !closest.IsNil() { + t.Error("Expected no match when requiring same key") + } +} + +func TestCommand_FindClosestCommand_FutureCommands(t *testing.T) { + baseTime := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) + + targetCmd := Command{ + Shell: "bash", + SessionID: 100, + Command: "ls", + Username: "user", + Time: baseTime, + } + + // All commands are in the future (positive time diff) + commands := []*Command{ + {Shell: "bash", SessionID: 100, Command: "ls", Username: "user", Time: baseTime.Add(5 * time.Second)}, + {Shell: "bash", SessionID: 100, Command: "ls", Username: "user", Time: baseTime.Add(10 * time.Second)}, + } + + // Future commands should not match (timeDiff < 0) + closest := targetCmd.FindClosestCommand(commands, true) + if closest == nil { + t.Fatal("Expected non-nil command (even if empty)") + } +} + +func TestCommand_FindClosestCommand_EmptyList(t *testing.T) { + targetCmd := Command{ + Shell: "bash", + SessionID: 100, + Command: "ls", + Username: "user", + Time: time.Now(), + } + + closest := targetCmd.FindClosestCommand([]*Command{}, true) + if !closest.IsNil() { + t.Error("Expected nil command for empty list") + } +} + +func TestCommandPhase_Constants(t *testing.T) { + if CommandPhasePre != 0 { + t.Errorf("Expected CommandPhasePre to be 0, got %d", CommandPhasePre) + } + if CommandPhasePost != 1 { + t.Errorf("Expected CommandPhasePost to be 1, got %d", CommandPhasePost) + } +} + +func TestCommand_DoSavePre(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Override the storage folder + originalBaseFolder := COMMAND_BASE_STORAGE_FOLDER + defer func() { COMMAND_BASE_STORAGE_FOLDER = originalBaseFolder }() + + // Use a unique base folder that will be under the temp dir + InitFolder("test-save-pre") + // Replace with temp dir path + COMMAND_BASE_STORAGE_FOLDER = filepath.Join(tempDir, ".shelltime-test-save-pre") + COMMAND_STORAGE_FOLDER = COMMAND_BASE_STORAGE_FOLDER + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + + cmd := Command{ + Shell: "bash", + SessionID: 12345, + Command: "ls -la", + Main: "ls", + Hostname: "testhost", + Username: "testuser", + Time: time.Now(), + Phase: CommandPhasePre, + } + + err = cmd.DoSavePre() + if err != nil { + t.Fatalf("DoSavePre failed: %v", err) + } + + // Verify the file was created + preFilePath := filepath.Join(tempDir, ".shelltime-test-save-pre", "commands", "pre.txt") + if _, err := os.Stat(preFilePath); os.IsNotExist(err) { + t.Error("Pre-command file was not created") + } + + // Read and verify content + content, err := os.ReadFile(preFilePath) + if err != nil { + t.Fatalf("Failed to read pre-command file: %v", err) + } + + if !strings.Contains(string(content), "ls -la") { + t.Error("Pre-command file should contain the command") + } +} + +func TestCommand_DoUpdate(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Override the storage folder + originalBaseFolder := COMMAND_BASE_STORAGE_FOLDER + defer func() { COMMAND_BASE_STORAGE_FOLDER = originalBaseFolder }() + + COMMAND_BASE_STORAGE_FOLDER = filepath.Join(tempDir, ".shelltime-test-update") + COMMAND_STORAGE_FOLDER = COMMAND_BASE_STORAGE_FOLDER + "/commands" + COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" + + cmd := Command{ + Shell: "bash", + SessionID: 12345, + Command: "make build", + Main: "make", + Hostname: "testhost", + Username: "testuser", + Time: time.Now(), + Phase: CommandPhasePre, + } + + err = cmd.DoUpdate(0) // Exit code 0 + if err != nil { + t.Fatalf("DoUpdate failed: %v", err) + } + + // Verify the file was created + postFilePath := filepath.Join(tempDir, ".shelltime-test-update", "commands", "post.txt") + if _, err := os.Stat(postFilePath); os.IsNotExist(err) { + t.Error("Post-command file was not created") + } + + // Read and verify content + content, err := os.ReadFile(postFilePath) + if err != nil { + t.Fatalf("Failed to read post-command file: %v", err) + } + + if !strings.Contains(string(content), "make build") { + t.Error("Post-command file should contain the command") + } + + // Verify result code is in the content + if !strings.Contains(string(content), `"result":0`) { + t.Error("Post-command file should contain result code") + } +} + +func TestEnsureStorageFolder(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Override the storage folder + originalBaseFolder := COMMAND_BASE_STORAGE_FOLDER + defer func() { COMMAND_BASE_STORAGE_FOLDER = originalBaseFolder }() + + COMMAND_BASE_STORAGE_FOLDER = filepath.Join(tempDir, ".shelltime-ensure") + COMMAND_STORAGE_FOLDER = COMMAND_BASE_STORAGE_FOLDER + "/commands" + + err = ensureStorageFolder() + if err != nil { + t.Fatalf("ensureStorageFolder failed: %v", err) + } + + // Verify folder was created + if _, err := os.Stat(COMMAND_STORAGE_FOLDER); os.IsNotExist(err) { + t.Error("Storage folder was not created") + } + + // Call again to ensure it doesn't fail when folder exists + err = ensureStorageFolder() + if err != nil { + t.Fatalf("ensureStorageFolder failed on second call: %v", err) + } +} diff --git a/model/crypto_test.go b/model/crypto_test.go new file mode 100644 index 0000000..5d9ad27 --- /dev/null +++ b/model/crypto_test.go @@ -0,0 +1,415 @@ +package model + +import ( + "encoding/base64" + "strings" + "testing" +) + +func TestNewAESGCMService(t *testing.T) { + service := NewAESGCMService() + if service == nil { + t.Fatal("NewAESGCMService returned nil") + } + + aesService, ok := service.(*AESGCMService) + if !ok { + t.Fatal("NewAESGCMService did not return an *AESGCMService") + } + + if aesService.KeySize != 32 { + t.Errorf("Expected KeySize to be 32, got %d", aesService.KeySize) + } +} + +func TestAESGCMService_GenerateKeys(t *testing.T) { + service := NewAESGCMService() + + publicKey, privateKey, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + // AES is symmetric, so private key should be nil + if privateKey != nil { + t.Error("Expected private key to be nil for AES") + } + + // Public key should be base64 encoded + if publicKey == nil { + t.Fatal("Expected public key to be non-nil") + } + + // Verify it's valid base64 + decoded, err := base64.StdEncoding.DecodeString(string(publicKey)) + if err != nil { + t.Fatalf("Public key is not valid base64: %v", err) + } + + // Decoded key should be 32 bytes (256 bits) + if len(decoded) != 32 { + t.Errorf("Expected decoded key to be 32 bytes, got %d", len(decoded)) + } +} + +func TestAESGCMService_GenerateKeys_Uniqueness(t *testing.T) { + service := NewAESGCMService() + + key1, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("First GenerateKeys failed: %v", err) + } + + key2, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("Second GenerateKeys failed: %v", err) + } + + // Keys should be unique + if string(key1) == string(key2) { + t.Error("Generated keys should be unique") + } +} + +func TestAESGCMService_Encrypt(t *testing.T) { + service := NewAESGCMService() + + // Generate a key + key, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + plaintext := []byte("Hello, World! This is a test message for AES-GCM encryption.") + + ciphertext, nonce, err := service.Encrypt(string(key), plaintext) + if err != nil { + t.Fatalf("Encrypt failed: %v", err) + } + + // Ciphertext should not be nil or empty + if ciphertext == nil || len(ciphertext) == 0 { + t.Error("Ciphertext should not be empty") + } + + // Nonce should be 12 bytes (GCM standard) + if len(nonce) != 12 { + t.Errorf("Expected nonce to be 12 bytes, got %d", len(nonce)) + } + + // Ciphertext should be different from plaintext + if string(ciphertext) == string(plaintext) { + t.Error("Ciphertext should be different from plaintext") + } +} + +func TestAESGCMService_Encrypt_UniqueNonce(t *testing.T) { + service := NewAESGCMService() + + key, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + plaintext := []byte("Test message") + + _, nonce1, err := service.Encrypt(string(key), plaintext) + if err != nil { + t.Fatalf("First Encrypt failed: %v", err) + } + + _, nonce2, err := service.Encrypt(string(key), plaintext) + if err != nil { + t.Fatalf("Second Encrypt failed: %v", err) + } + + // Nonces should be unique + if string(nonce1) == string(nonce2) { + t.Error("Nonces should be unique for each encryption") + } +} + +func TestAESGCMService_Encrypt_InvalidKey(t *testing.T) { + service := NewAESGCMService() + + testCases := []struct { + name string + key string + }{ + {"invalid base64", "not-valid-base64!@#$"}, + {"empty key", ""}, + {"wrong size key", base64.StdEncoding.EncodeToString([]byte("short"))}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := service.Encrypt(tc.key, []byte("test")) + if err == nil { + t.Error("Expected error for invalid key") + } + }) + } +} + +func TestAESGCMService_Encrypt_EmptyPlaintext(t *testing.T) { + service := NewAESGCMService() + + key, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + // Empty plaintext should still work (GCM can encrypt empty messages) + ciphertext, nonce, err := service.Encrypt(string(key), []byte{}) + if err != nil { + t.Fatalf("Encrypt with empty plaintext failed: %v", err) + } + + if nonce == nil { + t.Error("Nonce should not be nil even for empty plaintext") + } + + // Ciphertext will include the auth tag, so it won't be empty + if ciphertext == nil { + t.Error("Ciphertext should not be nil") + } +} + +func TestAESGCMService_Encrypt_LargePlaintext(t *testing.T) { + service := NewAESGCMService() + + key, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + // Create a large plaintext (1MB) + largePlaintext := make([]byte, 1024*1024) + for i := range largePlaintext { + largePlaintext[i] = byte(i % 256) + } + + ciphertext, nonce, err := service.Encrypt(string(key), largePlaintext) + if err != nil { + t.Fatalf("Encrypt with large plaintext failed: %v", err) + } + + if len(ciphertext) < len(largePlaintext) { + t.Error("Ciphertext should be at least as long as plaintext (plus auth tag)") + } + + if len(nonce) != 12 { + t.Errorf("Expected nonce to be 12 bytes, got %d", len(nonce)) + } +} + +func TestAESGCMService_Decrypt_Panics(t *testing.T) { + service := NewAESGCMService() + + defer func() { + if r := recover(); r == nil { + t.Error("Expected Decrypt to panic") + } + }() + + // This should panic as Decrypt is not implemented + service.Decrypt("key", []byte("ciphertext"), []byte("nonce")) +} + +// RSA Service Tests + +func TestNewRSAService(t *testing.T) { + service := NewRSAService() + if service == nil { + t.Fatal("NewRSAService returned nil") + } + + rsaService, ok := service.(*RSAService) + if !ok { + t.Fatal("NewRSAService did not return an *RSAService") + } + + if rsaService.KeySize != 2048 { + t.Errorf("Expected KeySize to be 2048, got %d", rsaService.KeySize) + } +} + +func TestRSAService_GenerateKeys(t *testing.T) { + service := NewRSAService() + + publicKey, privateKey, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + // Both keys should be non-nil + if publicKey == nil { + t.Error("Expected public key to be non-nil") + } + if privateKey == nil { + t.Error("Expected private key to be non-nil") + } + + // Verify PEM format for public key + pubKeyStr := string(publicKey) + if !strings.Contains(pubKeyStr, "-----BEGIN RSA PUBLIC KEY-----") { + t.Error("Public key should be in PEM format") + } + if !strings.Contains(pubKeyStr, "-----END RSA PUBLIC KEY-----") { + t.Error("Public key should have proper PEM ending") + } + + // Verify PEM format for private key + privKeyStr := string(privateKey) + if !strings.Contains(privKeyStr, "-----BEGIN RSA PRIVATE KEY-----") { + t.Error("Private key should be in PEM format") + } + if !strings.Contains(privKeyStr, "-----END RSA PRIVATE KEY-----") { + t.Error("Private key should have proper PEM ending") + } +} + +func TestRSAService_GenerateKeys_Uniqueness(t *testing.T) { + service := NewRSAService() + + pub1, priv1, err := service.GenerateKeys() + if err != nil { + t.Fatalf("First GenerateKeys failed: %v", err) + } + + pub2, priv2, err := service.GenerateKeys() + if err != nil { + t.Fatalf("Second GenerateKeys failed: %v", err) + } + + // Keys should be unique + if string(pub1) == string(pub2) { + t.Error("Generated public keys should be unique") + } + if string(priv1) == string(priv2) { + t.Error("Generated private keys should be unique") + } +} + +func TestRSAService_Encrypt(t *testing.T) { + service := NewRSAService() + + publicKey, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + plaintext := []byte("Hello, RSA!") + + ciphertext, nonce, err := service.Encrypt(string(publicKey), plaintext) + if err != nil { + t.Fatalf("Encrypt failed: %v", err) + } + + // For RSA, nonce should be nil + if nonce != nil { + t.Error("Expected nonce to be nil for RSA encryption") + } + + // Ciphertext should not be nil or empty + if ciphertext == nil || len(ciphertext) == 0 { + t.Error("Ciphertext should not be empty") + } + + // Ciphertext should be different from plaintext + if string(ciphertext) == string(plaintext) { + t.Error("Ciphertext should be different from plaintext") + } +} + +func TestRSAService_Encrypt_UniqueOutput(t *testing.T) { + service := NewRSAService() + + publicKey, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + plaintext := []byte("Test message") + + cipher1, _, err := service.Encrypt(string(publicKey), plaintext) + if err != nil { + t.Fatalf("First Encrypt failed: %v", err) + } + + cipher2, _, err := service.Encrypt(string(publicKey), plaintext) + if err != nil { + t.Fatalf("Second Encrypt failed: %v", err) + } + + // Due to PKCS1v15 padding randomness, ciphertexts should be unique + if string(cipher1) == string(cipher2) { + t.Error("RSA ciphertexts should be unique due to padding") + } +} + +func TestRSAService_Encrypt_InvalidKey(t *testing.T) { + service := NewRSAService() + + testCases := []struct { + name string + key string + }{ + {"empty key", ""}, + {"invalid PEM", "not a valid PEM key"}, + {"malformed PEM", "-----BEGIN RSA PUBLIC KEY-----\ninvalid content\n-----END RSA PUBLIC KEY-----"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := service.Encrypt(tc.key, []byte("test")) + if err == nil { + t.Error("Expected error for invalid key") + } + }) + } +} + +func TestRSAService_Encrypt_MessageTooLong(t *testing.T) { + service := NewRSAService() + + publicKey, _, err := service.GenerateKeys() + if err != nil { + t.Fatalf("GenerateKeys failed: %v", err) + } + + // RSA PKCS1v15 with 2048-bit key can encrypt at most 245 bytes + // Create a message that's too long + longMessage := make([]byte, 300) + for i := range longMessage { + longMessage[i] = 'a' + } + + _, _, err = service.Encrypt(string(publicKey), longMessage) + if err == nil { + t.Error("Expected error for message too long") + } +} + +func TestRSAService_Decrypt_Panics(t *testing.T) { + service := NewRSAService() + + defer func() { + if r := recover(); r == nil { + t.Error("Expected Decrypt to panic") + } + }() + + // This should panic as Decrypt is not implemented + service.Decrypt("key", []byte("ciphertext"), nil) +} + +// CryptoService interface compliance tests + +func TestAESGCMService_ImplementsCryptoService(t *testing.T) { + var _ CryptoService = &AESGCMService{} +} + +func TestRSAService_ImplementsCryptoService(t *testing.T) { + var _ CryptoService = &RSAService{} +} diff --git a/model/db_test.go b/model/db_test.go new file mode 100644 index 0000000..ef47b9e --- /dev/null +++ b/model/db_test.go @@ -0,0 +1,568 @@ +package model + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" +) + +func TestInitFolder(t *testing.T) { + // Save original values + origBase := COMMAND_BASE_STORAGE_FOLDER + origStorage := COMMAND_STORAGE_FOLDER + origPre := COMMAND_PRE_STORAGE_FILE + origPost := COMMAND_POST_STORAGE_FILE + origCursor := COMMAND_CURSOR_STORAGE_FILE + origHeartbeat := HEARTBEAT_LOG_FILE + origPending := SYNC_PENDING_FILE + + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origStorage + COMMAND_PRE_STORAGE_FILE = origPre + COMMAND_POST_STORAGE_FILE = origPost + COMMAND_CURSOR_STORAGE_FILE = origCursor + HEARTBEAT_LOG_FILE = origHeartbeat + SYNC_PENDING_FILE = origPending + }() + + // Test with empty baseFolder (should keep default) + InitFolder("") + if COMMAND_BASE_STORAGE_FOLDER != ".shelltime" { + t.Errorf("Expected default base folder, got %s", COMMAND_BASE_STORAGE_FOLDER) + } + + // Test with custom baseFolder + InitFolder("custom") + if COMMAND_BASE_STORAGE_FOLDER != ".shelltime-custom" { + t.Errorf("Expected .shelltime-custom, got %s", COMMAND_BASE_STORAGE_FOLDER) + } + if COMMAND_STORAGE_FOLDER != ".shelltime-custom/commands" { + t.Errorf("Expected .shelltime-custom/commands, got %s", COMMAND_STORAGE_FOLDER) + } + if COMMAND_PRE_STORAGE_FILE != ".shelltime-custom/commands/pre.txt" { + t.Errorf("Expected .shelltime-custom/commands/pre.txt, got %s", COMMAND_PRE_STORAGE_FILE) + } +} + +func TestGetPreCommandsTree(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + + // Create the commands directory + if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + t.Fatalf("Failed to create commands dir: %v", err) + } + + // Create test pre commands + cmd1 := Command{ + Shell: "bash", + SessionID: 100, + Command: "ls -la", + Username: "user1", + Time: time.Now(), + Phase: CommandPhasePre, + } + cmd2 := Command{ + Shell: "bash", + SessionID: 100, + Command: "ls -la", + Username: "user1", + Time: time.Now().Add(-time.Hour), + Phase: CommandPhasePre, + } + cmd3 := Command{ + Shell: "zsh", + SessionID: 200, + Command: "pwd", + Username: "user2", + Time: time.Now(), + Phase: CommandPhasePre, + } + + // Write commands to file + f, err := os.Create(GetPreCommandFilePath()) + if err != nil { + t.Fatalf("Failed to create pre file: %v", err) + } + + for _, cmd := range []Command{cmd1, cmd2, cmd3} { + line, _ := cmd.ToLine(time.Now()) + f.Write(line) + } + f.Close() + + // Test GetPreCommandsTree + ctx := context.Background() + tree, err := GetPreCommandsTree(ctx) + if err != nil { + t.Fatalf("GetPreCommandsTree failed: %v", err) + } + + // cmd1 and cmd2 should share the same key + key1 := cmd1.GetUniqueKey() + if len(tree[key1]) != 2 { + t.Errorf("Expected 2 commands for key %s, got %d", key1, len(tree[key1])) + } + + // cmd3 should have its own key + key3 := cmd3.GetUniqueKey() + if len(tree[key3]) != 1 { + t.Errorf("Expected 1 command for key %s, got %d", key3, len(tree[key3])) + } +} + +func TestGetPreCommandsTree_FileNotExists(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + + ctx := context.Background() + _, err = GetPreCommandsTree(ctx) + if err == nil { + t.Error("Expected error when file doesn't exist") + } +} + +func TestGetPreCommands(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + + // Create the commands directory + if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + t.Fatalf("Failed to create commands dir: %v", err) + } + + // Create test pre commands + commands := []Command{ + {Shell: "bash", SessionID: 100, Command: "ls", Username: "user", Time: time.Now(), Phase: CommandPhasePre}, + {Shell: "bash", SessionID: 101, Command: "pwd", Username: "user", Time: time.Now(), Phase: CommandPhasePre}, + {Shell: "zsh", SessionID: 102, Command: "cd", Username: "user", Time: time.Now(), Phase: CommandPhasePre}, + } + + // Write commands to file + f, err := os.Create(GetPreCommandFilePath()) + if err != nil { + t.Fatalf("Failed to create pre file: %v", err) + } + + for _, cmd := range commands { + line, _ := cmd.ToLine(time.Now()) + f.Write(line) + } + f.Close() + + // Test GetPreCommands + ctx := context.Background() + result, err := GetPreCommands(ctx) + if err != nil { + t.Fatalf("GetPreCommands failed: %v", err) + } + + if len(result) != 3 { + t.Errorf("Expected 3 commands, got %d", len(result)) + } +} + +func TestGetPreCommands_EmptyLines(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" + + // Create the commands directory + if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + t.Fatalf("Failed to create commands dir: %v", err) + } + + // Create file with empty lines + cmd := Command{Shell: "bash", SessionID: 100, Command: "ls", Username: "user", Time: time.Now()} + f, _ := os.Create(GetPreCommandFilePath()) + f.WriteString("\n") // Empty line + line, _ := cmd.ToLine(time.Now()) + f.Write(line) + f.WriteString("\n") // Empty line + f.Close() + + ctx := context.Background() + result, err := GetPreCommands(ctx) + if err != nil { + t.Fatalf("GetPreCommands failed: %v", err) + } + + // Should only have 1 command (empty lines skipped) + if len(result) != 1 { + t.Errorf("Expected 1 command, got %d", len(result)) + } +} + +func TestGetLastCursor_NoCursorFile(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" + + ctx := context.Background() + cursorTime, noCursorExist, err := GetLastCursor(ctx) + if err != nil { + t.Fatalf("GetLastCursor failed: %v", err) + } + + if !noCursorExist { + t.Error("Expected noCursorExist to be true") + } + + if !cursorTime.IsZero() { + t.Error("Expected zero time when no cursor exists") + } +} + +func TestGetLastCursor_WithCursor(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" + + // Create the commands directory + if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + t.Fatalf("Failed to create commands dir: %v", err) + } + + // Write cursor timestamp + expectedTime := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) + cursorFile := filepath.Join(COMMAND_STORAGE_FOLDER, "cursor.txt") + if err := os.WriteFile(cursorFile, []byte(fmt.Sprintf("%d", expectedTime.UnixNano())), 0644); err != nil { + t.Fatalf("Failed to write cursor file: %v", err) + } + + ctx := context.Background() + cursorTime, noCursorExist, err := GetLastCursor(ctx) + if err != nil { + t.Fatalf("GetLastCursor failed: %v", err) + } + + if noCursorExist { + t.Error("Expected noCursorExist to be false") + } + + if !cursorTime.Equal(expectedTime) { + t.Errorf("Expected cursor time %v, got %v", expectedTime, cursorTime) + } +} + +func TestGetLastCursor_MultipleLines(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" + + // Create the commands directory + if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + t.Fatalf("Failed to create commands dir: %v", err) + } + + // Write multiple cursor timestamps (should use last one) + time1 := time.Date(2024, 1, 10, 10, 0, 0, 0, time.UTC) + time2 := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) // Last line + cursorFile := filepath.Join(COMMAND_STORAGE_FOLDER, "cursor.txt") + content := fmt.Sprintf("%d\n%d\n", time1.UnixNano(), time2.UnixNano()) + if err := os.WriteFile(cursorFile, []byte(content), 0644); err != nil { + t.Fatalf("Failed to write cursor file: %v", err) + } + + ctx := context.Background() + cursorTime, _, err := GetLastCursor(ctx) + if err != nil { + t.Fatalf("GetLastCursor failed: %v", err) + } + + if !cursorTime.Equal(time2) { + t.Errorf("Expected cursor time %v (last line), got %v", time2, cursorTime) + } +} + +func TestGetLastCursor_InvalidContent(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" + + // Create the commands directory + if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + t.Fatalf("Failed to create commands dir: %v", err) + } + + // Write invalid cursor content + cursorFile := filepath.Join(COMMAND_STORAGE_FOLDER, "cursor.txt") + if err := os.WriteFile(cursorFile, []byte("not-a-number"), 0644); err != nil { + t.Fatalf("Failed to write cursor file: %v", err) + } + + ctx := context.Background() + _, _, err = GetLastCursor(ctx) + if err == nil { + t.Error("Expected error for invalid cursor content") + } +} + +func TestGetPostCommands(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" + + // Create the commands directory + if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + t.Fatalf("Failed to create commands dir: %v", err) + } + + // Create test post commands + commands := []Command{ + {Shell: "bash", SessionID: 100, Command: "make build", Username: "user", Time: time.Now(), Phase: CommandPhasePost, Result: 0}, + {Shell: "bash", SessionID: 101, Command: "make test", Username: "user", Time: time.Now(), Phase: CommandPhasePost, Result: 1}, + } + + // Write commands to file + f, err := os.Create(GetPostCommandFilePath()) + if err != nil { + t.Fatalf("Failed to create post file: %v", err) + } + + for _, cmd := range commands { + line, _ := cmd.ToLine(time.Now()) + f.Write(line) + } + f.Close() + + // Test GetPostCommands + ctx := context.Background() + content, lineCount, err := GetPostCommands(ctx) + if err != nil { + t.Fatalf("GetPostCommands failed: %v", err) + } + + if lineCount != 2 { + t.Errorf("Expected 2 lines, got %d", lineCount) + } + + if len(content) != 2 { + t.Errorf("Expected 2 content entries, got %d", len(content)) + } +} + +func TestGetPostCommands_EmptyLines(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" + + // Create the commands directory + if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + t.Fatalf("Failed to create commands dir: %v", err) + } + + // Create file with empty lines + cmd := Command{Shell: "bash", SessionID: 100, Command: "ls", Username: "user", Time: time.Now()} + f, _ := os.Create(GetPostCommandFilePath()) + f.WriteString("\n") // Empty line + line, _ := cmd.ToLine(time.Now()) + f.Write(line) + f.WriteString("\n\n") // Multiple empty lines + f.Close() + + ctx := context.Background() + content, lineCount, err := GetPostCommands(ctx) + if err != nil { + t.Fatalf("GetPostCommands failed: %v", err) + } + + // Should only have 1 command (empty lines skipped) + if lineCount != 1 { + t.Errorf("Expected 1 line, got %d", lineCount) + } + + if len(content) != 1 { + t.Errorf("Expected 1 content entry, got %d", len(content)) + } +} + +func TestGetPostCommands_FileNotExists(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-db-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + COMMAND_STORAGE_FOLDER = origBase + "/commands" + COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + COMMAND_STORAGE_FOLDER = tempDir + "/commands" + COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" + + ctx := context.Background() + _, _, err = GetPostCommands(ctx) + if err == nil { + t.Error("Expected error when file doesn't exist") + } +} + +func TestSEPARATOR_Constant(t *testing.T) { + if SEPARATOR != byte('\t') { + t.Errorf("Expected SEPARATOR to be tab, got %q", SEPARATOR) + } +} diff --git a/model/heartbeat_test.go b/model/heartbeat_test.go new file mode 100644 index 0000000..b2b1d6b --- /dev/null +++ b/model/heartbeat_test.go @@ -0,0 +1,370 @@ +package model + +import ( + "encoding/json" + "testing" +) + +func TestHeartbeatPayload_JSON(t *testing.T) { + payload := HeartbeatPayload{ + Heartbeats: []HeartbeatData{ + { + HeartbeatID: "uuid-1234", + Entity: "/path/to/file.go", + EntityType: "file", + Category: "coding", + Time: 1234567890, + Project: "my-project", + }, + }, + } + + // Marshal + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal + var decoded HeartbeatPayload + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if len(decoded.Heartbeats) != 1 { + t.Fatalf("Expected 1 heartbeat, got %d", len(decoded.Heartbeats)) + } + + hb := decoded.Heartbeats[0] + if hb.HeartbeatID != "uuid-1234" { + t.Errorf("HeartbeatID mismatch: expected uuid-1234, got %s", hb.HeartbeatID) + } + if hb.Entity != "/path/to/file.go" { + t.Errorf("Entity mismatch") + } + if hb.EntityType != "file" { + t.Errorf("EntityType mismatch") + } + if hb.Category != "coding" { + t.Errorf("Category mismatch") + } + if hb.Time != 1234567890 { + t.Errorf("Time mismatch") + } + if hb.Project != "my-project" { + t.Errorf("Project mismatch") + } +} + +func TestHeartbeatData_AllFields(t *testing.T) { + lines := 100 + lineNum := 50 + cursor := 25 + + hb := HeartbeatData{ + HeartbeatID: "test-id", + Entity: "/home/user/project/main.go", + EntityType: "file", + Category: "coding", + Time: 1234567890, + Project: "test-project", + ProjectRootPath: "/home/user/project", + Branch: "main", + Language: "go", + Lines: &lines, + LineNumber: &lineNum, + CursorPosition: &cursor, + Editor: "vscode", + EditorVersion: "1.85.0", + Plugin: "shelltime", + PluginVersion: "1.0.0", + Machine: "workstation", + OS: "linux", + OSVersion: "ubuntu 22.04", + IsWrite: true, + } + + // Marshal + data, err := json.Marshal(hb) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal + var decoded HeartbeatData + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + // Verify all fields + if decoded.HeartbeatID != hb.HeartbeatID { + t.Errorf("HeartbeatID mismatch") + } + if decoded.Entity != hb.Entity { + t.Errorf("Entity mismatch") + } + if decoded.Branch != hb.Branch { + t.Errorf("Branch mismatch") + } + if decoded.Language != hb.Language { + t.Errorf("Language mismatch") + } + if decoded.Lines == nil || *decoded.Lines != 100 { + t.Errorf("Lines mismatch") + } + if decoded.LineNumber == nil || *decoded.LineNumber != 50 { + t.Errorf("LineNumber mismatch") + } + if decoded.CursorPosition == nil || *decoded.CursorPosition != 25 { + t.Errorf("CursorPosition mismatch") + } + if decoded.Editor != hb.Editor { + t.Errorf("Editor mismatch") + } + if decoded.EditorVersion != hb.EditorVersion { + t.Errorf("EditorVersion mismatch") + } + if decoded.Plugin != hb.Plugin { + t.Errorf("Plugin mismatch") + } + if decoded.PluginVersion != hb.PluginVersion { + t.Errorf("PluginVersion mismatch") + } + if decoded.Machine != hb.Machine { + t.Errorf("Machine mismatch") + } + if decoded.OS != hb.OS { + t.Errorf("OS mismatch") + } + if decoded.OSVersion != hb.OSVersion { + t.Errorf("OSVersion mismatch") + } + if decoded.IsWrite != true { + t.Errorf("IsWrite mismatch") + } +} + +func TestHeartbeatData_OptionalFields(t *testing.T) { + // Minimal heartbeat with only required fields + hb := HeartbeatData{ + HeartbeatID: "minimal-id", + Entity: "/file.txt", + Time: 1234567890, + } + + data, err := json.Marshal(hb) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded HeartbeatData + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + // Optional fields should be nil or empty + if decoded.Lines != nil { + t.Error("Lines should be nil") + } + if decoded.LineNumber != nil { + t.Error("LineNumber should be nil") + } + if decoded.CursorPosition != nil { + t.Error("CursorPosition should be nil") + } + if decoded.EntityType != "" { + t.Error("EntityType should be empty") + } + if decoded.Category != "" { + t.Error("Category should be empty") + } +} + +func TestHeartbeatResponse_JSON(t *testing.T) { + resp := HeartbeatResponse{ + Success: true, + Processed: 10, + Errors: 2, + Message: "Partially processed", + } + + data, err := json.Marshal(resp) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded HeartbeatResponse + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Success != true { + t.Error("Success mismatch") + } + if decoded.Processed != 10 { + t.Errorf("Processed mismatch: expected 10, got %d", decoded.Processed) + } + if decoded.Errors != 2 { + t.Errorf("Errors mismatch: expected 2, got %d", decoded.Errors) + } + if decoded.Message != "Partially processed" { + t.Errorf("Message mismatch") + } +} + +func TestHeartbeatResponse_SuccessCase(t *testing.T) { + resp := HeartbeatResponse{ + Success: true, + Processed: 5, + Errors: 0, + } + + if !resp.Success { + t.Error("Expected success to be true") + } + if resp.Errors != 0 { + t.Error("Expected no errors") + } +} + +func TestHeartbeatResponse_FailureCase(t *testing.T) { + resp := HeartbeatResponse{ + Success: false, + Processed: 0, + Errors: 5, + Message: "All heartbeats failed", + } + + if resp.Success { + t.Error("Expected success to be false") + } + if resp.Processed != 0 { + t.Error("Expected 0 processed") + } +} + +func TestHeartbeatPayload_EmptyHeartbeats(t *testing.T) { + payload := HeartbeatPayload{ + Heartbeats: []HeartbeatData{}, + } + + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded HeartbeatPayload + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if len(decoded.Heartbeats) != 0 { + t.Errorf("Expected 0 heartbeats, got %d", len(decoded.Heartbeats)) + } +} + +func TestHeartbeatPayload_MultipleHeartbeats(t *testing.T) { + payload := HeartbeatPayload{ + Heartbeats: []HeartbeatData{ + {HeartbeatID: "id-1", Entity: "file1.go", Time: 1000}, + {HeartbeatID: "id-2", Entity: "file2.go", Time: 2000}, + {HeartbeatID: "id-3", Entity: "file3.go", Time: 3000}, + }, + } + + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded HeartbeatPayload + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if len(decoded.Heartbeats) != 3 { + t.Fatalf("Expected 3 heartbeats, got %d", len(decoded.Heartbeats)) + } + + for i, hb := range decoded.Heartbeats { + expectedID := "id-" + string(rune('1'+i)) + if hb.HeartbeatID != expectedID { + t.Errorf("Heartbeat %d: expected ID %s, got %s", i, expectedID, hb.HeartbeatID) + } + } +} + +func TestHeartbeatData_IsWrite(t *testing.T) { + testCases := []struct { + name string + isWrite bool + }{ + {"write event", true}, + {"read event", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hb := HeartbeatData{ + HeartbeatID: "test", + Entity: "file.go", + Time: 1234567890, + IsWrite: tc.isWrite, + } + + data, _ := json.Marshal(hb) + var decoded HeartbeatData + json.Unmarshal(data, &decoded) + + if decoded.IsWrite != tc.isWrite { + t.Errorf("IsWrite mismatch: expected %v, got %v", tc.isWrite, decoded.IsWrite) + } + }) + } +} + +func TestHeartbeatData_EntityTypes(t *testing.T) { + entityTypes := []string{"file", "app", "domain"} + + for _, et := range entityTypes { + t.Run(et, func(t *testing.T) { + hb := HeartbeatData{ + HeartbeatID: "test", + Entity: "test-entity", + EntityType: et, + Time: 1234567890, + } + + data, _ := json.Marshal(hb) + var decoded HeartbeatData + json.Unmarshal(data, &decoded) + + if decoded.EntityType != et { + t.Errorf("EntityType mismatch: expected %s, got %s", et, decoded.EntityType) + } + }) + } +} + +func TestHeartbeatData_Categories(t *testing.T) { + categories := []string{"coding", "debugging", "browsing", "building", "running_tests"} + + for _, cat := range categories { + t.Run(cat, func(t *testing.T) { + hb := HeartbeatData{ + HeartbeatID: "test", + Entity: "test-entity", + Category: cat, + Time: 1234567890, + } + + data, _ := json.Marshal(hb) + var decoded HeartbeatData + json.Unmarshal(data, &decoded) + + if decoded.Category != cat { + t.Errorf("Category mismatch: expected %s, got %s", cat, decoded.Category) + } + }) + } +} diff --git a/model/log_cleanup_test.go b/model/log_cleanup_test.go new file mode 100644 index 0000000..6c865b7 --- /dev/null +++ b/model/log_cleanup_test.go @@ -0,0 +1,356 @@ +package model + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestCleanLogFile_FileNotExists(t *testing.T) { + freed, err := CleanLogFile("/nonexistent/path/file.log", 1024, false) + if err != nil { + t.Errorf("Should not return error for non-existent file: %v", err) + } + if freed != 0 { + t.Errorf("Expected 0 bytes freed for non-existent file, got %d", freed) + } +} + +func TestCleanLogFile_BelowThreshold(t *testing.T) { + // Create a temp file + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + testFile := filepath.Join(tempDir, "test.log") + content := []byte("small content") + if err := os.WriteFile(testFile, content, 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + // File is smaller than threshold, should not be deleted + threshold := int64(1024 * 1024) // 1MB threshold + freed, err := CleanLogFile(testFile, threshold, false) + if err != nil { + t.Errorf("CleanLogFile failed: %v", err) + } + if freed != 0 { + t.Errorf("Expected 0 bytes freed (below threshold), got %d", freed) + } + + // Verify file still exists + if _, err := os.Stat(testFile); os.IsNotExist(err) { + t.Error("File should not be deleted when below threshold") + } +} + +func TestCleanLogFile_AboveThreshold(t *testing.T) { + // Create a temp file + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + testFile := filepath.Join(tempDir, "test.log") + // Create a file larger than threshold + content := make([]byte, 2048) + for i := range content { + content[i] = 'x' + } + if err := os.WriteFile(testFile, content, 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + // File is larger than threshold, should be deleted + threshold := int64(1024) // 1KB threshold + freed, err := CleanLogFile(testFile, threshold, false) + if err != nil { + t.Errorf("CleanLogFile failed: %v", err) + } + if freed != 2048 { + t.Errorf("Expected 2048 bytes freed, got %d", freed) + } + + // Verify file was deleted + if _, err := os.Stat(testFile); !os.IsNotExist(err) { + t.Error("File should be deleted when above threshold") + } +} + +func TestCleanLogFile_ForceDelete(t *testing.T) { + // Create a temp file + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + testFile := filepath.Join(tempDir, "test.log") + content := []byte("small content") + if err := os.WriteFile(testFile, content, 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + // File is smaller than threshold, but force=true + threshold := int64(1024 * 1024) // 1MB threshold + freed, err := CleanLogFile(testFile, threshold, true) + if err != nil { + t.Errorf("CleanLogFile failed: %v", err) + } + if freed != int64(len(content)) { + t.Errorf("Expected %d bytes freed, got %d", len(content), freed) + } + + // Verify file was deleted + if _, err := os.Stat(testFile); !os.IsNotExist(err) { + t.Error("File should be deleted when force=true") + } +} + +func TestCleanLargeLogFiles(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + + // Create log files + logFile := GetLogFilePath() + heartbeatFile := GetHeartbeatLogFilePath() + syncFile := GetSyncPendingFilePath() + + // Create parent directories + for _, f := range []string{logFile, heartbeatFile, syncFile} { + dir := filepath.Dir(f) + if err := os.MkdirAll(dir, 0755); err != nil { + t.Fatalf("Failed to create dir for %s: %v", f, err) + } + } + + // Create files with content above threshold + largeContent := make([]byte, 2048) + for i := range largeContent { + largeContent[i] = 'x' + } + + for _, f := range []string{logFile, heartbeatFile} { + if err := os.WriteFile(f, largeContent, 0644); err != nil { + t.Fatalf("Failed to write %s: %v", f, err) + } + } + + // Clean with 1KB threshold + threshold := int64(1024) + freed, err := CleanLargeLogFiles(threshold, false) + if err != nil { + t.Errorf("CleanLargeLogFiles failed: %v", err) + } + + // Should have freed 2 files worth of data + expectedFreed := int64(2 * 2048) + if freed != expectedFreed { + t.Errorf("Expected %d bytes freed, got %d", expectedFreed, freed) + } +} + +func TestCleanLargeLogFiles_NoFilesExist(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + + // No files exist + freed, err := CleanLargeLogFiles(1024, false) + if err != nil { + t.Errorf("CleanLargeLogFiles should not error when files don't exist: %v", err) + } + if freed != 0 { + t.Errorf("Expected 0 bytes freed when no files exist, got %d", freed) + } +} + +func TestCleanLargeLogFiles_ForceDelete(t *testing.T) { + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + + // Create log file with small content + logFile := GetLogFilePath() + if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil { + t.Fatalf("Failed to create dir: %v", err) + } + + smallContent := []byte("tiny") + if err := os.WriteFile(logFile, smallContent, 0644); err != nil { + t.Fatalf("Failed to write log file: %v", err) + } + + // Clean with force=true (should delete regardless of size) + threshold := int64(1024 * 1024) // Very high threshold + freed, err := CleanLargeLogFiles(threshold, true) + if err != nil { + t.Errorf("CleanLargeLogFiles failed: %v", err) + } + + if freed != int64(len(smallContent)) { + t.Errorf("Expected %d bytes freed with force=true, got %d", len(smallContent), freed) + } +} + +func TestCleanDaemonLogFiles(t *testing.T) { + // This function is platform-specific + if runtime.GOOS != "darwin" { + // On non-macOS, should return 0 and no error + freed, err := CleanDaemonLogFiles(1024, false) + if err != nil { + t.Errorf("CleanDaemonLogFiles should not error on non-darwin: %v", err) + } + if freed != 0 { + t.Errorf("Expected 0 bytes freed on non-darwin, got %d", freed) + } + return + } + + // macOS-specific tests + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + + // Create daemon log files + logFile := GetDaemonLogFilePath() + errFile := GetDaemonErrFilePath() + + // Create parent directories + if err := os.MkdirAll(filepath.Dir(logFile), 0755); err != nil { + t.Fatalf("Failed to create logs dir: %v", err) + } + + // Create files with content above threshold + largeContent := make([]byte, 2048) + for i := range largeContent { + largeContent[i] = 'x' + } + + for _, f := range []string{logFile, errFile} { + if err := os.WriteFile(f, largeContent, 0644); err != nil { + t.Fatalf("Failed to write %s: %v", f, err) + } + } + + // Clean with 1KB threshold + threshold := int64(1024) + freed, err := CleanDaemonLogFiles(threshold, false) + if err != nil { + t.Errorf("CleanDaemonLogFiles failed: %v", err) + } + + // Should have freed 2 files worth of data + expectedFreed := int64(2 * 2048) + if freed != expectedFreed { + t.Errorf("Expected %d bytes freed, got %d", expectedFreed, freed) + } +} + +func TestCleanDaemonLogFiles_NoFilesExist(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("Skipping macOS-specific test on non-darwin platform") + } + + // Create a temp directory for testing + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Save and restore original paths + origBase := COMMAND_BASE_STORAGE_FOLDER + defer func() { + COMMAND_BASE_STORAGE_FOLDER = origBase + }() + + COMMAND_BASE_STORAGE_FOLDER = tempDir + + // No files exist + freed, err := CleanDaemonLogFiles(1024, false) + if err != nil { + t.Errorf("CleanDaemonLogFiles should not error when files don't exist: %v", err) + } + if freed != 0 { + t.Errorf("Expected 0 bytes freed when no files exist, got %d", freed) + } +} + +func TestCleanLogFile_PermissionError(t *testing.T) { + if os.Getuid() == 0 { + t.Skip("Skipping permission test when running as root") + } + + // Create a temp file in a directory we'll make unreadable + tempDir, err := os.MkdirTemp("", "shelltime-cleanup-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + testFile := filepath.Join(tempDir, "test.log") + if err := os.WriteFile(testFile, []byte("content"), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + // Make the directory read-only (prevents deletion) + if err := os.Chmod(tempDir, 0555); err != nil { + t.Fatalf("Failed to chmod dir: %v", err) + } + defer os.Chmod(tempDir, 0755) // Restore for cleanup + + // Try to clean the file with force=true + _, err = CleanLogFile(testFile, 0, true) + if err == nil { + t.Error("Expected error when deleting file in read-only directory") + } +} diff --git a/model/path_test.go b/model/path_test.go new file mode 100644 index 0000000..d26933d --- /dev/null +++ b/model/path_test.go @@ -0,0 +1,291 @@ +package model + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestGetBaseStoragePath(t *testing.T) { + path := GetBaseStoragePath() + + homeDir, err := os.UserHomeDir() + if err != nil { + // If home dir is not available, should fallback to temp dir + if !strings.Contains(path, os.TempDir()) { + t.Errorf("Expected path to be in temp dir when home not available, got %s", path) + } + } else { + expected := filepath.Join(homeDir, COMMAND_BASE_STORAGE_FOLDER) + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } + } +} + +func TestGetStoragePath(t *testing.T) { + basePath := GetBaseStoragePath() + + testCases := []struct { + name string + subpaths []string + expected string + }{ + { + "single subpath", + []string{"config.toml"}, + filepath.Join(basePath, "config.toml"), + }, + { + "multiple subpaths", + []string{"commands", "pre.txt"}, + filepath.Join(basePath, "commands", "pre.txt"), + }, + { + "no subpaths", + []string{}, + basePath, + }, + { + "nested subpaths", + []string{"logs", "daemon", "output.log"}, + filepath.Join(basePath, "logs", "daemon", "output.log"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := GetStoragePath(tc.subpaths...) + if result != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, result) + } + }) + } +} + +func TestGetConfigFilePath(t *testing.T) { + path := GetConfigFilePath() + expected := GetStoragePath("config.toml") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } + + if !strings.HasSuffix(path, "config.toml") { + t.Error("Config file path should end with config.toml") + } +} + +func TestGetLocalConfigFilePath(t *testing.T) { + path := GetLocalConfigFilePath() + expected := GetStoragePath("config.local.toml") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } + + if !strings.HasSuffix(path, "config.local.toml") { + t.Error("Local config file path should end with config.local.toml") + } +} + +func TestGetYAMLConfigFilePath(t *testing.T) { + path := GetYAMLConfigFilePath() + expected := GetStoragePath("config.yaml") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetYMLConfigFilePath(t *testing.T) { + path := GetYMLConfigFilePath() + expected := GetStoragePath("config.yml") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetLocalYAMLConfigFilePath(t *testing.T) { + path := GetLocalYAMLConfigFilePath() + expected := GetStoragePath("config.local.yaml") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetLocalYMLConfigFilePath(t *testing.T) { + path := GetLocalYMLConfigFilePath() + expected := GetStoragePath("config.local.yml") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetLogFilePath(t *testing.T) { + path := GetLogFilePath() + expected := GetStoragePath("log.log") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetCommandsStoragePath(t *testing.T) { + path := GetCommandsStoragePath() + expected := GetStoragePath("commands") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetPreCommandFilePath(t *testing.T) { + path := GetPreCommandFilePath() + expected := GetStoragePath("commands", "pre.txt") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetPostCommandFilePath(t *testing.T) { + path := GetPostCommandFilePath() + expected := GetStoragePath("commands", "post.txt") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetCursorFilePath(t *testing.T) { + path := GetCursorFilePath() + expected := GetStoragePath("commands", "cursor.txt") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetHeartbeatLogFilePath(t *testing.T) { + path := GetHeartbeatLogFilePath() + expected := GetStoragePath("coding-heartbeat.data.log") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetSyncPendingFilePath(t *testing.T) { + path := GetSyncPendingFilePath() + expected := GetStoragePath("sync-pending.jsonl") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetBinFolderPath(t *testing.T) { + path := GetBinFolderPath() + expected := GetStoragePath("bin") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetHooksFolderPath(t *testing.T) { + path := GetHooksFolderPath() + expected := GetStoragePath("hooks") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetDaemonLogsPath(t *testing.T) { + path := GetDaemonLogsPath() + expected := GetStoragePath("logs") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetDaemonLogFilePath(t *testing.T) { + path := GetDaemonLogFilePath() + expected := GetStoragePath("logs", "shelltime-daemon.log") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestGetDaemonErrFilePath(t *testing.T) { + path := GetDaemonErrFilePath() + expected := GetStoragePath("logs", "shelltime-daemon.err") + + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestPathConsistency(t *testing.T) { + // All paths should be absolute + paths := []struct { + name string + path string + }{ + {"BaseStoragePath", GetBaseStoragePath()}, + {"ConfigFilePath", GetConfigFilePath()}, + {"LocalConfigFilePath", GetLocalConfigFilePath()}, + {"LogFilePath", GetLogFilePath()}, + {"CommandsStoragePath", GetCommandsStoragePath()}, + {"PreCommandFilePath", GetPreCommandFilePath()}, + {"PostCommandFilePath", GetPostCommandFilePath()}, + {"CursorFilePath", GetCursorFilePath()}, + {"HeartbeatLogFilePath", GetHeartbeatLogFilePath()}, + {"SyncPendingFilePath", GetSyncPendingFilePath()}, + {"BinFolderPath", GetBinFolderPath()}, + {"HooksFolderPath", GetHooksFolderPath()}, + {"DaemonLogsPath", GetDaemonLogsPath()}, + {"DaemonLogFilePath", GetDaemonLogFilePath()}, + {"DaemonErrFilePath", GetDaemonErrFilePath()}, + } + + basePath := GetBaseStoragePath() + for _, p := range paths { + t.Run(p.name, func(t *testing.T) { + // All paths should start with base path + if !strings.HasPrefix(p.path, basePath) { + t.Errorf("%s should start with base path %s, got %s", p.name, basePath, p.path) + } + + // Paths should be absolute (start with /) + if !filepath.IsAbs(p.path) { + t.Errorf("%s should be an absolute path, got %s", p.name, p.path) + } + }) + } +} + +func TestPathsAreClean(t *testing.T) { + // All paths should be clean (no . or ..) + paths := []string{ + GetBaseStoragePath(), + GetConfigFilePath(), + GetPreCommandFilePath(), + GetPostCommandFilePath(), + } + + for _, p := range paths { + cleaned := filepath.Clean(p) + if p != cleaned { + t.Errorf("Path %s is not clean (cleaned: %s)", p, cleaned) + } + } +} diff --git a/model/shell_test.go b/model/shell_test.go new file mode 100644 index 0000000..153fd6d --- /dev/null +++ b/model/shell_test.go @@ -0,0 +1,545 @@ +package model + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +// Test BaseHookService + +func TestBaseHookService_BackupFile_FileNotExists(t *testing.T) { + b := &BaseHookService{} + err := b.backupFile("/nonexistent/path/file.txt") + if err != nil { + t.Errorf("backupFile should not error for non-existent file: %v", err) + } +} + +func TestBaseHookService_BackupFile(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a test file + testFile := filepath.Join(tempDir, "testrc") + content := "# Test config\nexport PATH=$PATH:/usr/bin\n" + if err := os.WriteFile(testFile, []byte(content), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + b := &BaseHookService{} + err = b.backupFile(testFile) + if err != nil { + t.Fatalf("backupFile failed: %v", err) + } + + // Check that backup was created + files, _ := os.ReadDir(tempDir) + backupFound := false + for _, f := range files { + if strings.HasPrefix(f.Name(), "testrc.bak.") { + backupFound = true + // Verify backup content + backupContent, _ := os.ReadFile(filepath.Join(tempDir, f.Name())) + if string(backupContent) != content { + t.Error("Backup content doesn't match original") + } + break + } + } + + if !backupFound { + t.Error("Backup file was not created") + } +} + +func TestBaseHookService_AddHookLines(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a test file + testFile := filepath.Join(tempDir, "testrc") + initialContent := "# Existing config\nexport EDITOR=vim\n" + if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + b := &BaseHookService{} + hookLines := []string{ + "# Added by shelltime", + "source ~/.shelltime/hooks/test.sh", + } + + err = b.addHookLines(testFile, hookLines) + if err != nil { + t.Fatalf("addHookLines failed: %v", err) + } + + // Verify content + content, _ := os.ReadFile(testFile) + for _, line := range hookLines { + if !strings.Contains(string(content), line) { + t.Errorf("Expected file to contain: %s", line) + } + } +} + +func TestBaseHookService_RemoveHookLines(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a test file with hook lines + testFile := filepath.Join(tempDir, "testrc") + initialContent := `# Existing config +export EDITOR=vim +# Added by shelltime +source ~/.shelltime/hooks/test.sh +export PATH="~/.shelltime/bin:$PATH" +` + if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + b := &BaseHookService{} + hookLines := []string{ + "# Added by shelltime", + "source ~/.shelltime/hooks/test.sh", + } + + err = b.removeHookLines(testFile, hookLines) + if err != nil { + t.Fatalf("removeHookLines failed: %v", err) + } + + // Verify content - hook lines should be removed + content, _ := os.ReadFile(testFile) + for _, line := range hookLines { + if strings.Contains(string(content), line) { + t.Errorf("Expected file to NOT contain: %s", line) + } + } + + // Original content should still be there + if !strings.Contains(string(content), "export EDITOR=vim") { + t.Error("Original content should be preserved") + } +} + +func TestBaseHookService_RemoveHookLines_FileNotExists(t *testing.T) { + b := &BaseHookService{} + err := b.removeHookLines("/nonexistent/path/file.txt", []string{"line1"}) + if err == nil { + t.Error("Expected error for non-existent file") + } +} + +func TestBaseHookService_AddHookLines_FileNotExists(t *testing.T) { + b := &BaseHookService{} + err := b.addHookLines("/nonexistent/path/file.txt", []string{"line1"}) + if err == nil { + t.Error("Expected error for non-existent file") + } +} + +// Test BashHookService + +func TestNewBashHookService(t *testing.T) { + service := NewBashHookService() + if service == nil { + t.Fatal("NewBashHookService returned nil") + } + + bashService, ok := service.(*BashHookService) + if !ok { + t.Fatal("NewBashHookService did not return *BashHookService") + } + + if bashService.shellName != "bash" { + t.Errorf("Expected shellName 'bash', got '%s'", bashService.shellName) + } + + if !strings.HasSuffix(bashService.configPath, ".bashrc") { + t.Errorf("Expected configPath to end with .bashrc, got '%s'", bashService.configPath) + } +} + +func TestBashHookService_Match(t *testing.T) { + service := NewBashHookService() + + testCases := []struct { + input string + expected bool + }{ + {"bash", true}, + {"BASH", true}, + {"/bin/bash", true}, + {"/usr/bin/bash", true}, + {"zsh", false}, + {"fish", false}, + {"sh", false}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result := service.Match(tc.input) + if result != tc.expected { + t.Errorf("Match(%q) = %v, expected %v", tc.input, result, tc.expected) + } + }) + } +} + +func TestBashHookService_ShellName(t *testing.T) { + service := NewBashHookService() + if service.ShellName() != "bash" { + t.Errorf("Expected 'bash', got '%s'", service.ShellName()) + } +} + +func TestBashHookService_Check_FileNotExists(t *testing.T) { + // Create a service with non-existent config path + service := &BashHookService{ + shellName: "bash", + configPath: "/nonexistent/path/.bashrc", + hookLines: []string{"# test"}, + } + + err := service.Check() + if err == nil { + t.Error("Expected error when config file doesn't exist") + } +} + +func TestBashHookService_Check_MissingHooks(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a bashrc without hooks + configPath := filepath.Join(tempDir, ".bashrc") + if err := os.WriteFile(configPath, []byte("# Empty bashrc\n"), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + service := &BashHookService{ + shellName: "bash", + configPath: configPath, + hookLines: []string{"# Added by shelltime", "source test"}, + } + + err = service.Check() + if err == nil { + t.Error("Expected error when hook lines are missing") + } +} + +func TestBashHookService_Check_HooksPresent(t *testing.T) { + // Create a temp directory + tempDir, err := os.MkdirTemp("", "shelltime-shell-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + hookLines := []string{"# Added by shelltime", "source test"} + + // Create a bashrc with hooks + configPath := filepath.Join(tempDir, ".bashrc") + content := strings.Join(hookLines, "\n") + "\n" + if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + service := &BashHookService{ + shellName: "bash", + configPath: configPath, + hookLines: hookLines, + } + + err = service.Check() + if err != nil { + t.Errorf("Expected no error when hooks are present: %v", err) + } +} + +func TestBashHookService_Uninstall_FileNotExists(t *testing.T) { + service := &BashHookService{ + shellName: "bash", + configPath: "/nonexistent/path/.bashrc", + hookLines: []string{"# test"}, + } + + err := service.Uninstall() + if err != nil { + t.Errorf("Uninstall should not error when file doesn't exist: %v", err) + } +} + +// Test ZshHookService + +func TestNewZshHookService(t *testing.T) { + service := NewZshHookService() + if service == nil { + t.Fatal("NewZshHookService returned nil") + } + + zshService, ok := service.(*ZshHookService) + if !ok { + t.Fatal("NewZshHookService did not return *ZshHookService") + } + + if zshService.shellName != "zsh" { + t.Errorf("Expected shellName 'zsh', got '%s'", zshService.shellName) + } + + if !strings.HasSuffix(zshService.configPath, ".zshrc") { + t.Errorf("Expected configPath to end with .zshrc, got '%s'", zshService.configPath) + } +} + +func TestZshHookService_Match(t *testing.T) { + service := NewZshHookService() + + testCases := []struct { + input string + expected bool + }{ + {"zsh", true}, + {"ZSH", true}, + {"/bin/zsh", true}, + {"/usr/local/bin/zsh", true}, + {"bash", false}, + {"fish", false}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result := service.Match(tc.input) + if result != tc.expected { + t.Errorf("Match(%q) = %v, expected %v", tc.input, result, tc.expected) + } + }) + } +} + +func TestZshHookService_ShellName(t *testing.T) { + service := NewZshHookService() + if service.ShellName() != "zsh" { + t.Errorf("Expected 'zsh', got '%s'", service.ShellName()) + } +} + +func TestZshHookService_Check_FileNotExists(t *testing.T) { + service := &ZshHookService{ + shellName: "zsh", + configPath: "/nonexistent/path/.zshrc", + hookLines: []string{"# test"}, + } + + err := service.Check() + if err == nil { + t.Error("Expected error when config file doesn't exist") + } +} + +func TestZshHookService_Uninstall_FileNotExists(t *testing.T) { + service := &ZshHookService{ + shellName: "zsh", + configPath: "/nonexistent/path/.zshrc", + hookLines: []string{"# test"}, + } + + err := service.Uninstall() + if err != nil { + t.Errorf("Uninstall should not error when file doesn't exist: %v", err) + } +} + +// Test FishHookService + +func TestNewFishHookService(t *testing.T) { + service := NewFishHookService() + if service == nil { + t.Fatal("NewFishHookService returned nil") + } + + fishService, ok := service.(*FishHookService) + if !ok { + t.Fatal("NewFishHookService did not return *FishHookService") + } + + if fishService.shellName != "fish" { + t.Errorf("Expected shellName 'fish', got '%s'", fishService.shellName) + } + + if !strings.Contains(fishService.configPath, "fish/config.fish") { + t.Errorf("Expected configPath to contain fish/config.fish, got '%s'", fishService.configPath) + } +} + +func TestFishHookService_Match(t *testing.T) { + service := NewFishHookService() + + testCases := []struct { + input string + expected bool + }{ + {"fish", true}, + {"FISH", true}, + {"/usr/bin/fish", true}, + {"/usr/local/bin/fish", true}, + {"bash", false}, + {"zsh", false}, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + result := service.Match(tc.input) + if result != tc.expected { + t.Errorf("Match(%q) = %v, expected %v", tc.input, result, tc.expected) + } + }) + } +} + +func TestFishHookService_ShellName(t *testing.T) { + service := NewFishHookService() + if service.ShellName() != "fish" { + t.Errorf("Expected 'fish', got '%s'", service.ShellName()) + } +} + +func TestFishHookService_Check_FileNotExists(t *testing.T) { + service := &FishHookService{ + shellName: "fish", + configPath: "/nonexistent/path/config.fish", + hookLines: []string{"# test"}, + } + + err := service.Check() + if err == nil { + t.Error("Expected error when config file doesn't exist") + } +} + +func TestFishHookService_Uninstall_FileNotExists(t *testing.T) { + service := &FishHookService{ + shellName: "fish", + configPath: "/nonexistent/path/config.fish", + hookLines: []string{"# test"}, + } + + err := service.Uninstall() + if err != nil { + t.Errorf("Uninstall should not error when file doesn't exist: %v", err) + } +} + +// Interface compliance tests + +func TestBashHookService_ImplementsShellHookService(t *testing.T) { + var _ ShellHookService = &BashHookService{} +} + +func TestZshHookService_ImplementsShellHookService(t *testing.T) { + var _ ShellHookService = &ZshHookService{} +} + +func TestFishHookService_ImplementsShellHookService(t *testing.T) { + var _ ShellHookService = &FishHookService{} +} + +// Test hook line content + +func TestBashHookService_HookLines(t *testing.T) { + service := NewBashHookService().(*BashHookService) + + if len(service.hookLines) == 0 { + t.Error("Expected at least one hook line") + } + + // First line should be a comment + if !strings.HasPrefix(service.hookLines[0], "#") { + t.Error("First hook line should be a comment") + } + + // Should include PATH modification + hasPath := false + for _, line := range service.hookLines { + if strings.Contains(line, "PATH") { + hasPath = true + break + } + } + if !hasPath { + t.Error("Hook lines should include PATH modification") + } + + // Should include source command + hasSource := false + for _, line := range service.hookLines { + if strings.Contains(line, "source") { + hasSource = true + break + } + } + if !hasSource { + t.Error("Hook lines should include source command") + } +} + +func TestZshHookService_HookLines(t *testing.T) { + service := NewZshHookService().(*ZshHookService) + + if len(service.hookLines) == 0 { + t.Error("Expected at least one hook line") + } + + // Should include PATH modification + hasPath := false + for _, line := range service.hookLines { + if strings.Contains(line, "PATH") { + hasPath = true + break + } + } + if !hasPath { + t.Error("Hook lines should include PATH modification") + } +} + +func TestFishHookService_HookLines(t *testing.T) { + service := NewFishHookService().(*FishHookService) + + if len(service.hookLines) == 0 { + t.Error("Expected at least one hook line") + } + + // Should include fish_add_path + hasFishPath := false + for _, line := range service.hookLines { + if strings.Contains(line, "fish_add_path") { + hasFishPath = true + break + } + } + if !hasFishPath { + t.Error("Fish hook lines should include fish_add_path") + } +} From 2e5a1ebcc947ffc88a9502a56b3badb00d732b12 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 4 Jan 2026 05:21:02 +0000 Subject: [PATCH 2/2] fix(test): correct path handling and mock setup in test files - Fix daemon/socket_test.go to pass required args to NewGoChannel - Fix daemon/circuit_breaker_test.go to properly set up temp directories and override HOME env var for SaveForRetry tests - Fix daemon/cleanup_timer_test.go to use correct LogCleanup type - Fix daemon/aicode_otel_processor_test.go to use correct AICodeOtel type - Fix model/command_test.go to use path helper functions for verification - Fix model/db_test.go to use path helper functions consistently - Remove unused filepath import from db_test.go --- daemon/aicode_otel_processor_test.go | 2 +- daemon/circuit_breaker_test.go | 41 ++++++++++++++++++++++++++-- daemon/cleanup_timer_test.go | 12 ++++---- daemon/socket_test.go | 8 +++--- model/command_test.go | 12 ++++---- model/db_test.go | 41 ++++++++++++++-------------- 6 files changed, 76 insertions(+), 40 deletions(-) diff --git a/daemon/aicode_otel_processor_test.go b/daemon/aicode_otel_processor_test.go index 8f6e265..c1df0d5 100644 --- a/daemon/aicode_otel_processor_test.go +++ b/daemon/aicode_otel_processor_test.go @@ -34,7 +34,7 @@ func TestNewAICodeOtelProcessor_Debug(t *testing.T) { debug := true config := model.ShellTimeConfig{ Token: "token", - AICodeOtel: &model.AICodeOtelConfig{ + AICodeOtel: &model.AICodeOtel{ Debug: &debug, }, } diff --git a/daemon/circuit_breaker_test.go b/daemon/circuit_breaker_test.go index 6d858eb..d32b824 100644 --- a/daemon/circuit_breaker_test.go +++ b/daemon/circuit_breaker_test.go @@ -3,9 +3,12 @@ package daemon import ( "context" "encoding/json" + "os" + "path/filepath" "testing" "github.com/ThreeDotsLabs/watermill/message" + "github.com/malamtime/cli/model" ) // Mock publisher for testing @@ -76,26 +79,60 @@ func TestSyncCircuitBreakerWrapper_RecordFailure(t *testing.T) { } func TestSyncCircuitBreakerWrapper_SaveForRetry(t *testing.T) { + // Create temp directory for test + tempDir, err := os.MkdirTemp("", "circuit-breaker-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Override sync pending file + origFile := model.SYNC_PENDING_FILE + model.SYNC_PENDING_FILE = filepath.Join(tempDir, "sync-pending.jsonl") + defer func() { model.SYNC_PENDING_FILE = origFile }() + + // Override HOME for test (so $HOME/ doesn't affect the path) + origHome := os.Getenv("HOME") + os.Setenv("HOME", "") + defer os.Setenv("HOME", origHome) + publisher := &mockPublisher{} wrapper := NewSyncCircuitBreakerService(publisher) ctx := context.Background() payload := map[string]string{"key": "value"} - err := wrapper.SaveForRetry(ctx, payload) + err = wrapper.SaveForRetry(ctx, payload) if err != nil { t.Fatalf("SaveForRetry failed: %v", err) } } func TestSyncCircuitBreakerWrapper_SaveForRetry_WrapsInSocketMessage(t *testing.T) { + // Create temp directory for test + tempDir, err := os.MkdirTemp("", "circuit-breaker-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Override sync pending file + origFile := model.SYNC_PENDING_FILE + model.SYNC_PENDING_FILE = filepath.Join(tempDir, "sync-pending.jsonl") + defer func() { model.SYNC_PENDING_FILE = origFile }() + + // Override HOME for test (so $HOME/ doesn't affect the path) + origHome := os.Getenv("HOME") + os.Setenv("HOME", "") + defer os.Setenv("HOME", origHome) + publisher := &mockPublisher{} wrapper := NewSyncCircuitBreakerService(publisher) ctx := context.Background() payload := map[string]string{"test": "data"} - err := wrapper.SaveForRetry(ctx, payload) + err = wrapper.SaveForRetry(ctx, payload) if err != nil { t.Fatalf("SaveForRetry failed: %v", err) } diff --git a/daemon/cleanup_timer_test.go b/daemon/cleanup_timer_test.go index f9fbbfd..49a53d0 100644 --- a/daemon/cleanup_timer_test.go +++ b/daemon/cleanup_timer_test.go @@ -12,7 +12,7 @@ import ( func TestNewCleanupTimerService(t *testing.T) { config := model.ShellTimeConfig{ - LogCleanup: model.LogCleanupConfig{ + LogCleanup: &model.LogCleanup{ ThresholdMB: 100, }, } @@ -33,7 +33,7 @@ func TestNewCleanupTimerService(t *testing.T) { func TestCleanupTimerService_StartStop(t *testing.T) { config := model.ShellTimeConfig{ - LogCleanup: model.LogCleanupConfig{ + LogCleanup: &model.LogCleanup{ ThresholdMB: 100, }, } @@ -61,7 +61,7 @@ func TestCleanupTimerService_StartStop(t *testing.T) { func TestCleanupTimerService_StopWithoutStart(t *testing.T) { config := model.ShellTimeConfig{ - LogCleanup: model.LogCleanupConfig{ + LogCleanup: &model.LogCleanup{ ThresholdMB: 100, }, } @@ -80,7 +80,7 @@ func TestCleanupTimerService_StopWithoutStart(t *testing.T) { func TestCleanupTimerService_ContextCancellation(t *testing.T) { config := model.ShellTimeConfig{ - LogCleanup: model.LogCleanupConfig{ + LogCleanup: &model.LogCleanup{ ThresholdMB: 100, }, } @@ -125,7 +125,7 @@ func TestCleanupTimerService_Cleanup(t *testing.T) { model.COMMAND_BASE_STORAGE_FOLDER = tempDir config := model.ShellTimeConfig{ - LogCleanup: model.LogCleanupConfig{ + LogCleanup: &model.LogCleanup{ ThresholdMB: 1, // 1MB threshold (1024 * 1024 bytes) }, } @@ -179,7 +179,7 @@ func TestCleanupTimerService_CleanupBelowThreshold(t *testing.T) { model.COMMAND_BASE_STORAGE_FOLDER = tempDir config := model.ShellTimeConfig{ - LogCleanup: model.LogCleanupConfig{ + LogCleanup: &model.LogCleanup{ ThresholdMB: 10, // 10MB threshold }, } diff --git a/daemon/socket_test.go b/daemon/socket_test.go index 2eaf87f..18c9043 100644 --- a/daemon/socket_test.go +++ b/daemon/socket_test.go @@ -15,7 +15,7 @@ func TestNewSocketHandler(t *testing.T) { config := &model.ShellTimeConfig{ SocketPath: "/tmp/test-shelltime.sock", } - ch := NewGoChannel() + ch := NewGoChannel(PubSubConfig{OutputChannelBuffer: 10}, nil) handler := NewSocketHandler(config, ch) if handler == nil { @@ -51,7 +51,7 @@ func TestSocketHandler_StartStop(t *testing.T) { config := &model.ShellTimeConfig{ SocketPath: socketPath, } - ch := NewGoChannel() + ch := NewGoChannel(PubSubConfig{OutputChannelBuffer: 10}, nil) handler := NewSocketHandler(config, ch) @@ -90,7 +90,7 @@ func TestSocketHandler_StatusRequest(t *testing.T) { config := &model.ShellTimeConfig{ SocketPath: socketPath, } - ch := NewGoChannel() + ch := NewGoChannel(PubSubConfig{OutputChannelBuffer: 10}, nil) handler := NewSocketHandler(config, ch) @@ -310,7 +310,7 @@ func TestSocketHandler_MultipleConnections(t *testing.T) { config := &model.ShellTimeConfig{ SocketPath: socketPath, } - ch := NewGoChannel() + ch := NewGoChannel(PubSubConfig{OutputChannelBuffer: 10}, nil) handler := NewSocketHandler(config, ch) diff --git a/model/command_test.go b/model/command_test.go index f258aa3..30e6acc 100644 --- a/model/command_test.go +++ b/model/command_test.go @@ -517,8 +517,8 @@ func TestCommand_DoSavePre(t *testing.T) { t.Fatalf("DoSavePre failed: %v", err) } - // Verify the file was created - preFilePath := filepath.Join(tempDir, ".shelltime-test-save-pre", "commands", "pre.txt") + // Verify the file was created using the same path helper function + preFilePath := GetPreCommandFilePath() if _, err := os.Stat(preFilePath); os.IsNotExist(err) { t.Error("Pre-command file was not created") } @@ -566,8 +566,8 @@ func TestCommand_DoUpdate(t *testing.T) { t.Fatalf("DoUpdate failed: %v", err) } - // Verify the file was created - postFilePath := filepath.Join(tempDir, ".shelltime-test-update", "commands", "post.txt") + // Verify the file was created using the same path helper function + postFilePath := GetPostCommandFilePath() if _, err := os.Stat(postFilePath); os.IsNotExist(err) { t.Error("Post-command file was not created") } @@ -608,8 +608,8 @@ func TestEnsureStorageFolder(t *testing.T) { t.Fatalf("ensureStorageFolder failed: %v", err) } - // Verify folder was created - if _, err := os.Stat(COMMAND_STORAGE_FOLDER); os.IsNotExist(err) { + // Verify folder was created using the same path helper function + if _, err := os.Stat(GetCommandsStoragePath()); os.IsNotExist(err) { t.Error("Storage folder was not created") } diff --git a/model/db_test.go b/model/db_test.go index ef47b9e..ec66f31 100644 --- a/model/db_test.go +++ b/model/db_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "path/filepath" "testing" "time" ) @@ -68,8 +67,8 @@ func TestGetPreCommandsTree(t *testing.T) { COMMAND_STORAGE_FOLDER = tempDir + "/commands" COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" - // Create the commands directory - if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + // Create the commands directory using the path function to ensure consistency + if err := os.MkdirAll(GetCommandsStoragePath(), 0755); err != nil { t.Fatalf("Failed to create commands dir: %v", err) } @@ -178,8 +177,8 @@ func TestGetPreCommands(t *testing.T) { COMMAND_STORAGE_FOLDER = tempDir + "/commands" COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" - // Create the commands directory - if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + // Create the commands directory using the path function to ensure consistency + if err := os.MkdirAll(GetCommandsStoragePath(), 0755); err != nil { t.Fatalf("Failed to create commands dir: %v", err) } @@ -234,8 +233,8 @@ func TestGetPreCommands_EmptyLines(t *testing.T) { COMMAND_STORAGE_FOLDER = tempDir + "/commands" COMMAND_PRE_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/pre.txt" - // Create the commands directory - if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + // Create the commands directory using the path function to ensure consistency + if err := os.MkdirAll(GetCommandsStoragePath(), 0755); err != nil { t.Fatalf("Failed to create commands dir: %v", err) } @@ -315,14 +314,14 @@ func TestGetLastCursor_WithCursor(t *testing.T) { COMMAND_STORAGE_FOLDER = tempDir + "/commands" COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" - // Create the commands directory - if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + // Create the commands directory using the path function to ensure consistency + if err := os.MkdirAll(GetCommandsStoragePath(), 0755); err != nil { t.Fatalf("Failed to create commands dir: %v", err) } - // Write cursor timestamp + // Write cursor timestamp using the path helper function expectedTime := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) - cursorFile := filepath.Join(COMMAND_STORAGE_FOLDER, "cursor.txt") + cursorFile := GetCursorFilePath() if err := os.WriteFile(cursorFile, []byte(fmt.Sprintf("%d", expectedTime.UnixNano())), 0644); err != nil { t.Fatalf("Failed to write cursor file: %v", err) } @@ -362,15 +361,15 @@ func TestGetLastCursor_MultipleLines(t *testing.T) { COMMAND_STORAGE_FOLDER = tempDir + "/commands" COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" - // Create the commands directory - if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + // Create the commands directory using the path function to ensure consistency + if err := os.MkdirAll(GetCommandsStoragePath(), 0755); err != nil { t.Fatalf("Failed to create commands dir: %v", err) } // Write multiple cursor timestamps (should use last one) time1 := time.Date(2024, 1, 10, 10, 0, 0, 0, time.UTC) time2 := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) // Last line - cursorFile := filepath.Join(COMMAND_STORAGE_FOLDER, "cursor.txt") + cursorFile := GetCursorFilePath() content := fmt.Sprintf("%d\n%d\n", time1.UnixNano(), time2.UnixNano()) if err := os.WriteFile(cursorFile, []byte(content), 0644); err != nil { t.Fatalf("Failed to write cursor file: %v", err) @@ -407,13 +406,13 @@ func TestGetLastCursor_InvalidContent(t *testing.T) { COMMAND_STORAGE_FOLDER = tempDir + "/commands" COMMAND_CURSOR_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/cursor.txt" - // Create the commands directory - if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + // Create the commands directory using the path function to ensure consistency + if err := os.MkdirAll(GetCommandsStoragePath(), 0755); err != nil { t.Fatalf("Failed to create commands dir: %v", err) } // Write invalid cursor content - cursorFile := filepath.Join(COMMAND_STORAGE_FOLDER, "cursor.txt") + cursorFile := GetCursorFilePath() if err := os.WriteFile(cursorFile, []byte("not-a-number"), 0644); err != nil { t.Fatalf("Failed to write cursor file: %v", err) } @@ -445,8 +444,8 @@ func TestGetPostCommands(t *testing.T) { COMMAND_STORAGE_FOLDER = tempDir + "/commands" COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" - // Create the commands directory - if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + // Create the commands directory using the path function to ensure consistency + if err := os.MkdirAll(GetCommandsStoragePath(), 0755); err != nil { t.Fatalf("Failed to create commands dir: %v", err) } @@ -504,8 +503,8 @@ func TestGetPostCommands_EmptyLines(t *testing.T) { COMMAND_STORAGE_FOLDER = tempDir + "/commands" COMMAND_POST_STORAGE_FILE = COMMAND_STORAGE_FOLDER + "/post.txt" - // Create the commands directory - if err := os.MkdirAll(COMMAND_STORAGE_FOLDER, 0755); err != nil { + // Create the commands directory using the path function to ensure consistency + if err := os.MkdirAll(GetCommandsStoragePath(), 0755); err != nil { t.Fatalf("Failed to create commands dir: %v", err) }