diff --git a/README.md b/README.md index ed7b20e..0f8ca58 100644 --- a/README.md +++ b/README.md @@ -103,10 +103,21 @@ All containers are prefixed `px-` internally. Commands accept bare names (e.g., ## SSH Access -**Console** opens an interactive SSH session. If the container is stopped, it starts it automatically: +**Console** opens an interactive SSH session with zmx session persistence. +Disconnecting and reconnecting re-attaches to the same session: ```bash -pixels console mybox +pixels console mybox # default "console" session +pixels console mybox -s build # named session +pixels console mybox --no-persist # plain SSH, no zmx +``` + +Inside a session, press `Ctrl+\` to detach (works in TUIs too), or type `detach`. + +**Sessions** lists zmx sessions in a container: + +```bash +pixels sessions mybox ``` **Exec** runs a command and returns its exit code: diff --git a/cmd/console.go b/cmd/console.go index 185e446..02b7c6d 100644 --- a/cmd/console.go +++ b/cmd/console.go @@ -1,7 +1,9 @@ package cmd import ( + "context" "fmt" + "regexp" "time" "github.com/briandowns/spinner" @@ -12,19 +14,31 @@ import ( "github.com/deevus/pixels/internal/ssh" ) +var validSessionName = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + func init() { - rootCmd.AddCommand(&cobra.Command{ + cmd := &cobra.Command{ Use: "console ", - Short: "Open an interactive SSH session", + Short: "Open a persistent SSH session (zmx)", Args: cobra.ExactArgs(1), RunE: runConsole, - }) + } + cmd.Flags().StringP("session", "s", "console", "zmx session name") + cmd.Flags().Bool("no-persist", false, "skip zmx, use plain SSH") + rootCmd.AddCommand(cmd) } func runConsole(cmd *cobra.Command, args []string) error { ctx := cmd.Context() name := args[0] + session, _ := cmd.Flags().GetString("session") + noPersist, _ := cmd.Flags().GetBool("no-persist") + + if !noPersist && !validSessionName.MatchString(session) { + return fmt.Errorf("invalid session name %q: must match [a-zA-Z0-9._-]", session) + } + // Try local cache first for fast path (already running). var ip string if cached := cache.Get(name); cached != nil && cached.IP != "" && cached.Status == "RUNNING" { @@ -74,7 +88,7 @@ func runConsole(cmd *cobra.Command, args []string) error { } // Wait for provisioning to finish before opening the console. - runner := &provision.Runner{Host: ip, User: "root", KeyPath: cfg.SSH.Key} + runner := provision.NewRunner(ip, "root", cfg.SSH.Key) var spin *spinner.Spinner if !verbose { spin = spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmd.ErrOrStderr())) @@ -93,6 +107,26 @@ func runConsole(cmd *cobra.Command, args []string) error { spin.Stop() } + cc := ssh.ConnConfig{Host: ip, User: cfg.SSH.User, KeyPath: cfg.SSH.Key, Env: cfg.EnvForward} + + // Determine remote command for zmx session persistence. + var remoteCmd string + if !noPersist { + remoteCmd = zmxRemoteCmd(ctx, cc, session) + } + // Console replaces the process — does not return on success. - return ssh.Console(ip, cfg.SSH.User, cfg.SSH.Key, cfg.EnvForward) + return ssh.Console(cc, remoteCmd) +} + +// zmxRemoteCmd checks if zmx is available in the container and returns the +// attach command string. Returns empty string if zmx is not installed. +func zmxRemoteCmd(ctx context.Context, cc ssh.ConnConfig, session string) string { + // Check without env forwarding to avoid polluting the zmx check. + checkCC := ssh.ConnConfig{Host: cc.Host, User: cc.User, KeyPath: cc.KeyPath} + code, err := ssh.ExecQuiet(ctx, checkCC, []string{"command -v zmx >/dev/null 2>&1"}) + if err == nil && code == 0 { + return "unset XDG_RUNTIME_DIR && zmx attach " + session + " bash -l" + } + return "" } diff --git a/cmd/create.go b/cmd/create.go index 50d9ee9..6b0d9f6 100644 --- a/cmd/create.go +++ b/cmd/create.go @@ -313,13 +313,14 @@ func runCreate(cmd *cobra.Command, args []string) error { } if openConsole && ip != "" { - runner := &provision.Runner{Host: ip, User: "root", KeyPath: cfg.SSH.Key} + runner := provision.NewRunner(ip, "root", cfg.SSH.Key) runner.WaitProvisioned(ctx, func(status string) { setStatus(status) logv(cmd, "Provision: %s", status) }) stopSpinner() - return ssh.Console(ip, cfg.SSH.User, cfg.SSH.Key, cfg.EnvForward) + cc := ssh.ConnConfig{Host: ip, User: cfg.SSH.User, KeyPath: cfg.SSH.Key, Env: cfg.EnvForward} + return ssh.Console(cc, zmxRemoteCmd(ctx, cc, "console")) } return nil diff --git a/cmd/exec.go b/cmd/exec.go index bef125a..db7c2a5 100644 --- a/cmd/exec.go +++ b/cmd/exec.go @@ -38,7 +38,7 @@ func runExec(cmd *cobra.Command, args []string) error { return err } - exitCode, err := ssh.Exec(ctx, ip, cfg.SSH.User, cfg.SSH.Key, command, cfg.EnvForward) + exitCode, err := ssh.Exec(ctx, ssh.ConnConfig{Host: ip, User: cfg.SSH.User, KeyPath: cfg.SSH.Key, Env: cfg.EnvForward}, command) if err != nil { return err } diff --git a/cmd/network.go b/cmd/network.go index 3f80768..0ffe0d5 100644 --- a/cmd/network.go +++ b/cmd/network.go @@ -93,7 +93,7 @@ func resolveNetworkContext(cmd *cobra.Command, name string) (*networkContext, er // sshAsRoot runs a command on the container as root via SSH. func sshAsRoot(cmd *cobra.Command, ip string, command []string) (int, error) { - return ssh.Exec(cmd.Context(), ip, "root", cfg.SSH.Key, command, nil) + return ssh.Exec(cmd.Context(), ssh.ConnConfig{Host: ip, User: "root", KeyPath: cfg.SSH.Key}, command) } func runNetworkShow(cmd *cobra.Command, args []string) error { @@ -199,7 +199,7 @@ func runNetworkAllow(cmd *cobra.Command, args []string) error { } // Read current domains via SSH. - out, err := ssh.Output(ctx, nc.ip, "root", cfg.SSH.Key, []string{"cat", "/etc/pixels-egress-domains"}) + out, err := ssh.Output(ctx, ssh.ConnConfig{Host: nc.ip, User: "root", KeyPath: cfg.SSH.Key}, []string{"cat", "/etc/pixels-egress-domains"}) if err != nil { return fmt.Errorf("reading domains file: %w", err) } @@ -244,7 +244,7 @@ func runNetworkDeny(cmd *cobra.Command, args []string) error { cname := containerName(name) // Read current domains via SSH. - out, err := ssh.Output(ctx, nc.ip, "root", cfg.SSH.Key, []string{"cat", "/etc/pixels-egress-domains"}) + out, err := ssh.Output(ctx, ssh.ConnConfig{Host: nc.ip, User: "root", KeyPath: cfg.SSH.Key}, []string{"cat", "/etc/pixels-egress-domains"}) if err != nil { return fmt.Errorf("no egress policy configured on %s", name) } diff --git a/cmd/resolve.go b/cmd/resolve.go index c6992eb..633ebae 100644 --- a/cmd/resolve.go +++ b/cmd/resolve.go @@ -105,7 +105,7 @@ func readSSHPubKey() (string, error) { // ensureSSHAuth tests key auth and, if it fails, writes the current machine's // SSH public key to the container's authorized_keys via TrueNAS. func ensureSSHAuth(cmd *cobra.Command, ctx context.Context, ip, name string) error { - if err := ssh.TestAuth(ctx, ip, cfg.SSH.User, cfg.SSH.Key); err == nil { + if err := ssh.TestAuth(ctx, ssh.ConnConfig{Host: ip, User: cfg.SSH.User, KeyPath: cfg.SSH.Key}); err == nil { return nil } diff --git a/cmd/resolve_test.go b/cmd/resolve_test.go index 1399b05..992aa85 100644 --- a/cmd/resolve_test.go +++ b/cmd/resolve_test.go @@ -9,6 +9,32 @@ import ( truenas "github.com/deevus/truenas-go" ) +func TestValidSessionName(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"console", "console", true}, + {"build", "build", true}, + {"my-session", "my-session", true}, + {"test.1", "test.1", true}, + {"a_b", "a_b", true}, + {"empty", "", false}, + {"has space", "has space", false}, + {"semicolon", "semi;colon", false}, + {"backtick", "back`tick", false}, + {"newline", "new\nline", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validSessionName.MatchString(tt.input); got != tt.want { + t.Errorf("validSessionName.MatchString(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + func TestContainerName(t *testing.T) { if got := containerName("my-project"); got != "px-my-project" { t.Errorf("containerName(my-project) = %q, want %q", got, "px-my-project") diff --git a/cmd/sessions.go b/cmd/sessions.go new file mode 100644 index 0000000..55bcde7 --- /dev/null +++ b/cmd/sessions.go @@ -0,0 +1,64 @@ +package cmd + +import ( + "fmt" + "strings" + "time" + + "github.com/spf13/cobra" + + "github.com/deevus/pixels/internal/provision" + "github.com/deevus/pixels/internal/ssh" +) + +func init() { + rootCmd.AddCommand(&cobra.Command{ + Use: "sessions ", + Short: "List zmx sessions in a container", + Args: cobra.ExactArgs(1), + RunE: runSessions, + }) +} + +func runSessions(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + name := args[0] + + ip, err := resolveRunningIP(ctx, name) + if err != nil { + return err + } + + if err := ssh.WaitReady(ctx, ip, 30*time.Second, nil); err != nil { + return fmt.Errorf("waiting for SSH: %w", err) + } + + cc := ssh.ConnConfig{Host: ip, User: cfg.SSH.User, KeyPath: cfg.SSH.Key} + out, err := ssh.OutputQuiet(ctx, cc, []string{"unset XDG_RUNTIME_DIR && zmx list"}) + if err != nil { + return fmt.Errorf("zmx not available on %s", name) + } + + raw := strings.TrimSpace(string(out)) + if raw == "" { + fmt.Fprintln(cmd.OutOrStdout(), "No sessions") + return nil + } + + sessions := provision.ParseSessions(raw) + if len(sessions) == 0 { + fmt.Fprintln(cmd.OutOrStdout(), "No sessions") + return nil + } + + tw := newTabWriter(cmd) + fmt.Fprintln(tw, "SESSION\tSTATUS") + for _, s := range sessions { + status := "running" + if s.EndedAt != "" { + status = "exited" + } + fmt.Fprintf(tw, "%s\t%s\n", s.Name, status) + } + return tw.Flush() +} diff --git a/cmd/status.go b/cmd/status.go index 895dab2..0c83e76 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -34,7 +34,7 @@ func runStatus(cmd *cobra.Command, args []string) error { return fmt.Errorf("waiting for SSH: %w", err) } - runner := &provision.Runner{Host: ip, User: "root", KeyPath: cfg.SSH.Key} + runner := provision.NewRunner(ip, "root", cfg.SSH.Key) raw, err := runner.List(ctx) if err != nil { if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "No such file") { diff --git a/internal/provision/provision.go b/internal/provision/provision.go index 9a1f968..419fc3f 100644 --- a/internal/provision/provision.go +++ b/internal/provision/provision.go @@ -27,12 +27,66 @@ type Step struct { Finalize string // optional: runs after ALL steps complete (not tracked by zmx) } +// Executor runs commands on a remote host. +type Executor interface { + // Exec runs a command and returns its exit code. + Exec(ctx context.Context, command []string) (int, error) + // Output runs a command and returns its stdout. + Output(ctx context.Context, command []string) ([]byte, error) +} + +// MockExecutor is a test double for Executor. +type MockExecutor struct { + ExecFunc func(ctx context.Context, command []string) (int, error) + OutputFunc func(ctx context.Context, command []string) ([]byte, error) +} + +func (m *MockExecutor) Exec(ctx context.Context, command []string) (int, error) { + return m.ExecFunc(ctx, command) +} + +func (m *MockExecutor) Output(ctx context.Context, command []string) ([]byte, error) { + return m.OutputFunc(ctx, command) +} + +// sshExecutor implements Executor by shelling out to SSH. +type sshExecutor struct { + cc ssh.ConnConfig +} + +func (e *sshExecutor) Exec(ctx context.Context, command []string) (int, error) { + return ssh.ExecQuiet(ctx, e.cc, command) +} + +func (e *sshExecutor) Output(ctx context.Context, command []string) ([]byte, error) { + return ssh.OutputQuiet(ctx, e.cc, command) +} + // Runner executes and monitors zmx provisioning steps over SSH. type Runner struct { Host string User string // typically "root" KeyPath string Log io.Writer + exec Executor +} + +// NewRunner creates a Runner that executes commands over SSH. +func NewRunner(host, user, keyPath string) *Runner { + return &Runner{ + Host: host, + User: user, + KeyPath: keyPath, + exec: &sshExecutor{ + cc: ssh.ConnConfig{Host: host, User: user, KeyPath: keyPath}, + }, + } +} + +// NewRunnerWith creates a Runner using the provided Executor. +// This is intended for testing. +func NewRunnerWith(exec Executor) *Runner { + return &Runner{exec: exec} } func (r *Runner) logf(format string, a ...any) { @@ -54,7 +108,7 @@ func (r *Runner) InstallZmx(ctx context.Context) error { url := fmt.Sprintf("https://zmx.sh/a/zmx-%s-linux-x86_64.tar.gz", zmxVersion) script := fmt.Sprintf("curl -fsSL %s | tar xz -C /usr/local/bin/", url) r.logf("Installing zmx %s...", zmxVersion) - code, err := ssh.ExecQuiet(ctx, r.Host, r.User, r.KeyPath, []string{script}) + code, err := r.exec.Exec(ctx, []string{script}) if err != nil { return fmt.Errorf("installing zmx: %w", err) } @@ -72,7 +126,7 @@ func (r *Runner) Run(ctx context.Context, step Step) error { // Redirect stdout/stderr so SSH doesn't wait for the background zmx // session to finish (it inherits the FDs from zmx run). cmd := zmxCmd(fmt.Sprintf("zmx run %s %s >/dev/null 2>&1", step.Name, step.Script)) - code, err := ssh.ExecQuiet(ctx, r.Host, r.User, r.KeyPath, []string{cmd}) + code, err := r.exec.Exec(ctx, []string{cmd}) if err != nil { return fmt.Errorf("starting %s: %w", step.Name, err) } @@ -85,7 +139,7 @@ func (r *Runner) Run(ctx context.Context, step Step) error { // Wait blocks until all named zmx sessions complete. func (r *Runner) Wait(ctx context.Context, names ...string) error { cmd := zmxCmd("zmx wait " + strings.Join(names, " ")) - code, err := ssh.ExecQuiet(ctx, r.Host, r.User, r.KeyPath, []string{cmd}) + code, err := r.exec.Exec(ctx, []string{cmd}) if err != nil { return fmt.Errorf("waiting for steps: %w", err) } @@ -98,7 +152,7 @@ func (r *Runner) Wait(ctx context.Context, names ...string) error { // List runs zmx list and returns the raw output. The caller can display // this directly or parse it for structured status information. func (r *Runner) List(ctx context.Context) (string, error) { - out, err := ssh.OutputQuiet(ctx, r.Host, r.User, r.KeyPath, []string{zmxCmd("zmx list")}) + out, err := r.exec.Output(ctx, []string{zmxCmd("zmx list")}) if err != nil { return "", fmt.Errorf("listing zmx sessions: %w", err) } @@ -107,7 +161,7 @@ func (r *Runner) List(ctx context.Context) (string, error) { // History returns the scrollback output of a completed zmx session. func (r *Runner) History(ctx context.Context, name string) (string, error) { - out, err := ssh.OutputQuiet(ctx, r.Host, r.User, r.KeyPath, []string{zmxCmd("zmx history " + name)}) + out, err := r.exec.Output(ctx, []string{zmxCmd("zmx history " + name)}) if err != nil { return "", fmt.Errorf("getting history for %s: %w", name, err) } @@ -116,13 +170,13 @@ func (r *Runner) History(ctx context.Context, name string) (string, error) { // IsProvisioned checks if the provision sentinel file exists. func (r *Runner) IsProvisioned(ctx context.Context) bool { - code, err := ssh.ExecQuiet(ctx, r.Host, r.User, r.KeyPath, []string{"test -f /root/.pixels-provisioned"}) + code, err := r.exec.Exec(ctx, []string{"test -f /root/.pixels-provisioned"}) return err == nil && code == 0 } // HasProvisionScript checks if the provision script was written to the container. func (r *Runner) HasProvisionScript(ctx context.Context) bool { - code, err := ssh.ExecQuiet(ctx, r.Host, r.User, r.KeyPath, []string{"test -x /usr/local/bin/pixels-provision.sh"}) + code, err := r.exec.Exec(ctx, []string{"test -x /usr/local/bin/pixels-provision.sh"}) return err == nil && code == 0 } diff --git a/internal/provision/provision_test.go b/internal/provision/provision_test.go index 3045bf3..5354cc1 100644 --- a/internal/provision/provision_test.go +++ b/internal/provision/provision_test.go @@ -1,10 +1,340 @@ package provision import ( + "context" + "errors" "strings" "testing" ) +func TestZmxCmd(t *testing.T) { + got := zmxCmd("zmx list") + want := "unset XDG_RUNTIME_DIR && zmx list" + if got != want { + t.Errorf("zmxCmd(\"zmx list\") = %q, want %q", got, want) + } +} + +func TestNewRunner(t *testing.T) { + r := NewRunner("10.0.0.1", "root", "/tmp/key") + if r.Host != "10.0.0.1" { + t.Errorf("Host = %q, want %q", r.Host, "10.0.0.1") + } + if r.User != "root" { + t.Errorf("User = %q, want %q", r.User, "root") + } + if r.KeyPath != "/tmp/key" { + t.Errorf("KeyPath = %q, want %q", r.KeyPath, "/tmp/key") + } + if r.exec == nil { + t.Fatal("exec should not be nil") + } +} + +func TestInstallZmx(t *testing.T) { + tests := []struct { + name string + code int + err error + wantErr string + }{ + {"success", 0, nil, ""}, + {"ssh error", 0, errors.New("connection refused"), "installing zmx:"}, + {"non-zero exit", 5, nil, "installing zmx: exit code 5"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var captured []string + r := NewRunnerWith(&MockExecutor{ + ExecFunc: func(ctx context.Context, command []string) (int, error) { + captured = command + return tt.code, tt.err + }, + }) + err := r.InstallZmx(context.Background()) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(captured) == 0 || !strings.Contains(captured[0], zmxVersion) { + t.Errorf("command should contain zmx version, got %v", captured) + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %v, want containing %q", err, tt.wantErr) + } + } + }) + } +} + +func TestRun(t *testing.T) { + tests := []struct { + name string + code int + err error + wantErr string + }{ + {"success", 0, nil, ""}, + {"ssh error", 0, errors.New("connection refused"), "starting px-test:"}, + {"non-zero exit", 1, nil, "starting px-test: exit code 1"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var captured []string + r := NewRunnerWith(&MockExecutor{ + ExecFunc: func(ctx context.Context, command []string) (int, error) { + captured = command + return tt.code, tt.err + }, + }) + step := Step{Name: "px-test", Script: "/usr/local/bin/test.sh"} + err := r.Run(context.Background(), step) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cmd := captured[0] + if !strings.Contains(cmd, "zmx run px-test /usr/local/bin/test.sh") { + t.Errorf("command missing zmx run, got %q", cmd) + } + if !strings.Contains(cmd, ">/dev/null 2>&1") { + t.Errorf("command missing redirect, got %q", cmd) + } + if !strings.HasPrefix(cmd, "unset XDG_RUNTIME_DIR") { + t.Errorf("command missing XDG unset, got %q", cmd) + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %v, want containing %q", err, tt.wantErr) + } + } + }) + } +} + +func TestWait(t *testing.T) { + tests := []struct { + name string + names []string + code int + err error + wantErr string + }{ + {"single name", []string{"px-devtools"}, 0, nil, ""}, + {"multiple names", []string{"px-devtools", "px-egress"}, 0, nil, ""}, + {"ssh error", []string{"px-test"}, 0, errors.New("timeout"), "waiting for steps:"}, + {"non-zero exit", []string{"px-test"}, 1, nil, "one or more provisioning steps failed"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var captured []string + r := NewRunnerWith(&MockExecutor{ + ExecFunc: func(ctx context.Context, command []string) (int, error) { + captured = command + return tt.code, tt.err + }, + }) + err := r.Wait(context.Background(), tt.names...) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cmd := captured[0] + for _, n := range tt.names { + if !strings.Contains(cmd, n) { + t.Errorf("command missing name %q, got %q", n, cmd) + } + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %v, want containing %q", err, tt.wantErr) + } + } + }) + } +} + +func TestList(t *testing.T) { + t.Run("success", func(t *testing.T) { + r := NewRunnerWith(&MockExecutor{ + OutputFunc: func(ctx context.Context, command []string) ([]byte, error) { + return []byte(" session_name=px-test\tpid=1 \n"), nil + }, + }) + out, err := r.List(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "session_name=px-test\tpid=1" { + t.Errorf("output = %q, want trimmed", out) + } + }) + + t.Run("error", func(t *testing.T) { + r := NewRunnerWith(&MockExecutor{ + OutputFunc: func(ctx context.Context, command []string) ([]byte, error) { + return nil, errors.New("connection refused") + }, + }) + _, err := r.List(context.Background()) + if err == nil || !strings.Contains(err.Error(), "listing zmx sessions") { + t.Errorf("error = %v, want containing 'listing zmx sessions'", err) + } + }) +} + +func TestHistory(t *testing.T) { + t.Run("success", func(t *testing.T) { + r := NewRunnerWith(&MockExecutor{ + OutputFunc: func(ctx context.Context, command []string) ([]byte, error) { + if !strings.Contains(command[0], "zmx history px-test") { + t.Errorf("command missing history, got %v", command) + } + return []byte("line1\nline2\n"), nil + }, + }) + out, err := r.History(context.Background(), "px-test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "line1\nline2\n" { + t.Errorf("output = %q", out) + } + }) + + t.Run("error", func(t *testing.T) { + r := NewRunnerWith(&MockExecutor{ + OutputFunc: func(ctx context.Context, command []string) ([]byte, error) { + return nil, errors.New("not found") + }, + }) + _, err := r.History(context.Background(), "px-test") + if err == nil || !strings.Contains(err.Error(), "getting history for px-test") { + t.Errorf("error = %v, want containing 'getting history'", err) + } + }) +} + +func TestIsProvisioned(t *testing.T) { + tests := []struct { + name string + code int + err error + want bool + }{ + {"file exists", 0, nil, true}, + {"file missing", 1, nil, false}, + {"ssh error", 0, errors.New("timeout"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewRunnerWith(&MockExecutor{ + ExecFunc: func(ctx context.Context, command []string) (int, error) { + if !strings.Contains(command[0], ".pixels-provisioned") { + t.Errorf("command should check sentinel, got %v", command) + } + return tt.code, tt.err + }, + }) + if got := r.IsProvisioned(context.Background()); got != tt.want { + t.Errorf("IsProvisioned() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHasProvisionScript(t *testing.T) { + tests := []struct { + name string + code int + err error + want bool + }{ + {"script exists", 0, nil, true}, + {"script missing", 1, nil, false}, + {"ssh error", 0, errors.New("timeout"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewRunnerWith(&MockExecutor{ + ExecFunc: func(ctx context.Context, command []string) (int, error) { + if !strings.Contains(command[0], "pixels-provision.sh") { + t.Errorf("command should check provision script, got %v", command) + } + return tt.code, tt.err + }, + }) + if got := r.HasProvisionScript(context.Background()); got != tt.want { + t.Errorf("HasProvisionScript() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPollStatus(t *testing.T) { + tests := []struct { + name string + output string + outErr error + names []string + wantStr string + wantDone bool + }{ + { + name: "all done", + output: "session_name=px-devtools\ttask_ended_at=100\ttask_exit_code=0", + names: []string{"px-devtools"}, + wantStr: "px-devtools done", + wantDone: true, + }, + { + name: "still running", + output: "session_name=px-devtools\tpid=1", + names: []string{"px-devtools"}, + wantStr: "px-devtools running", + wantDone: false, + }, + { + name: "step pending (not in list)", + output: "", + names: []string{"px-devtools"}, + wantStr: "px-devtools pending", + wantDone: false, + }, + { + name: "step failed", + output: "session_name=px-devtools\ttask_ended_at=100\ttask_exit_code=1", + names: []string{"px-devtools"}, + wantStr: "px-devtools failed (exit 1)", + wantDone: true, + }, + { + name: "list error", + outErr: errors.New("connection refused"), + names: []string{"px-devtools"}, + wantStr: "", + wantDone: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewRunnerWith(&MockExecutor{ + OutputFunc: func(ctx context.Context, command []string) ([]byte, error) { + return []byte(tt.output), tt.outErr + }, + }) + status, done := r.PollStatus(context.Background(), tt.names) + if status != tt.wantStr { + t.Errorf("status = %q, want %q", status, tt.wantStr) + } + if done != tt.wantDone { + t.Errorf("done = %v, want %v", done, tt.wantDone) + } + }) + } +} + func TestSteps(t *testing.T) { tests := []struct { name string diff --git a/internal/ssh/console_unix.go b/internal/ssh/console_unix.go index 63fba33..8ea8642 100644 --- a/internal/ssh/console_unix.go +++ b/internal/ssh/console_unix.go @@ -10,12 +10,12 @@ import ( ) // Console replaces the current process with an interactive SSH session. -// If env is non-nil, the entries are forwarded via SSH SetEnv. -func Console(host, user, keyPath string, env map[string]string) error { +// If remoteCmd is non-empty, it is executed in a forced PTY on the remote host. +func Console(cc ConnConfig, remoteCmd string) error { sshBin, err := exec.LookPath("ssh") if err != nil { return fmt.Errorf("ssh binary not found: %w", err) } - args := append([]string{"ssh"}, sshArgs(host, user, keyPath, env)...) + args := append([]string{"ssh"}, consoleArgs(cc, remoteCmd)...) return syscall.Exec(sshBin, args, os.Environ()) } diff --git a/internal/ssh/console_windows.go b/internal/ssh/console_windows.go index bdafef7..3c712a1 100644 --- a/internal/ssh/console_windows.go +++ b/internal/ssh/console_windows.go @@ -9,13 +9,13 @@ import ( ) // Console runs an interactive SSH session as a child process. -// If env is non-nil, the entries are forwarded via SSH SetEnv. -func Console(host, user, keyPath string, env map[string]string) error { +// If remoteCmd is non-empty, it is executed in a forced PTY on the remote host. +func Console(cc ConnConfig, remoteCmd string) error { sshBin, err := exec.LookPath("ssh") if err != nil { return fmt.Errorf("ssh binary not found: %w", err) } - cmd := exec.Command(sshBin, sshArgs(host, user, keyPath, env)...) + cmd := exec.Command(sshBin, consoleArgs(cc, remoteCmd)...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 429e76d..efa89bd 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -13,6 +13,14 @@ import ( "time" ) +// ConnConfig holds the parameters for an SSH connection. +type ConnConfig struct { + Host string + User string + KeyPath string + Env map[string]string // optional, for SetEnv forwarding +} + // WaitReady polls the host's SSH port until it accepts connections or the timeout expires. // If log is non-nil, progress is written every 5 seconds. func WaitReady(ctx context.Context, host string, timeout time.Duration, log io.Writer) error { @@ -46,9 +54,8 @@ func WaitReady(ctx context.Context, host string, timeout time.Duration, log io.W } // Exec runs a command on the remote host via SSH and returns its exit code. -// If env is non-nil, the entries are forwarded via SSH SetEnv. -func Exec(ctx context.Context, host, user, keyPath string, command []string, env map[string]string) (int, error) { - args := append(sshArgs(host, user, keyPath, env), command...) +func Exec(ctx context.Context, cc ConnConfig, command []string) (int, error) { + args := append(sshArgs(cc), command...) cmd := exec.CommandContext(ctx, "ssh", args...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -66,8 +73,8 @@ func Exec(ctx context.Context, host, user, keyPath string, command []string, env // ExecQuiet runs a non-interactive command on the remote host via SSH and // returns its exit code. Unlike Exec, it does not attach stdin/stdout/stderr. -func ExecQuiet(ctx context.Context, host, user, keyPath string, command []string) (int, error) { - args := append(sshArgs(host, user, keyPath, nil), command...) +func ExecQuiet(ctx context.Context, cc ConnConfig, command []string) (int, error) { + args := append(sshArgs(cc), command...) cmd := exec.CommandContext(ctx, "ssh", args...) if err := cmd.Run(); err != nil { @@ -81,8 +88,8 @@ func ExecQuiet(ctx context.Context, host, user, keyPath string, command []string } // Output runs a command on the remote host via SSH and returns its stdout. -func Output(ctx context.Context, host, user, keyPath string, command []string) ([]byte, error) { - args := append(sshArgs(host, user, keyPath, nil), command...) +func Output(ctx context.Context, cc ConnConfig, command []string) ([]byte, error) { + args := append(sshArgs(cc), command...) cmd := exec.CommandContext(ctx, "ssh", args...) cmd.Stderr = os.Stderr return cmd.Output() @@ -90,37 +97,37 @@ func Output(ctx context.Context, host, user, keyPath string, command []string) ( // OutputQuiet runs a command on the remote host via SSH and returns its stdout, // discarding stderr. Use this when parsing command output programmatically. -func OutputQuiet(ctx context.Context, host, user, keyPath string, command []string) ([]byte, error) { - args := append(sshArgs(host, user, keyPath, nil), command...) +func OutputQuiet(ctx context.Context, cc ConnConfig, command []string) ([]byte, error) { + args := append(sshArgs(cc), command...) cmd := exec.CommandContext(ctx, "ssh", args...) return cmd.Output() } // TestAuth runs a quick SSH connection test (ssh ... true) to verify // key-based authentication works. Returns nil on success. -func TestAuth(ctx context.Context, host, user, keyPath string) error { - args := append(sshArgs(host, user, keyPath, nil), "true") +func TestAuth(ctx context.Context, cc ConnConfig) error { + args := append(sshArgs(cc), "true") cmd := exec.CommandContext(ctx, "ssh", args...) return cmd.Run() } -func sshArgs(host, user, keyPath string, env map[string]string) []string { +func sshArgs(cc ConnConfig) []string { args := []string{ "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=" + os.DevNull, "-o", "PasswordAuthentication=no", "-o", "LogLevel=ERROR", } - if keyPath != "" { - args = append(args, "-i", keyPath) + if cc.KeyPath != "" { + args = append(args, "-i", cc.KeyPath) } // Forward env vars via SSH protocol (requires AcceptEnv on server). // All vars must be in a single SetEnv directive (multiple -o SetEnv // flags don't stack in OpenSSH — only the first takes effect). - if len(env) > 0 { - keys := make([]string, 0, len(env)) - for k := range env { + if len(cc.Env) > 0 { + keys := make([]string, 0, len(cc.Env)) + for k := range cc.Env { keys = append(keys, k) } sort.Strings(keys) @@ -131,11 +138,25 @@ func sshArgs(host, user, keyPath string, env map[string]string) []string { if i > 0 { setenv.WriteByte(' ') } - fmt.Fprintf(&setenv, "%s=%s", k, env[k]) + fmt.Fprintf(&setenv, "%s=%s", k, cc.Env[k]) } args = append(args, "-o", setenv.String()) } - args = append(args, user+"@"+host) + args = append(args, cc.User+"@"+cc.Host) + return args +} + +// consoleArgs builds SSH arguments for an interactive console session. +// When remoteCmd is non-empty, -t is inserted to force PTY allocation +// and the command is appended after user@host. +func consoleArgs(cc ConnConfig, remoteCmd string) []string { + if remoteCmd == "" { + return sshArgs(cc) + } + args := sshArgs(cc) + // Insert -t before user@host (last element). + userHost := args[len(args)-1] + args = append(args[:len(args)-1], "-t", userHost, remoteCmd) return args } diff --git a/internal/ssh/ssh_test.go b/internal/ssh/ssh_test.go index fd0931e..219d659 100644 --- a/internal/ssh/ssh_test.go +++ b/internal/ssh/ssh_test.go @@ -8,7 +8,7 @@ import ( func TestConsole_SSHNotFound(t *testing.T) { t.Setenv("PATH", t.TempDir()) // empty dir, no ssh binary - err := Console("10.0.0.1", "pixel", "", nil) + err := Console(ConnConfig{Host: "10.0.0.1", User: "pixel"}, "") if err == nil { t.Fatal("expected error when ssh is not on PATH") } @@ -19,7 +19,7 @@ func TestConsole_SSHNotFound(t *testing.T) { func TestSSHArgs(t *testing.T) { t.Run("with key", func(t *testing.T) { - args := sshArgs("10.0.0.1", "pixel", "/tmp/key", nil) + args := sshArgs(ConnConfig{Host: "10.0.0.1", User: "pixel", KeyPath: "/tmp/key"}) wantSuffix := []string{"-i", "/tmp/key", "pixel@10.0.0.1"} got := args[len(args)-3:] for i, w := range wantSuffix { @@ -30,7 +30,7 @@ func TestSSHArgs(t *testing.T) { }) t.Run("uses os.DevNull for UserKnownHostsFile", func(t *testing.T) { - args := sshArgs("10.0.0.1", "pixel", "", nil) + args := sshArgs(ConnConfig{Host: "10.0.0.1", User: "pixel"}) want := "UserKnownHostsFile=" + os.DevNull found := false for _, a := range args { @@ -45,7 +45,7 @@ func TestSSHArgs(t *testing.T) { }) t.Run("without key", func(t *testing.T) { - args := sshArgs("10.0.0.1", "pixel", "", nil) + args := sshArgs(ConnConfig{Host: "10.0.0.1", User: "pixel"}) last := args[len(args)-1] if last != "pixel@10.0.0.1" { t.Errorf("last arg = %q, want %q", last, "pixel@10.0.0.1") @@ -58,11 +58,15 @@ func TestSSHArgs(t *testing.T) { }) t.Run("SetEnv with env vars", func(t *testing.T) { - env := map[string]string{ - "GITHUB_TOKEN": "ghp_abc123", - "API_KEY": "sk-secret", + cc := ConnConfig{ + Host: "10.0.0.1", + User: "pixel", + Env: map[string]string{ + "GITHUB_TOKEN": "ghp_abc123", + "API_KEY": "sk-secret", + }, } - args := sshArgs("10.0.0.1", "pixel", "", env) + args := sshArgs(cc) // All vars should be in a single SetEnv directive (space-separated, // sorted by key), preceded by -o. Multiple -o SetEnv flags don't @@ -91,7 +95,7 @@ func TestSSHArgs(t *testing.T) { }) t.Run("nil env produces no SetEnv", func(t *testing.T) { - args := sshArgs("10.0.0.1", "pixel", "", nil) + args := sshArgs(ConnConfig{Host: "10.0.0.1", User: "pixel"}) for _, a := range args { if strings.HasPrefix(a, "SetEnv=") { t.Errorf("unexpected SetEnv arg %q with nil env", a) @@ -100,7 +104,7 @@ func TestSSHArgs(t *testing.T) { }) t.Run("empty env produces no SetEnv", func(t *testing.T) { - args := sshArgs("10.0.0.1", "pixel", "", map[string]string{}) + args := sshArgs(ConnConfig{Host: "10.0.0.1", User: "pixel", Env: map[string]string{}}) for _, a := range args { if strings.HasPrefix(a, "SetEnv=") { t.Errorf("unexpected SetEnv arg %q with empty env", a) @@ -108,3 +112,93 @@ func TestSSHArgs(t *testing.T) { } }) } + +func TestConsoleArgs(t *testing.T) { + t.Run("no remote command", func(t *testing.T) { + cc := ConnConfig{Host: "10.0.0.1", User: "pixel", KeyPath: "/tmp/key"} + args := consoleArgs(cc, "") + sshA := sshArgs(cc) + if len(args) != len(sshA) { + t.Fatalf("len(consoleArgs) = %d, want %d (same as sshArgs)", len(args), len(sshA)) + } + for i := range args { + if args[i] != sshA[i] { + t.Errorf("args[%d] = %q, want %q", i, args[i], sshA[i]) + } + } + for _, a := range args { + if a == "-t" { + t.Error("should not include -t when remoteCmd is empty") + } + } + }) + + t.Run("no key with remote command", func(t *testing.T) { + cc := ConnConfig{Host: "10.0.0.1", User: "pixel"} + args := consoleArgs(cc, "zmx attach console") + last3 := args[len(args)-3:] + if last3[0] != "-t" { + t.Errorf("expected -t before user@host, got %q", last3[0]) + } + if last3[1] != "pixel@10.0.0.1" { + t.Errorf("expected user@host, got %q", last3[1]) + } + if last3[2] != "zmx attach console" { + t.Errorf("expected remote command, got %q", last3[2]) + } + for _, a := range args { + if a == "-i" { + t.Error("should not include -i when keyPath is empty") + } + } + }) + + t.Run("with remote command", func(t *testing.T) { + cc := ConnConfig{Host: "10.0.0.1", User: "pixel", KeyPath: "/tmp/key"} + args := consoleArgs(cc, "zmx attach console") + // Should have -t before user@host and command after + last3 := args[len(args)-3:] + if last3[0] != "-t" { + t.Errorf("expected -t before user@host, got %q", last3[0]) + } + if last3[1] != "pixel@10.0.0.1" { + t.Errorf("expected user@host, got %q", last3[1]) + } + if last3[2] != "zmx attach console" { + t.Errorf("expected remote command, got %q", last3[2]) + } + }) + + t.Run("with env and remote command", func(t *testing.T) { + cc := ConnConfig{ + Host: "10.0.0.1", + User: "pixel", + KeyPath: "/tmp/key", + Env: map[string]string{"FOO": "bar"}, + } + args := consoleArgs(cc, "zmx attach build") + + // Verify SetEnv is present. + var foundSetEnv bool + for _, a := range args { + if strings.HasPrefix(a, "SetEnv=") { + foundSetEnv = true + } + } + if !foundSetEnv { + t.Error("SetEnv not found in args") + } + + // Verify -t and command at end. + last3 := args[len(args)-3:] + if last3[0] != "-t" { + t.Errorf("expected -t, got %q", last3[0]) + } + if last3[1] != "pixel@10.0.0.1" { + t.Errorf("expected user@host, got %q", last3[1]) + } + if last3[2] != "zmx attach build" { + t.Errorf("expected remote command, got %q", last3[2]) + } + }) +} diff --git a/internal/truenas/client.go b/internal/truenas/client.go index 4f964ef..9bbe5cb 100644 --- a/internal/truenas/client.go +++ b/internal/truenas/client.go @@ -28,6 +28,9 @@ var egressSetupScript string //go:embed scripts/enable-egress.sh var egressEnableScript string +//go:embed scripts/pixels-profile.sh +var pixelsProfileScript string + // Client wraps a truenas-go WebSocket client and its typed services. type Client struct { ws client.Client @@ -160,6 +163,15 @@ func (c *Client) Provision(ctx context.Context, name string, opts ProvisionOpts) } logf("Wrote sshd AcceptEnv config") + // Shell alias for detaching zmx sessions. + if err := c.Filesystem.WriteFile(ctx, rootfs+"/etc/profile.d/pixels.sh", truenas.WriteFileParams{ + Content: []byte(pixelsProfileScript), + Mode: 0o644, + }); err != nil { + return fmt.Errorf("writing /etc/profile.d/pixels.sh: %w", err) + } + logf("Wrote detach alias") + // Write environment variables to /etc/environment (sourced by PAM on login). if len(opts.Env) > 0 { var envBuf strings.Builder diff --git a/internal/truenas/client_test.go b/internal/truenas/client_test.go index c1d2bc5..7acb7b0 100644 --- a/internal/truenas/client_test.go +++ b/internal/truenas/client_test.go @@ -236,7 +236,7 @@ func TestProvision(t *testing.T) { DevTools: true, }, pool: "tank", - wantCalls: 7, // dns + sshd config + env + root key + pixel key + setup script + rc.local + wantCalls: 8, // dns + sshd config + profile.d + env + root key + pixel key + setup script + rc.local check: func(t *testing.T, calls []writeCall) { paths := make(map[string]writeCall) for _, c := range calls { @@ -285,7 +285,7 @@ func TestProvision(t *testing.T) { DNS: []string{"1.1.1.1", "8.8.8.8"}, }, pool: "tank", - wantCalls: 5, // sshd config + dns + root key + pixel key + rc.local + wantCalls: 6, // sshd config + dns + profile.d + root key + pixel key + rc.local check: func(t *testing.T, calls []writeCall) { // No devtools files should be written. for _, c := range calls { @@ -307,7 +307,7 @@ func TestProvision(t *testing.T) { SSHPubKey: "ssh-ed25519 AAAA test@host", }, pool: "tank", - wantCalls: 4, // sshd config + root key + pixel key + rc.local + wantCalls: 5, // sshd config + profile.d + root key + pixel key + rc.local }, { name: "env only, no ssh key", @@ -315,10 +315,10 @@ func TestProvision(t *testing.T) { Env: map[string]string{"FOO": "bar"}, }, pool: "tank", - wantCalls: 2, // sshd config + /etc/environment + wantCalls: 3, // sshd config + profile.d + /etc/environment check: func(t *testing.T, calls []writeCall) { - if !strings.Contains(calls[1].path, "/etc/environment") { - t.Errorf("expected /etc/environment, got %s", calls[1].path) + if !strings.Contains(calls[2].path, "/etc/environment") { + t.Errorf("expected /etc/environment, got %s", calls[2].path) } }, }, @@ -328,13 +328,13 @@ func TestProvision(t *testing.T) { DNS: []string{"1.1.1.1"}, }, pool: "tank", - wantCalls: 2, // dns + sshd config + wantCalls: 3, // dns + sshd config + profile.d }, { name: "no ssh key no dns", opts: ProvisionOpts{}, pool: "tank", - wantCalls: 1, // sshd config only + wantCalls: 2, // sshd config + profile.d }, { name: "global config error", @@ -363,7 +363,7 @@ func TestProvision(t *testing.T) { Egress: "agent", }, pool: "tank", - wantCalls: 12, // sshd config + root key + pixel key + domains + cidrs + nftables.conf + resolve script + safe-apt + sudoers.restricted + setup-egress + enable-egress + rc.local + wantCalls: 13, // sshd config + profile.d + root key + pixel key + domains + cidrs + nftables.conf + resolve script + safe-apt + sudoers.restricted + setup-egress + enable-egress + rc.local check: func(t *testing.T, calls []writeCall) { paths := make(map[string]writeCall) for _, c := range calls { @@ -432,7 +432,7 @@ func TestProvision(t *testing.T) { ProvisionScript: "#!/bin/sh\necho hello\n", }, pool: "tank", - wantCalls: 5, // sshd config + root key + pixel key + provision script + rc.local + wantCalls: 6, // sshd config + profile.d + root key + pixel key + provision script + rc.local check: func(t *testing.T, calls []writeCall) { paths := make(map[string]writeCall) for _, c := range calls { @@ -465,7 +465,7 @@ func TestProvision(t *testing.T) { Egress: "unrestricted", }, pool: "tank", - wantCalls: 4, // sshd config + root key + pixel key + rc.local (no egress files) + wantCalls: 5, // sshd config + profile.d + root key + pixel key + rc.local (no egress files) check: func(t *testing.T, calls []writeCall) { for _, c := range calls { if strings.Contains(c.path, "pixels-egress") || strings.Contains(c.path, "nftables") { @@ -482,7 +482,7 @@ func TestProvision(t *testing.T) { EgressAllow: []string{"custom.example.com"}, }, pool: "tank", - wantCalls: 11, // sshd config + root key + pixel key + domains + nftables.conf + resolve script + safe-apt + sudoers + setup-egress + enable-egress + rc.local + wantCalls: 12, // sshd config + profile.d + root key + pixel key + domains + nftables.conf + resolve script + safe-apt + sudoers + setup-egress + enable-egress + rc.local check: func(t *testing.T, calls []writeCall) { rootfs := "/var/lib/incus/storage-pools/tank/containers/px-test/rootfs" for _, c := range calls { @@ -587,6 +587,9 @@ func TestProvision(t *testing.T) { } idx++ + // Skip profile.d/pixels.sh (always written). + idx++ + if tt.opts.SSHPubKey == "" { return } @@ -645,6 +648,23 @@ func TestProvision(t *testing.T) { } } +func TestIsZFSPathChar(t *testing.T) { + tests := []struct { + r rune + want bool + }{ + {'a', true}, {'Z', true}, {'5', true}, + {'/', true}, {'-', true}, {'_', true}, {'.', true}, {'@', true}, + {'!', false}, {' ', false}, {'$', false}, {'\n', false}, + {';', false}, {'\\', false}, {'`', false}, + } + for _, tt := range tests { + if got := isZFSPathChar(tt.r); got != tt.want { + t.Errorf("isZFSPathChar(%q) = %v, want %v", tt.r, got, tt.want) + } + } +} + func TestCreateInstance(t *testing.T) { var captured truenas.CreateVirtInstanceOpts @@ -755,6 +775,223 @@ func TestListSnapshots(t *testing.T) { } } +func TestWriteContainerFile(t *testing.T) { + tests := []struct { + name string + pool string + configErr error + writeErr error + wantErr string + wantPath string + }{ + { + name: "writes to rootfs path", + pool: "tank", + wantPath: "/var/lib/incus/storage-pools/tank/containers/px-test/rootfs/etc/test.conf", + }, + { + name: "config error", + configErr: errors.New("api failure"), + wantErr: "querying virt global config", + }, + { + name: "empty pool", + pool: "", + wantErr: "no pool", + }, + { + name: "write error", + pool: "tank", + writeErr: errors.New("disk full"), + wantErr: "disk full", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var writtenPath string + var writtenContent string + var writtenMode uint32 + + c := &Client{ + Virt: &truenas.MockVirtService{ + GetGlobalConfigFunc: func(ctx context.Context) (*truenas.VirtGlobalConfig, error) { + if tt.configErr != nil { + return nil, tt.configErr + } + return &truenas.VirtGlobalConfig{Pool: tt.pool}, nil + }, + }, + Filesystem: &truenas.MockFilesystemService{ + WriteFileFunc: func(ctx context.Context, path string, params truenas.WriteFileParams) error { + if tt.writeErr != nil { + return tt.writeErr + } + writtenPath = path + writtenContent = string(params.Content) + writtenMode = uint32(params.Mode) + return nil + }, + }, + } + + err := c.WriteContainerFile(context.Background(), "px-test", "/etc/test.conf", []byte("hello"), 0o644) + if tt.wantErr != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if writtenPath != tt.wantPath { + t.Errorf("path = %q, want %q", writtenPath, tt.wantPath) + } + if writtenContent != "hello" { + t.Errorf("content = %q, want %q", writtenContent, "hello") + } + if writtenMode != 0o644 { + t.Errorf("mode = %o, want 644", writtenMode) + } + }) + } +} + +func TestReplaceContainerRootfs(t *testing.T) { + tests := []struct { + name string + dataset string + container string + snapshot string + configErr error + createErr error + runErr error + wantErr string + wantCmd string + wantDelete bool + }{ + { + name: "creates cron job, runs, and deletes", + dataset: "tank/ix-virt", + container: "px-test", + snapshot: "tank/ix-virt/containers/px-test@snap1", + wantCmd: "/usr/sbin/zfs destroy -r tank/ix-virt/containers/px-test", + wantDelete: true, + }, + { + name: "config error", + configErr: errors.New("api down"), + container: "px-test", + snapshot: "tank@snap", + wantErr: "querying virt global config", + }, + { + name: "empty dataset", + dataset: "", + container: "px-test", + snapshot: "tank@snap", + wantErr: "no dataset", + }, + { + name: "unsafe chars in snapshot", + dataset: "tank/ix-virt", + container: "px-test", + snapshot: "tank@snap; rm -rf /", + wantErr: "unsafe character", + }, + { + name: "cron create error", + dataset: "tank/ix-virt", + container: "px-test", + snapshot: "tank/ix-virt/containers/px-test@snap1", + createErr: errors.New("cron api failed"), + wantErr: "creating temp cron job", + wantDelete: false, + }, + { + name: "cron run error still deletes job", + dataset: "tank/ix-virt", + container: "px-test", + snapshot: "tank/ix-virt/containers/px-test@snap1", + runErr: errors.New("zfs command failed"), + wantErr: "running ZFS clone", + wantDelete: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var createdCmd string + var deleted bool + + c := &Client{ + Virt: &truenas.MockVirtService{ + GetGlobalConfigFunc: func(ctx context.Context) (*truenas.VirtGlobalConfig, error) { + if tt.configErr != nil { + return nil, tt.configErr + } + return &truenas.VirtGlobalConfig{Dataset: tt.dataset}, nil + }, + }, + Cron: &truenas.MockCronService{ + CreateFunc: func(ctx context.Context, opts truenas.CreateCronJobOpts) (*truenas.CronJob, error) { + if tt.createErr != nil { + return nil, tt.createErr + } + createdCmd = opts.Command + if opts.User != "root" { + t.Errorf("cron user = %q, want root", opts.User) + } + if opts.Enabled { + t.Error("cron job should be disabled") + } + return &truenas.CronJob{ID: 42}, nil + }, + RunFunc: func(ctx context.Context, id int64, skipDisabled bool) error { + if id != 42 { + t.Errorf("run id = %d, want 42", id) + } + return tt.runErr + }, + DeleteFunc: func(ctx context.Context, id int64) error { + if id != 42 { + t.Errorf("delete id = %d, want 42", id) + } + deleted = true + return nil + }, + }, + } + + err := c.ReplaceContainerRootfs(context.Background(), tt.container, tt.snapshot) + if tt.wantErr != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error %q should contain %q", err.Error(), tt.wantErr) + } + } else if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if tt.wantCmd != "" && !strings.Contains(createdCmd, tt.wantCmd) { + t.Errorf("cron command %q should contain %q", createdCmd, tt.wantCmd) + } + if tt.wantDelete && !deleted { + t.Error("cron job should have been deleted") + } + if !tt.wantDelete && deleted { + t.Error("cron job should not have been deleted") + } + }) + } +} + func TestWriteAuthorizedKey(t *testing.T) { tests := []struct { name string diff --git a/internal/truenas/scripts/pixels-profile.sh b/internal/truenas/scripts/pixels-profile.sh new file mode 100644 index 0000000..6a54849 --- /dev/null +++ b/internal/truenas/scripts/pixels-profile.sh @@ -0,0 +1,2 @@ +alias detach='zmx detach' +[ -n "$ZMX_SESSION" ] && echo "Detach: Ctrl+\\ or type 'detach'"