Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ func sandboxConfig() map[string]string {
if cfg.SSH.Key != "" {
m["ssh_key"] = cfg.SSH.Key
}
if cfg.SSH.StrictHostKeysEnabled() {
m["ssh_known_hosts"] = config.KnownHostsPath()
}
m["provision"] = strconv.FormatBool(cfg.Provision.IsEnabled())
m["devtools"] = strconv.FormatBool(cfg.Provision.DevToolsEnabled())
if cfg.Network.Egress != "" {
Expand Down
20 changes: 18 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ type Defaults struct {
}

type SSH struct {
User string `toml:"user" env:"PIXELS_SSH_USER"`
Key string `toml:"key" env:"PIXELS_SSH_KEY"`
User string `toml:"user" env:"PIXELS_SSH_USER"`
Key string `toml:"key" env:"PIXELS_SSH_KEY"`
StrictHostKeys *bool `toml:"strict_host_keys" env:"PIXELS_SSH_STRICT_HOST_KEYS"`
}

// StrictHostKeysEnabled returns whether SSH host key verification is enabled.
// Defaults to true when not explicitly set.
func (s *SSH) StrictHostKeysEnabled() bool {
if s.StrictHostKeys == nil {
return true
}
return *s.StrictHostKeys
}

type Checkpoint struct {
Expand Down Expand Up @@ -203,6 +213,12 @@ func (t *TrueNAS) InsecureSkipVerifyValue() bool {
return *t.InsecureSkipVerify
}

// KnownHostsPath returns the path to the pixels-managed SSH known_hosts file.
func KnownHostsPath() string {
dir := filepath.Dir(configPath())
return filepath.Join(dir, "known_hosts")
}

func expandHome(path string) string {
if strings.HasPrefix(path, "~/") {
if home, err := os.UserHomeDir(); err == nil {
Expand Down
72 changes: 72 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,78 @@ func TestConfigPathDefault(t *testing.T) {
}
}

func TestStrictHostKeysEnabled(t *testing.T) {
tests := []struct {
name string
val *bool
want bool
}{
{"nil defaults to true", nil, true},
{"explicit true", boolPtr(true), true},
{"explicit false", boolPtr(false), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := SSH{StrictHostKeys: tt.val}
if got := s.StrictHostKeysEnabled(); got != tt.want {
t.Errorf("StrictHostKeysEnabled() = %v, want %v", got, tt.want)
}
})
}
}

func boolPtr(v bool) *bool { return &v }

func TestStrictHostKeysFromFile(t *testing.T) {
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)

cfgDir := filepath.Join(dir, "pixels")
if err := os.MkdirAll(cfgDir, 0o755); err != nil {
t.Fatal(err)
}

content := `
[ssh]
strict_host_keys = false
`
if err := os.WriteFile(filepath.Join(cfgDir, "config.toml"), []byte(content), 0o644); err != nil {
t.Fatal(err)
}

cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.SSH.StrictHostKeysEnabled() {
t.Error("StrictHostKeysEnabled() = true, want false (set in TOML)")
}
}

func TestStrictHostKeysEnvOverride(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", t.TempDir())
t.Setenv("PIXELS_SSH_STRICT_HOST_KEYS", "false")

cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.SSH.StrictHostKeysEnabled() {
t.Error("StrictHostKeysEnabled() = true, want false (env override)")
}
}

func TestKnownHostsPath(t *testing.T) {
dir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", dir)

got := KnownHostsPath()
want := filepath.Join(dir, "pixels", "known_hosts")
if got != want {
t.Errorf("KnownHostsPath() = %q, want %q", got, want)
}
}

func TestExpandHome(t *testing.T) {
home, err := os.UserHomeDir()
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions internal/provision/provision.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ type Runner struct {
}

// NewRunner creates a Runner that executes commands over SSH.
func NewRunner(host, user, keyPath string) *Runner {
func NewRunner(host, user, keyPath, knownHostsPath string) *Runner {
return &Runner{
Host: host,
User: user,
KeyPath: keyPath,
exec: &sshExecutor{
cc: ssh.ConnConfig{Host: host, User: user, KeyPath: keyPath},
cc: ssh.NewConnConfig(host, user, keyPath, knownHostsPath),
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/provision/provision_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestZmxCmd(t *testing.T) {
}

func TestNewRunner(t *testing.T) {
r := NewRunner("10.0.0.1", "root", "/tmp/key")
r := NewRunner("10.0.0.1", "root", "/tmp/key", "/tmp/known_hosts")
if r.Host != "10.0.0.1" {
t.Errorf("Host = %q, want %q", r.Host, "10.0.0.1")
}
Expand Down
59 changes: 53 additions & 6 deletions internal/ssh/ssh.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ssh

import (
"bytes"
"context"
"errors"
"fmt"
Expand All @@ -11,14 +12,28 @@ import (
"sort"
"strings"
"time"

"github.com/deevus/pixels/internal/config"
)

// ConnConfig holds the parameters for an SSH connection.
// Use NewConnConfig to construct — it ensures secure defaults.
type ConnConfig struct {
Host string
User string
KeyPath string
Env map[string]string // optional, for SetEnv forwarding
Host string
User string
KeyPath string
Env map[string]string // optional, for SetEnv forwarding
KnownHostsPath string // path to known_hosts file for accept-new verification
}

// NewConnConfig creates a ConnConfig with the given parameters.
func NewConnConfig(host, user, keyPath, knownHostsPath string) ConnConfig {
return ConnConfig{
Host: host,
User: user,
KeyPath: keyPath,
KnownHostsPath: knownHostsPath,
}
}

// WaitReady polls the host's SSH port until it accepts connections or the timeout expires.
Expand Down Expand Up @@ -115,9 +130,13 @@ func TestAuth(ctx context.Context, cc ConnConfig) error {
// It is exported for use by callers that need to construct custom exec.Cmd
// with non-standard Stdin/Stdout/Stderr (e.g. sandbox backends).
func Args(cc ConnConfig) []string {
knownHosts := cc.KnownHostsPath
if knownHosts == "" {
knownHosts = config.KnownHostsPath()
}
args := []string{
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=" + os.DevNull,
"-o", "StrictHostKeyChecking=accept-new",
"-o", "UserKnownHostsFile=" + knownHosts,
"-o", "PasswordAuthentication=no",
"-o", "LogLevel=ERROR",
}
Expand Down Expand Up @@ -150,6 +169,34 @@ func Args(cc ConnConfig) []string {
return args
}

// RemoveKnownHost removes all entries for the given host from the known_hosts
// file. This is used to clean up stale entries when containers are
// created, destroyed, or restored from snapshots. It is a no-op if the
// known_hosts file does not exist or the path is empty.
func RemoveKnownHost(knownHostsPath, host string) error {
if knownHostsPath == "" || host == "" {
return nil
}
data, err := os.ReadFile(knownHostsPath)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return fmt.Errorf("reading known_hosts: %w", err)
}

prefix := []byte(host + " ")
var kept []byte
for _, line := range bytes.SplitAfter(data, []byte("\n")) {
if !bytes.HasPrefix(line, prefix) {
kept = append(kept, line...)
}
}

return os.WriteFile(knownHostsPath, kept, 0o600)
}


// consoleArgs builds SSH arguments for an interactive console session.
// When remoteCmd is non-empty, -t is inserted to force PTY allocation
// and the command is appended after user@host.
Expand Down
83 changes: 79 additions & 4 deletions internal/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,53 @@ func TestSSHArgs(t *testing.T) {
}
})

t.Run("uses os.DevNull for UserKnownHostsFile", func(t *testing.T) {
t.Run("always uses accept-new even without explicit KnownHostsPath", func(t *testing.T) {
args := Args(ConnConfig{Host: "10.0.0.1", User: "pixel"})
want := "UserKnownHostsFile=" + os.DevNull
foundAcceptNew := false
for _, a := range args {
if a == "StrictHostKeyChecking=accept-new" {
foundAcceptNew = true
}
if a == "StrictHostKeyChecking=no" {
t.Error("should never use StrictHostKeyChecking=no")
}
if strings.Contains(a, os.DevNull) {
t.Errorf("should never use DevNull for known hosts, got %q", a)
}
}
if !foundAcceptNew {
t.Errorf("expected StrictHostKeyChecking=accept-new, got %v", args)
}
})

t.Run("accept-new with KnownHostsPath", func(t *testing.T) {
khFile := "/tmp/pixels-test-known-hosts"
args := Args(ConnConfig{Host: "10.0.0.1", User: "pixel", KnownHostsPath: khFile})

// Should use accept-new instead of no.
foundAcceptNew := false
for _, a := range args {
if a == "StrictHostKeyChecking=accept-new" {
foundAcceptNew = true
}
if a == "StrictHostKeyChecking=no" {
t.Error("should not use StrictHostKeyChecking=no when KnownHostsPath is set")
}
}
if !foundAcceptNew {
t.Errorf("expected StrictHostKeyChecking=accept-new, got %v", args)
}

// Should use the provided known hosts file.
want := "UserKnownHostsFile=" + khFile
found := false
for _, a := range args {
if a == want {
found = true
break
}
}
if !found {
t.Errorf("sshArgs should contain %q, got %v", want, args)
t.Errorf("expected %q, got %v", want, args)
}
})

Expand Down Expand Up @@ -113,6 +148,46 @@ func TestSSHArgs(t *testing.T) {
})
}

func TestRemoveKnownHost(t *testing.T) {
t.Run("no-op when file is empty string", func(t *testing.T) {
if err := RemoveKnownHost("", "10.0.0.1"); err != nil {
t.Errorf("expected no error, got %v", err)
}
})

t.Run("no-op when file does not exist", func(t *testing.T) {
if err := RemoveKnownHost("/tmp/nonexistent-known-hosts-file", "10.0.0.1"); err != nil {
t.Errorf("expected no error for missing file, got %v", err)
}
})

t.Run("removes entry from existing file", func(t *testing.T) {
dir := t.TempDir()
khFile := dir + "/known_hosts"
// Use valid ssh-ed25519 key data (32 bytes base64-encoded with key type prefix).
key1 := "10.0.0.1 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBVlGh5YxGBMp/DO3OjAHsMR0DVQS2DJnpOqaGP2MkNl\n"
key2 := "10.0.0.2 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKxvLhGmlN1sdag3FISwEVfAGwC+v3+x0v6qIFNyGmNd\n"
if err := os.WriteFile(khFile, []byte(key1+key2), 0o600); err != nil {
t.Fatal(err)
}

if err := RemoveKnownHost(khFile, "10.0.0.1"); err != nil {
t.Fatalf("RemoveKnownHost: %v", err)
}

data, err := os.ReadFile(khFile)
if err != nil {
t.Fatal(err)
}
if strings.Contains(string(data), "10.0.0.1") {
t.Errorf("expected 10.0.0.1 to be removed, file contains: %s", data)
}
if !strings.Contains(string(data), "10.0.0.2") {
t.Errorf("expected 10.0.0.2 to remain, file contains: %s", data)
}
})
}

func TestConsoleArgs(t *testing.T) {
t.Run("no remote command", func(t *testing.T) {
cc := ConnConfig{Host: "10.0.0.1", User: "pixel", KeyPath: "/tmp/key"}
Expand Down
Loading