diff --git a/cmds/dutagent/dutagent.go b/cmds/dutagent/dutagent.go index 506b01e..631ee45 100644 --- a/cmds/dutagent/dutagent.go +++ b/cmds/dutagent/dutagent.go @@ -9,25 +9,22 @@ package main import ( "context" - "crypto/tls" "errors" "flag" "fmt" "io" "log" - "net" "net/http" "os" "os/signal" "syscall" + "time" "connectrpc.com/connect" "github.com/BlindspotSoftware/dutctl/internal/buildinfo" "github.com/BlindspotSoftware/dutctl/internal/dutagent" "github.com/BlindspotSoftware/dutctl/pkg/dut" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "gopkg.in/yaml.v3" pb "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1" @@ -155,6 +152,9 @@ func printInitErr(err error) { log.Print(err) } +// readHeaderTimeout bounds how long the server waits to read request headers. +const readHeaderTimeout = 10 * time.Second + // startRPCService starts the RPC service, that ideally listens for incoming // connections forever. It always returns an non-nil error. func (agt *agent) startRPCService() error { @@ -167,12 +167,17 @@ func (agt *agent) startRPCService() error { path, handler := dutctlv1connect.NewDeviceServiceHandler(service) mux.Handle(path, handler) - //nolint:gosec - return http.ListenAndServe( - agt.address, - // Use h2c so we can serve HTTP/2 without TLS. - h2c.NewHandler(mux, &http2.Server{}), - ) + // Serve HTTP/2 without TLS (h2c) + srv := &http.Server{ + Addr: agt.address, + Handler: mux, + ReadHeaderTimeout: readHeaderTimeout, + } + srv.Protocols = new(http.Protocols) + srv.Protocols.SetHTTP1(true) + srv.Protocols.SetUnencryptedHTTP2(true) + + return srv.ListenAndServe() } func (agt *agent) registerWithServer() error { @@ -211,19 +216,17 @@ func spawnClient(agendURL string) dutctlv1connect.RelayServiceClient { // TODO: refactor into pkg and reuse in dutctl and dutserver. func newInsecureClient() *http.Client { + transport := &http.Transport{} + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // If you're also using this client for non-h2c traffic, you may want - // to delegate to tls.Dial if the network isn't TCP or the addr isn't - // in an allowlist. - - //nolint:noctx - return net.Dial(network, addr) - }, - // TODO: Don't forget timeouts! - }, + Transport: transport, + // TODO: Don't forget timeouts! http.Client.Timeout must not be used here: + // it bounds the entire exchange including the response body, which would + // abort long-lived streaming RPCs. Instead use per-RPC context deadlines + // on unary calls and/or transport timeouts (DialContext, + // TLSHandshakeTimeout, ResponseHeaderTimeout, IdleConnTimeout). } } diff --git a/cmds/dutctl/dutctl.go b/cmds/dutctl/dutctl.go index 3810743..34d4d01 100644 --- a/cmds/dutctl/dutctl.go +++ b/cmds/dutctl/dutctl.go @@ -7,13 +7,11 @@ package main import ( - "crypto/tls" "errors" "flag" "fmt" "io" "log" - "net" "net/http" "os" @@ -22,7 +20,6 @@ import ( "github.com/BlindspotSoftware/dutctl/internal/output" "github.com/BlindspotSoftware/dutctl/pkg/lock" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" ) const usageAbstract = `dutctl - The client application of the DUT Control system. @@ -128,7 +125,6 @@ type application struct { func (app *application) setupRPCClient() { client := dutctlv1connect.NewDeviceServiceClient( - // Instead of http.DefaultClient, use the HTTP/2 protocol without TLS newInsecureClient(), fmt.Sprintf("http://%s", app.serverAddr), connect.WithGRPC(), @@ -138,19 +134,18 @@ func (app *application) setupRPCClient() { } func newInsecureClient() *http.Client { + // Use the HTTP/2 protocol without TLS (h2c) + transport := &http.Transport{} + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // If you're also using this client for non-h2c traffic, you may want - // to delegate to tls.Dial if the network isn't TCP or the addr isn't - // in an allowlist. - - //nolint:noctx - return net.Dial(network, addr) - }, - // Don't forget timeouts! - }, + Transport: transport, + // TODO: Don't forget timeouts! http.Client.Timeout must not be used here: + // it bounds the entire exchange including the response body, which would + // abort long-lived streaming RPCs. Instead use per-RPC context deadlines + // on unary calls and/or transport timeouts (DialContext, + // TLSHandshakeTimeout, ResponseHeaderTimeout, IdleConnTimeout). } } diff --git a/cmds/dutctl/rawconsole_test.go b/cmds/dutctl/rawconsole_test.go new file mode 100644 index 0000000..ba7843f --- /dev/null +++ b/cmds/dutctl/rawconsole_test.go @@ -0,0 +1,60 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "testing" +) + +// TestRawConsoleNeverArmsForNonFileStdin verifies that a non-*os.File stdin +// (e.g. a piped/scripted run) never switches to raw mode, regardless of the +// interactive hint, and that arm/disarm are safe no-ops in that case. +func TestRawConsoleNeverArmsForNonFileStdin(t *testing.T) { + console := newRawConsole(&bytes.Buffer{}, true) + + console.arm() + + if console.isActive() { + t.Error("isActive() = true for non-file stdin, want false") + } + + // disarm must not panic even though arm never engaged. + console.disarm() +} + +// TestRawConsoleNeverArmsForScriptedInvocation verifies that when the command +// was invoked with arguments (interactive=false), raw mode is never armed even +// though the agent streams console output (modelled here by calling arm). +func TestRawConsoleNeverArmsForScriptedInvocation(t *testing.T) { + // interactive=false models a serial expect/send sequence run. + console := newRawConsole(&bytes.Buffer{}, false) + + console.arm() + + if console.isActive() { + t.Error("isActive() = true for a scripted (argument-bearing) invocation, want false") + } + + console.disarm() +} + +// TestRawConsoleArmIsIdempotent verifies that repeated arm calls (one per +// console message) do not panic and leave a consistent state. With a non-file +// stdin it stays inactive; the point is that calling arm many times is safe. +func TestRawConsoleArmIsIdempotent(t *testing.T) { + console := newRawConsole(&bytes.Buffer{}, true) + + for range 5 { + console.arm() + } + + if console.isActive() { + t.Error("isActive() = true for non-file stdin, want false") + } + + console.disarm() + console.disarm() // double disarm must be safe +} diff --git a/cmds/dutctl/rawmode_darwin.go b/cmds/dutctl/rawmode_darwin.go new file mode 100644 index 0000000..a954fbb --- /dev/null +++ b/cmds/dutctl/rawmode_darwin.go @@ -0,0 +1,15 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build darwin + +package main + +import "golang.org/x/sys/unix" + +// Terminal get/set ioctl request numbers for macOS (BSD-derived). +const ( + tcGetReq = unix.TIOCGETA + tcSetReq = unix.TIOCSETA +) diff --git a/cmds/dutctl/rawmode_linux.go b/cmds/dutctl/rawmode_linux.go new file mode 100644 index 0000000..954bc29 --- /dev/null +++ b/cmds/dutctl/rawmode_linux.go @@ -0,0 +1,15 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux + +package main + +import "golang.org/x/sys/unix" + +// Terminal get/set ioctl request numbers for Linux. +const ( + tcGetReq = unix.TCGETS + tcSetReq = unix.TCSETS +) diff --git a/cmds/dutctl/rawmode_other.go b/cmds/dutctl/rawmode_other.go new file mode 100644 index 0000000..9c14a2e --- /dev/null +++ b/cmds/dutctl/rawmode_other.go @@ -0,0 +1,14 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !linux && !darwin + +package main + +// setRawInput is a no-op on platforms without termios support (e.g. Windows). +// Input stays line-buffered; the interactive serial experience is degraded but +// the client still builds and runs. +func setRawInput(_ int) func() { + return nil +} diff --git a/cmds/dutctl/rawmode_unix.go b/cmds/dutctl/rawmode_unix.go new file mode 100644 index 0000000..5ca6c8a --- /dev/null +++ b/cmds/dutctl/rawmode_unix.go @@ -0,0 +1,53 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux || darwin + +package main + +import "golang.org/x/sys/unix" + +// setRawInput puts the terminal into raw input mode so the interactive serial +// session behaves like a direct console: +// +// - ECHO/ICANON off: keystrokes are delivered immediately, character at a +// time, and not echoed locally (the DUT echoes them back). +// - ISIG off: control characters such as Ctrl-C, Ctrl-Z and Ctrl-\ are NOT +// turned into local signals; they are forwarded to the DUT as raw bytes. +// - IXON off: Ctrl-S/Ctrl-Q flow control is forwarded to the DUT instead of +// being swallowed by the local terminal. +// - IEXTEN off: Ctrl-V and friends are forwarded literally. +// - ICRNL off: a typed CR is sent as CR, not translated to NL. +// +// dutctl is exited with the client-side escape sequence (Ctrl-A x), not with a +// terminal signal — see filterEscape in rpc.go. +// +// It returns a restore function, or nil if the fd is not a terminal (in which +// case input stays line-buffered, which is the correct fallback for pipes). +// +// The ioctl request numbers differ per OS (tcGetReq/tcSetReq are defined in the +// platform-specific files); the termios flags themselves are shared across +// unix platforms. +func setRawInput(fileDescriptor int) func() { + termios, err := unix.IoctlGetTermios(fileDescriptor, tcGetReq) + if err != nil { + return nil + } + + old := *termios + + termios.Iflag &^= unix.ICRNL | unix.IXON + termios.Lflag &^= unix.ECHO | unix.ICANON | unix.ISIG | unix.IEXTEN + termios.Cc[unix.VMIN] = 1 + termios.Cc[unix.VTIME] = 0 + + err = unix.IoctlSetTermios(fileDescriptor, tcSetReq, termios) + if err != nil { + return nil + } + + return func() { + _ = unix.IoctlSetTermios(fileDescriptor, tcSetReq, &old) + } +} diff --git a/cmds/dutctl/rpc.go b/cmds/dutctl/rpc.go index 13e3075..e86afd3 100644 --- a/cmds/dutctl/rpc.go +++ b/cmds/dutctl/rpc.go @@ -5,7 +5,6 @@ package main import ( - "bufio" "context" "errors" "fmt" @@ -14,6 +13,8 @@ import ( "log" "os" "strings" + "sync" + "sync/atomic" "time" "connectrpc.com/connect" @@ -23,6 +24,84 @@ import ( pb "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1" ) +// rawConsole lazily switches the terminal to raw input mode, and prints the +// interactive session banner, the first time the agent actually streams console +// output. One-shot commands (power, flash) only ever send Print messages, so +// they never arm it: the terminal is left untouched and no misleading banner is +// shown. A serial session sends its "Connected" banner as console output within +// milliseconds, which arms it just in time for keystroke forwarding. +// +// The zero value is not usable; build one with newRawConsole. +type rawConsole struct { + fd int // stdin file descriptor (valid only when canRaw is true) + canRaw bool // stdin is an *os.File, so raw mode may be attempted + once sync.Once // guards the one-time arm + active atomic.Bool // true once raw mode is on; read by the send goroutine + mu sync.Mutex // guards restore + restore func() // set by arm, called by disarm; nil until/unless armed +} + +// newRawConsole returns a console that may switch the terminal to raw mode only +// when both: interactive is true (the command was invoked without arguments, so +// it is a hand-driven session rather than a scripted one) and stdin is a real +// *os.File. A pipe, /dev/null, or any argument-bearing (scripted) invocation +// yields canRaw=false, so it never changes the terminal nor prints the banner. +func newRawConsole(stdin io.Reader, interactive bool) *rawConsole { + console := &rawConsole{} + + if f, ok := stdin.(*os.File); ok && interactive { + console.fd = int(f.Fd()) + console.canRaw = true + } + + return console +} + +// arm switches the terminal to raw mode and prints the session banner on its +// first call; later calls are no-ops. setRawInput returns nil for a non-terminal +// fd, so arming is silently skipped when stdin is not a TTY. Safe to call from +// any goroutine. +func (rc *rawConsole) arm() { + rc.once.Do(func() { + if !rc.canRaw { + return + } + + restore := setRawInput(rc.fd) + if restore == nil { + return // not a terminal — keep input line-buffered + } + + rc.mu.Lock() + rc.restore = restore + rc.mu.Unlock() + + rc.active.Store(true) + + // The escape sequence is the only way to quit while raw mode is on. + fmt.Fprint(os.Stderr, "\r\n[dutctl] interactive session — press Ctrl-A then x to quit\r\n") + }) +} + +// isActive reports whether raw mode is currently engaged. The send goroutine +// uses it to decide whether to apply the Ctrl-A escape filter to stdin. +func (rc *rawConsole) isActive() bool { + return rc.active.Load() +} + +// disarm restores the terminal if arm switched it to raw mode. It is safe to +// defer unconditionally and safe to call when arm never fired. +func (rc *rawConsole) disarm() { + rc.mu.Lock() + restore := rc.restore + rc.restore = nil + rc.mu.Unlock() + + if restore != nil { + restore() + } +} + func (app *application) listRPC() error { ctx := context.Background() req := connect.NewRequest(&pb.ListRequest{}) @@ -188,10 +267,72 @@ func (app *application) detailsRPC(device, command, keyword string) error { return nil } -//nolint:funlen,cyclop,gocognit +// Interactive escape sequence. Ctrl-A is the prefix; the following key decides: +// - 'x'/'X'/Ctrl-X: quit dutctl. +// - Ctrl-A: send a single literal Ctrl-A to the DUT. +// - anything else: forward both the prefix and the key unchanged. +// +// Every other byte (Ctrl-C, Ctrl-D, Ctrl-Z, ...) is forwarded to the DUT. +const ( + escapePrefix = 0x01 // Ctrl-A + escapeQuitCtrl = 0x18 // Ctrl-X +) + +// filterEscape applies the interactive escape state machine to in. It returns +// the bytes to forward to the DUT and whether the quit sequence was seen. +// escapePending carries the "prefix seen" state across reads, so the prefix and +// its following key may arrive in separate stdin reads. +func filterEscape(in []byte, escapePending *bool) ([]byte, bool) { + out := make([]byte, 0, len(in)) + + for _, char := range in { + if *escapePending { + *escapePending = false + + switch char { + case 'x', 'X', escapeQuitCtrl: + return out, true + case escapePrefix: // prefix pressed twice -> send one literal Ctrl-A + out = append(out, escapePrefix) + default: // not an escape; forward the prefix and this byte + out = append(out, escapePrefix, char) + } + + continue + } + + if char == escapePrefix { + *escapePending = true + + continue + } + + out = append(out, char) + } + + return out, false +} + +//nolint:funlen,cyclop,gocognit,maintidx func (app *application) runRPC(device, command string, cmdArgs []string) error { const numWorkers = 2 // The send and receive worker goroutines + // Raw input mode (no echo, no canonical line buffering, no local signal + // generation) is needed only for a live, hand-driven console session, so that + // each keystroke — including Ctrl-C — is forwarded to the DUT immediately and + // echoed by the remote side. Two conditions must hold, so it is armed lazily: + // + // - The invocation is interactive, i.e. the command was given no arguments. + // A command with arguments is parameterised/scripted (e.g. a serial + // expect/send sequence the agent drives on its own); raw mode there would + // only mislead and would steal Ctrl-C from aborting the client. + // - The agent actually streams console output, which one-shot commands + // (power, flash) never do — so they leave the terminal untouched. + // + // Piped/scripted stdin is never switched to raw mode and is forwarded verbatim. + console := newRawConsole(app.stdin, len(cmdArgs) == 0) + defer console.disarm() + runCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -256,6 +397,11 @@ func (app *application) runRPC(device, command string, cmdArgs []string) error { Metadata: metadata, }) case *pb.RunResponse_Console: + // First console output means this is a live console session: + // switch the terminal to raw mode now (no-op for non-TTY stdin) + // so keystrokes forward correctly and the banner is truthful. + console.arm() + switch consoleData := msg.Console.Data.(type) { case *pb.Console_Stdout: app.formatter.WriteContent(output.Content{ @@ -324,18 +470,22 @@ func (app *application) runRPC(device, command string, cmdArgs []string) error { } }() - // Send routine — reads lines from stdin and forwards them to the server. + // Send routine — reads raw bytes from stdin and forwards them to the server. // // Unlike the receive routine this goroutine intentionally does NOT defer - // cancel(). When stdin reaches EOF (e.g. /dev/null in non-interactive - // runs) this goroutine returns immediately. If it cancelled the context - // on exit, the receive routine would be torn down before it could read - // and print the server's response. + // cancel() for the EOF case. When stdin reaches EOF (e.g. /dev/null in + // non-interactive runs) this goroutine returns immediately; if it cancelled + // the context on exit, the receive routine would be torn down before it + // could read and print the server's response. // - // Only the receive routine drives context cancellation so that all - // server output is processed before the RPC terminates. + // It DOES cancel when the user types the interactive escape sequence + // (Ctrl-A x): that is an explicit request to end the session. go func() { - reader := bufio.NewReader(app.stdin) + const stdinBufSize = 256 + + buf := make([]byte, stdinBufSize) + + var escapePending bool for { select { @@ -346,7 +496,7 @@ func (app *application) runRPC(device, command string, cmdArgs []string) error { default: } - text, err := reader.ReadString('\n') + nRead, err := app.stdin.Read(buf) if err != nil { if !errors.Is(err, io.EOF) { errChan <- fmt.Errorf("reading stdin: %w", err) @@ -355,17 +505,35 @@ func (app *application) runRPC(device, command string, cmdArgs []string) error { return } - err = stream.Send(&pb.RunRequest{ - Msg: &pb.RunRequest_Console{ - Console: &pb.Console{ - Data: &pb.Console_Stdin{ - Stdin: []byte(text), + payload := buf[:nRead] + + quit := false + if console.isActive() { + // Intercept the escape sequence; forward everything else + // (including Ctrl-C) untouched. + payload, quit = filterEscape(payload, &escapePending) + } + + if len(payload) > 0 { + sendErr := stream.Send(&pb.RunRequest{ + Msg: &pb.RunRequest_Console{ + Console: &pb.Console{ + Data: &pb.Console_Stdin{ + Stdin: payload, + }, }, }, - }, - }) - if err != nil { - errChan <- fmt.Errorf("sending RPC message: %w", err) + }) + if sendErr != nil { + errChan <- fmt.Errorf("sending RPC message: %w", sendErr) + + return + } + } + + if quit { + log.Println("Send routine terminating: escape sequence") + cancel() return } diff --git a/cmds/dutctl/rpc_escape_test.go b/cmds/dutctl/rpc_escape_test.go new file mode 100644 index 0000000..9477c4f --- /dev/null +++ b/cmds/dutctl/rpc_escape_test.go @@ -0,0 +1,145 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "testing" + +// TestFilterEscapeSingleCall covers escape handling when the prefix and the +// following key arrive within one stdin read. +func TestFilterEscapeSingleCall(t *testing.T) { + tests := []struct { + name string + in []byte + wantOut []byte + wantQuit bool + }{ + { + name: "plain bytes pass through", + in: []byte("hello"), + wantOut: []byte("hello"), + }, + { + name: "ctrl-c is forwarded to the DUT", + in: []byte{0x03}, + wantOut: []byte{0x03}, + }, + { + name: "ctrl-d is forwarded to the DUT", + in: []byte{0x04}, + wantOut: []byte{0x04}, + }, + { + name: "ctrl-z is forwarded to the DUT", + in: []byte{0x1a}, + wantOut: []byte{0x1a}, + }, + { + name: "ctrl-a then x quits", + in: []byte{escapePrefix, 'x'}, + wantOut: []byte{}, + wantQuit: true, + }, + { + name: "ctrl-a then X quits", + in: []byte{escapePrefix, 'X'}, + wantOut: []byte{}, + wantQuit: true, + }, + { + name: "ctrl-a then ctrl-x quits", + in: []byte{escapePrefix, escapeQuitCtrl}, + wantOut: []byte{}, + wantQuit: true, + }, + { + name: "ctrl-a ctrl-a sends one literal ctrl-a", + in: []byte{escapePrefix, escapePrefix}, + wantOut: []byte{escapePrefix}, + }, + { + name: "ctrl-a then other key forwards both", + in: []byte{escapePrefix, 'z'}, + wantOut: []byte{escapePrefix, 'z'}, + }, + { + name: "text before quit sequence is forwarded", + in: append([]byte("ab"), escapePrefix, 'x'), + wantOut: []byte("ab"), + wantQuit: true, + }, + { + name: "text around literal ctrl-a", + in: append(append([]byte("a"), escapePrefix, escapePrefix), 'b'), + wantOut: append([]byte{'a', escapePrefix}, 'b'), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pending := false + + out, quit := filterEscape(tt.in, &pending) + + if string(out) != string(tt.wantOut) { + t.Errorf("out = %q, want %q", out, tt.wantOut) + } + + if quit != tt.wantQuit { + t.Errorf("quit = %v, want %v", quit, tt.wantQuit) + } + + if pending { + t.Errorf("escapePending left true, want false") + } + }) + } +} + +// TestFilterEscapeAcrossReads verifies the escape prefix and its following key +// are handled correctly when they arrive in separate stdin reads. +func TestFilterEscapeAcrossReads(t *testing.T) { + pending := false + + // Read 1: lone prefix — nothing forwarded yet, state armed. + out, quit := filterEscape([]byte{escapePrefix}, &pending) + if len(out) != 0 || quit { + t.Fatalf("read 1: out=%q quit=%v, want empty/false", out, quit) + } + + if !pending { + t.Fatal("read 1: escapePending = false, want true") + } + + // Read 2: the quit key arrives separately. + out, quit = filterEscape([]byte{'x'}, &pending) + if !quit { + t.Errorf("read 2: quit = false, want true") + } + + if len(out) != 0 { + t.Errorf("read 2: out = %q, want empty", out) + } + + if pending { + t.Errorf("read 2: escapePending = true, want false") + } +} + +// TestFilterEscapePrefixThenForward verifies an armed prefix followed by a +// normal byte in the next read forwards both bytes and disarms. +func TestFilterEscapePrefixThenForward(t *testing.T) { + pending := false + + filterEscape([]byte{escapePrefix}, &pending) // arm + + out, quit := filterEscape([]byte{'k'}, &pending) + if quit { + t.Errorf("quit = true, want false") + } + + if string(out) != string([]byte{escapePrefix, 'k'}) { + t.Errorf("out = %q, want %q", out, []byte{escapePrefix, 'k'}) + } +} diff --git a/cmds/exp/dutserver/dutserver.go b/cmds/exp/dutserver/dutserver.go index ab9965b..d029b0c 100644 --- a/cmds/exp/dutserver/dutserver.go +++ b/cmds/exp/dutserver/dutserver.go @@ -14,10 +14,9 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) const ( @@ -76,6 +75,9 @@ func (svr *server) watchInterrupt() { }() } +// readHeaderTimeout bounds how long the server waits to read request headers. +const readHeaderTimeout = 10 * time.Second + // startRPCService starts the RPC service, that ideally listens for incoming // connections forever. It always returns an non-nil error. func (svr *server) startRPCService() error { @@ -94,12 +96,17 @@ func (svr *server) startRPCService() error { path, handler = dutctlv1connect.NewRelayServiceHandler(service) mux.Handle(path, handler) - //nolint:gosec - return http.ListenAndServe( - svr.address, - // Use h2c so we can serve HTTP/2 without TLS. - h2c.NewHandler(mux, &http2.Server{}), - ) + // Serve HTTP/2 without TLS (h2c) + srv := &http.Server{ + Addr: svr.address, + Handler: mux, + ReadHeaderTimeout: readHeaderTimeout, + } + srv.Protocols = new(http.Protocols) + srv.Protocols.SetHTTP1(true) + srv.Protocols.SetUnencryptedHTTP2(true) + + return srv.ListenAndServe() } // start orchestrates the dutagent execution. diff --git a/cmds/exp/dutserver/rpc.go b/cmds/exp/dutserver/rpc.go index c13cf90..4692684 100644 --- a/cmds/exp/dutserver/rpc.go +++ b/cmds/exp/dutserver/rpc.go @@ -6,13 +6,11 @@ package main import ( "context" - "crypto/tls" "errors" "fmt" "io" "log" "maps" - "net" "net/http" "slices" "sync" @@ -20,7 +18,6 @@ import ( "connectrpc.com/connect" "github.com/BlindspotSoftware/dutctl/pkg/lock" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" pb "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1" ) @@ -390,19 +387,18 @@ func spawnClient(agendURL string) dutctlv1connect.DeviceServiceClient { // TODO: refactor into pkg and reuse in dutctl and dutserver. func newInsecureClient() *http.Client { + // Use the HTTP/2 protocol without TLS (h2c). + transport := &http.Transport{} + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // If you're also using this client for non-h2c traffic, you may want - // to delegate to tls.Dial if the network isn't TCP or the addr isn't - // in an allowlist. - - //nolint:noctx - return net.Dial(network, addr) - }, - // TODO: Don't forget timeouts! - }, + Transport: transport, + // TODO: Don't forget timeouts! http.Client.Timeout must not be used here: + // it bounds the entire exchange including the response body, which would + // abort long-lived streaming RPCs. Instead use per-RPC context deadlines + // on unary calls and/or transport timeouts (DialContext, + // TLSHandshakeTimeout, ResponseHeaderTimeout, IdleConnTimeout). } } diff --git a/go.mod b/go.mod index ba1890b..e5b5cc0 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/stianeikeland/go-rpio/v4 v4.6.0 github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 golang.org/x/crypto v0.51.0 - golang.org/x/net v0.53.0 + golang.org/x/sys v0.44.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 ) @@ -32,6 +32,6 @@ require ( github.com/olekukonko/tablewriter v1.0.9 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/rogpeppe/go-internal v1.6.1 // indirect - golang.org/x/sys v0.44.0 // indirect + golang.org/x/net v0.53.0 // indirect golang.org/x/text v0.37.0 // indirect ) diff --git a/internal/dutagent/broker_test.go b/internal/dutagent/broker_test.go index d8db2c6..5a58962 100644 --- a/internal/dutagent/broker_test.go +++ b/internal/dutagent/broker_test.go @@ -32,9 +32,8 @@ func (s *testStream) Send(_ *pb.RunResponse) error { func (s *testStream) Receive() (*pb.RunRequest, error) { if s.recvBlock { - if s.unblockCh == nil { - s.unblockCh = make(chan struct{}) - } + // unblockCh must be created by the test before Start so this goroutine + // only reads it, never writes it (a write here would race the test). <-s.unblockCh // will block until closed; simulates a long receive } @@ -132,7 +131,8 @@ func TestBroker_StdinForwarding(t *testing.T) { b := &Broker{} stdinPayload := []byte("user input") req := &pb.RunRequest{Msg: &pb.RunRequest_Console{Console: &pb.Console{Data: &pb.Console_Stdin{Stdin: stdinPayload}}}} - stream := &testStream{recvReqs: []*pb.RunRequest{req}, recvErrs: []error{nil}} // after first req, EOF + // No recvErrs: testStream returns reqs in order then EOF by default. + stream := &testStream{recvReqs: []*pb.RunRequest{req}} ctx, cancel := context.WithCancel(context.Background()) sess, errCh := b.Start(ctx, stream) @@ -144,8 +144,7 @@ func TestBroker_StdinForwarding(t *testing.T) { t.Fatalf("stdin mismatch: got %q want %q", string(data), string(stdinPayload)) } case <-time.After(200 * time.Millisecond): - // Expected to fail until refactor ensures proper sequencing / closure. - // (Current code may work but channel closure semantics will still fail later.) + t.Fatal("timeout: stdin payload was not forwarded to stdinCh") } cancel() // simulate module completion @@ -156,15 +155,13 @@ func TestBroker_StdinForwarding(t *testing.T) { // Cancellation during a blocked receive should terminate fromClientWorker without producing errors. func TestBroker_CancelDuringBlockedReceive(t *testing.T) { b := &Broker{} - stream := &testStream{recvBlock: true} + stream := &testStream{recvBlock: true, unblockCh: make(chan struct{})} ctx, cancel := context.WithCancel(context.Background()) _, errCh := b.Start(ctx, stream) // Cancel promptly, then unblock the fake receive so worker goroutine does not leak. cancel() - if stream.unblockCh != nil { - close(stream.unblockCh) - } + close(stream.unblockCh) errs := collectErrors(t, errCh, 200*time.Millisecond) if len(errs) != 0 { diff --git a/internal/dutagent/worker.go b/internal/dutagent/worker.go index 62484fe..61b626c 100644 --- a/internal/dutagent/worker.go +++ b/internal/dutagent/worker.go @@ -100,6 +100,10 @@ func toClientWorker(ctx context.Context, stream Stream, s *session) error { // //nolint:cyclop,funlen,gocognit func fromClientWorker(ctx context.Context, stream Stream, s *session) error { + // Close stdinCh on exit so ChanReader.Read returns io.EOF and any module + // goroutine blocked on stdin (e.g. the serial inner goroutine) unblocks cleanly. + defer close(s.stdinCh) + type recvResult struct { req *pb.RunRequest err error diff --git a/pkg/dut/dut.go b/pkg/dut/dut.go index 2933768..6ff8c3b 100644 --- a/pkg/dut/dut.go +++ b/pkg/dut/dut.go @@ -15,9 +15,9 @@ import ( ) var ( - ErrDeviceNotFound = errors.New("no such device") - ErrCommandNotFound = errors.New("no such command") - ErrNoPassthroughForArgs = errors.New("arguments provided but command has no passthrough module to receive them") + ErrDeviceNotFound = errors.New("no such device") + ErrCommandNotFound = errors.New("no such command") + ErrNoReceiverForArgs = errors.New("arguments provided but command has neither a passthrough module nor declared arguments to receive them") ) // Devlist is a list of devices-under-test. @@ -124,8 +124,11 @@ func (c *Command) countPassthrough() int { // receive their statically configured Args with template references substituted // using runtimeArgs. The returned slice has the same length and ordering as c.Modules. func (c *Command) ModuleArgs(runtimeArgs []string) ([][]string, error) { - if len(runtimeArgs) > 0 && !c.HasPassthrough() { - return nil, ErrNoPassthroughForArgs + // Runtime args may be consumed either by a passthrough module or by + // command-level templating (declared c.Args substituted via ${name}). + // Only reject when neither can receive them. + if len(runtimeArgs) > 0 && !c.HasPassthrough() && len(c.Args) == 0 { + return nil, ErrNoReceiverForArgs } result := make([][]string, len(c.Modules)) diff --git a/pkg/dut/dut_test.go b/pkg/dut/dut_test.go index 918d26a..8260ff1 100644 --- a/pkg/dut/dut_test.go +++ b/pkg/dut/dut_test.go @@ -191,7 +191,7 @@ func TestModuleArgs(t *testing.T) { }}, runtimeArgs: []string{"a"}, want: nil, - err: ErrNoPassthroughForArgs, + err: ErrNoReceiverForArgs, }, { name: "mixed passthrough and non-passthrough", @@ -222,7 +222,7 @@ func TestModuleArgs(t *testing.T) { cmd: Command{}, runtimeArgs: []string{"a"}, want: nil, - err: ErrNoPassthroughForArgs, + err: ErrNoReceiverForArgs, }, { name: "error when runtime args provided but no passthrough module", @@ -231,7 +231,7 @@ func TestModuleArgs(t *testing.T) { }}, runtimeArgs: []string{"a"}, want: nil, - err: ErrNoPassthroughForArgs, + err: ErrNoReceiverForArgs, }, { name: "passthrough module with no runtime args", diff --git a/pkg/module/serial/serial.go b/pkg/module/serial/serial.go index 75039d9..6bd4c02 100644 --- a/pkg/module/serial/serial.go +++ b/pkg/module/serial/serial.go @@ -6,15 +6,15 @@ package serial import ( - "bytes" "context" - "errors" "flag" "fmt" "io" "log" + "os" "regexp" "strings" + "sync" "time" "github.com/BlindspotSoftware/dutctl/pkg/module" @@ -31,14 +31,77 @@ func init() { // DefaultBaudRate is the default baud rate for the serial connection. const DefaultBaudRate = 115200 -// Serial is a module that forwards the serial output of a connected DUT to the dutctl client. -// It is non-interactive and does not support stdin yet. +// pairsDrain is the time the module continues reading serial output after +// sending the last response in expect-send mode, so the output of the +// triggered command is visible before the connection closes. +const pairsDrain = 1 * time.Second + +// reconnectInterval is the delay between attempts to reopen the serial device +// after it has disappeared (e.g. an FTDI chip that powers down with the DUT). +const reconnectInterval = 500 * time.Millisecond + +// deviceLossGraceDefault is how long the module tolerates receiving no serial +// data before it suspects the device is gone (rather than merely idle) and +// verifies by checking the device node. With ReadTimeout set the driver uses +// VMIN=0, so an idle read returns io.EOF exactly like a removed device would — +// the grace plus the node check tell a quiet but healthy console apart from a +// device that has actually vanished. Kept short so a real disconnect is noticed +// quickly, but long enough not to probe on every idle read. +const deviceLossGraceDefault = 2 * time.Second + +// maxMatchWindow bounds the buffer of recent serial output kept for regex +// matching. Matching is always done against the tail of the output, so old +// bytes can be discarded; the cap keeps memory and regex cost bounded while +// staying far larger than any realistic prompt or expect pattern. +const maxMatchWindow = 64 * 1024 + +// stepKind distinguishes the two kinds of step in a serial sequence. +type stepKind int + +const ( + stepExpect stepKind = iota // wait until pattern matches the serial output + stepSend // write data to the serial port +) + +// seqStep is one step of a serial automation sequence. The module walks the +// steps in order: it waits for each expect-step's pattern to appear in the +// output, and writes each send-step's data to the port. A single expect +// pattern and the legacy expect-send pairs both compile down to a slice of +// these, so the run loop has exactly one code path for all non-interactive +// modes. +type seqStep struct { + kind stepKind + pattern *regexp.Regexp // set when kind == stepExpect + data []byte // set when kind == stepSend (empty = send nothing) +} + +// serialPort is the subset of the serial device used by Run. It is satisfied by +// *serial.Port and allows a fake to be injected in tests via Serial.dialPort. +type serialPort interface { + io.ReadWriteCloser + Flush() error +} + +// Serial is a module that provides an interactive serial connection to a DUT. type Serial struct { Port string // Port is the path to the serial device on the dutagent. - Baud int // Baud is the baud rate of the serial device. Is unset, DefaultBaudRate is used. + Baud int // Baud is the baud rate of the serial device. If unset, DefaultBaudRate is used. + + steps []seqStep // steps is the ordered expect/send sequence (empty = interactive mode). + timeout time.Duration // timeout is the maximum time to wait for the sequence to complete. - expect *regexp.Regexp // expect is a pattern to match against the serial output. - timeout time.Duration // timeout is the maximum time to wait for the expect pattern to match. + // dialPort opens the serial port. It defaults to openPort and is overridden + // in tests to inject a fake. Set lazily in Run so the zero value works. + dialPort func() (serialPort, error) + + // portPresent reports whether the configured device node still exists. It + // defaults to a filesystem stat of Port and is overridden in tests. Used to + // tell a benign idle-timeout EOF apart from real device loss. + portPresent func() bool + + // deviceLossGrace overrides deviceLossGraceDefault in tests so the device-loss + // path can be exercised quickly. Zero means use the default. + deviceLossGrace time.Duration } // Ensure implementing the Module interface. @@ -49,18 +112,42 @@ const abstract = `Serial connection to the DUT const usage = ` ARGUMENTS: - [-t ] [] + [-t ] [ [ ...]] + [-t ] expect:|send: [expect:|send: ...] ` const description = ` -The serial connection is read-only and does not support stdin yet. -If a regex is provided, the module will wait for the regex to match on the serial output, -then exit with success. If no expect string is provided, the module will read from the serial port -until it is terminated by a signal (e.g. Ctrl-C). -The expect string supports regular expressions according to [1]. -The optional -t flag specifies the maximum time to wait for the regex to match. -Quote the expect string if it contains spaces or special characters. E.g.: "(?i)hello\s+world!? dutctl" +The serial module provides an interactive connection to the DUT's serial port. +Input from the client is forwarded to the serial port, and output from the serial port is displayed. + +Modes of operation: + - Interactive (no arguments): read and write until terminated by a signal (e.g. Ctrl-C). + - Expect (1 argument): wait for the regex to match on the serial output, then exit. + - Expect-send (even number of arguments >= 2): pass pattern/response pairs. + For each pair, the module waits for the pattern to match and then sends the + response to the serial port. Pairs are processed in order; after the last + pair matches the module reads serial output for 1 more second so the output + of the triggered command is visible, then exits. + - Sequence (every argument carries an "expect:" or "send:" tag): an ordered + list of steps run one after another. An expect-step waits for its regex to + match the serial output; a send-step writes its data to the port. Unlike + expect-send pairs, the steps may appear in any order, so a sequence can + begin with a send (e.g. an Enter to wake the console) or chain several + sends or expects in a row. The whole sequence shares the -t deadline. If + the last step is a send, the module drains output for 1 more second before + exiting; if it is an expect, it exits as soon as that pattern matches. + + e.g.: send:"\n" expect:"login:" send:"root\n" expect:"# " send:"reboot\n" + +If the serial device disappears mid-session (e.g. an FTDI chip that powers down +when the DUT loses power), the module waits for it to reappear and reconnects +automatically instead of ending the session. + +The expect string supports regular expressions according to [1]. +The optional -t flag specifies the maximum time to wait. +Quote strings containing spaces or special characters. E.g.: "(?i)hello\s+world" +Response and send strings support C-style escape sequences: \n, \r, \t, \\, \xHH. [1] https://golang.org/s/re2syntax. ` @@ -71,7 +158,7 @@ func (s *Serial) Help() string { help := strings.Builder{} help.WriteString(abstract) help.WriteString(usage) - help.WriteString(fmt.Sprintf("Configured COM port is %q with baud rate %d.\n", s.Port, s.Baud)) + fmt.Fprintf(&help, "Configured COM port is %q with baud rate %d.\n", s.Port, s.Baud) help.WriteString(description) return help.String() @@ -101,7 +188,25 @@ func (s *Serial) Deinit() error { return nil } -//nolint:cyclop,funlen,gocognit +// Run bridges the DUT's serial port to the client console. +// +// Concurrency model (deliberately minimal to be race- and deadlock-free): +// +// - The main loop is the sole owner of all match state (expect/pairs/draining, +// the match window, and the CSI remainder). Nothing else touches it, so no +// lock is needed to protect it. +// - One "stdin pump" goroutine forwards client keystrokes to the port. It is +// the only goroutine that may briefly outlive Run; it never reads or writes +// match state. It unblocks via stdin EOF, which the agent guarantees by +// closing the stdin channel during session teardown. +// - One "stdout pump" goroutine performs every client write, so the main loop +// can keep matching and draining without blocking on a slow client, and so +// a vanished client can never wedge the main loop (writes abort on context +// cancellation). +// - Port writes (stdin forwarding + auto-responses) are serialised by portMu, +// which is never held across a channel operation, so it cannot deadlock. +// +//nolint:gocognit,cyclop,funlen,gocyclo,maintidx func (s *Serial) Run(ctx context.Context, session module.Session, args ...string) error { log.Println("serial module: Run called") @@ -110,94 +215,414 @@ func (s *Serial) Run(ctx context.Context, session module.Session, args ...string return err } - port, err := s.openPort() + dial := s.dialPort + if dial == nil { + dial = s.openPort + } + + present := s.portPresent + if present == nil { + present = func() bool { + _, statErr := os.Stat(s.Port) + + return statErr == nil + } + } + + grace := s.deviceLossGrace + if grace == 0 { + grace = deviceLossGraceDefault + } + + port, err := dial() if err != nil { return err } - defer port.Close() + + // done is closed when Run returns, signalling the stdin pump to stop writing + // to the (about to be closed) port. + done := make(chan struct{}) + defer close(done) + + // portMu guards the current port handle, which is replaced when the serial + // device disappears (e.g. an FTDI chip that powers down with the DUT) and is + // reopened once it reappears — see reconnect below. The main loop owns the + // handle and is its only writer, so it may read port directly; the stdin + // pump (another goroutine) goes through writeToPort under the lock. + var portMu sync.Mutex + + storePort := func(next serialPort) { + portMu.Lock() + defer portMu.Unlock() + + port = next + } + + closePort := func() { + portMu.Lock() + defer portMu.Unlock() + + if port != nil { + _ = port.Close() + port = nil + } + } + + defer closePort() // Discard any stale bytes left in the kernel/driver RX buffer from a // previous session, otherwise the user sees data from the last boot. - err = port.Flush() - if err != nil { - log.Printf("serial module: flush failed: %v", err) + flushErr := port.Flush() + if flushErr != nil { + log.Printf("serial module: flush failed: %v", flushErr) } + stdin, stdout, _ := session.Console() // stderr intentionally unused: serial output goes to stdout only + log.Printf("serial module: connected to %s at %d baud", s.Port, s.Baud) - session.Print(fmt.Sprintf("--- Connected to %s at %d baud ---\n", s.Port, s.Baud)) + fmt.Fprintf(stdout, "--- Connected to %s at %d baud ---\n", s.Port, s.Baud) + + // baseCtx tracks external cancellation (client disconnect / agent teardown). + // It is never replaced, so the stdout pump keeps flushing even while a + // derived deadline (expect timeout or post-match drain) is counting down. + baseCtx := ctx + + writeToPort := func(data []byte) { + portMu.Lock() + defer portMu.Unlock() + + if port == nil { + return // device is gone; drop input until it reconnects + } + + _, werr := port.Write(data) + if werr != nil { + select { + case <-done: // port closed on Run exit — not an error. + default: + log.Printf("serial module: error writing to serial port: %v", werr) + } + } + } + + s.startStdinPump(stdin, done, writeToPort) - var cancel context.CancelFunc + emit, flushAndWait := newStdoutPump(baseCtx, stdout) + defer flushAndWait(false) // ensure the pump goroutine is released on every exit path + + // loopCtx carries the active deadline: the expect timeout first, then the + // post-match drain. It is always derived from baseCtx, so external + // cancellation still propagates through it. + loopCtx := baseCtx + + var loopCancel context.CancelFunc if s.timeout > 0 { log.Printf("serial module: setting timeout of %s", s.timeout) - ctx, cancel = context.WithTimeout(ctx, s.timeout) + loopCtx, loopCancel = context.WithTimeout(baseCtx, s.timeout) + } + + defer func() { + if loopCancel != nil { + loopCancel() + } + }() + + var ( + remainder []byte // partial CSI sequence carried across reads + matchWindow []byte // bounded window of recent output for regex matching + currentStep int // cursor into s.steps + draining bool + ) + + // lastData is the time of the last data-bearing read; with no data for the + // grace period it triggers the device-loss check. + lastData := time.Now() + + //nolint:fatcontext // intentional: the post-match drain replaces the loop deadline + startDrain := func() { + if loopCancel != nil { + loopCancel() + } + + loopCtx, loopCancel = context.WithTimeout(baseCtx, pairsDrain) + draining = true + + log.Printf("serial module: draining serial output for %s before closing", pairsDrain) + } + + // advanceSends fires consecutive send-steps at the cursor, writing each to + // the port, until the cursor reaches an expect-step or the end of the + // sequence. It is called once before the first read (so a sequence may begin + // with a send) and again after every expect-step matches. + advanceSends := func() { + for currentStep < len(s.steps) && s.steps[currentStep].kind == stepSend { + data := s.steps[currentStep].data + currentStep++ + + if len(data) > 0 { + writeToPort(data) + } + } + } + + // reconnect closes the vanished port and retries opening it until the device + // reappears, the deadline (watchCtx) fires, or the session is cancelled. Once + // it returns, either a fresh port is in place or watchCtx is done and the + // loop's select handles it. Mirrors tio's auto-reconnect so a DUT power-cycle + // that drops the FTDI device does not end the session. + reconnect := func(watchCtx context.Context) { + closePort() + + emit([]byte("\n--- Serial device disconnected, waiting to reconnect ---\n")) + log.Printf("serial module: device %s disconnected, waiting to reconnect", s.Port) + + for { + fresh, dialErr := dial() + if dialErr == nil { + flushErr := fresh.Flush() + if flushErr != nil { + log.Printf("serial module: flush after reconnect failed: %v", flushErr) + } + + storePort(fresh) + + emit([]byte("\n--- Serial device reconnected ---\n")) + log.Printf("serial module: device %s reconnected", s.Port) + + return + } - defer cancel() + select { + case <-watchCtx.Done(): + return + case <-time.After(reconnectInterval): + } + } } const bufferSize = 4096 readBuffer := make([]byte, bufferSize) - lineBuffer := &bytes.Buffer{} + + // Fire any leading send-steps before the first read, so a sequence may begin + // by sending (e.g. an Enter to wake the console) rather than expecting. If + // the sequence is sends only, drain briefly so the DUT's response to the + // final input is visible, then exit via the drain deadline below. + advanceSends() + + if len(s.steps) > 0 && currentStep >= len(s.steps) { + startDrain() + } for { select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - session.Print("\n--- Timeout reached, no match found ---\n") + case <-loopCtx.Done(): + if baseCtx.Err() != nil { + // External cancellation: client disconnect or agent teardown. + // The client may no longer be reading, so do not wait on a flush. + log.Println("serial module: context cancelled, closing") - return fmt.Errorf("timeout of %s reached, pattern %q not found", s.timeout, s.expect) + return baseCtx.Err() } - session.Print("\n--- Connection closed ---\n") + // One of our own deadlines fired. + if draining { + emit([]byte("\n--- Pattern matched, connection closed ---")) + flushAndWait(true) - return ctx.Err() - default: - sbytes, err := port.Read(readBuffer) - if err != nil { - // Ignore timeout errors as these are expected with the read timeout config - if err != io.EOF && !strings.Contains(err.Error(), "timeout") { - return fmt.Errorf("error reading from serial port: %w", err) - } + return nil + } - continue + emit([]byte("\n--- Timeout reached, no match found ---")) + flushAndWait(true) + + if len(s.steps) == 1 && s.steps[0].kind == stepExpect { + return fmt.Errorf("timeout of %s reached, pattern %q not found", s.timeout, s.steps[0].pattern) } - if sbytes == 0 { - continue + return fmt.Errorf("timeout of %s reached, expect-send sequence not completed", s.timeout) + default: + } + + nRead, readErr := port.Read(readBuffer) + + // A hard read error (EIO/ENXIO/…) means the device errored outright — wait + // for it to come back instead of ending the session. A plain read timeout + // or io.EOF is NOT a hard error: with VMIN=0 an idle read returns io.EOF, + // so it is handled below together with zero-length reads. + if readErr != nil && readErr != io.EOF && !strings.Contains(readErr.Error(), "timeout") { + reconnect(loopCtx) + + lastData = time.Now() + + continue + } + + if nRead == 0 { + // No data this cycle. Usually a benign idle read — but a removed or + // powered-down device idles identically (io.EOF / zero bytes), so once + // no data has arrived for the grace period, verify the device node + // still exists and reconnect if it has vanished. reconnect returns + // false only when the deadline or an external cancellation fired; the + // next loop iteration's select then handles it. + if time.Since(lastData) > grace && !present() { + reconnect(loopCtx) + + lastData = time.Now() } - // Forward bytes to the client immediately so partial lines are - // not held back waiting for a newline from the DUT. - session.Print(string(readBuffer[:sbytes])) + continue + } + + lastData = time.Now() + + // Filter cursor/query CSI sequences (SGR colour is preserved), + // reconstructing sequences split across reads via remainder. + chunk := readBuffer[:nRead] + if len(remainder) > 0 { + chunk = append(remainder, chunk...) + remainder = nil + } + + out := filterOutputCSI(chunk, &remainder) + if len(out) == 0 { + continue + } + + // Display everything immediately, including partial lines (e.g. prompts + // without a trailing newline). out is a fresh slice owned by the pump. + emit(out) - if s.expect == nil { - continue + // Matching is skipped while draining and in interactive mode (no steps). + if draining || len(s.steps) == 0 { + continue + } + + matchWindow = append(matchWindow, out...) + if len(matchWindow) > maxMatchWindow { + matchWindow = matchWindow[len(matchWindow)-maxMatchWindow:] + } + + // Walk the sequence: satisfy every expect-step the current window already + // matches, firing the send-steps that follow each match. The cursor rests + // on an expect-step here, because advanceSends consumed any leading or + // trailing sends after the previous iteration. + for currentStep < len(s.steps) && s.steps[currentStep].kind == stepExpect { + loc := s.steps[currentStep].pattern.FindIndex(matchWindow) + if loc == nil { + break } - // Process the data read character by character - for i := range sbytes { - b := readBuffer[i] - lineBuffer.WriteByte(b) + matchWindow = matchWindow[loc[1]:] // consume through the match + currentStep++ - // If we reach a newline or a buffer limit, process the line - if b == '\n' || lineBuffer.Len() >= 1024 { - line := lineBuffer.String() + advanceSends() // fire the send-steps that follow this match - // Check for regex match if we have one - if s.expect.MatchString(line) { - session.Print("\n--- Pattern matched, connection closed ---\n") + if currentStep < len(s.steps) { + continue // more steps remain; keep matching the current window + } - return nil // Success - pattern found - } + // Sequence complete. If it ended on a send, drain so the DUT's + // response to the final input is visible; if it ended on an expect, + // the match itself is the completion, so exit immediately. + if s.steps[len(s.steps)-1].kind == stepSend { + startDrain() + } else { + emit([]byte("\n--- Pattern matched, connection closed ---")) + flushAndWait(true) - lineBuffer.Reset() + return nil + } + + break + } + } +} + +const stdinBufSize = 256 + +// startStdinPump forwards client keystrokes to the serial port until stdin +// reaches EOF. It is the only goroutine that may outlive Run; it touches no +// match state, so it cannot race with the main loop. +func (s *Serial) startStdinPump(stdin io.Reader, done <-chan struct{}, writeToPort func([]byte)) { + go func() { + buf := make([]byte, stdinBufSize) + + for { + nRead, err := stdin.Read(buf) + if nRead > 0 { + select { + case <-done: + return // Run exited — do not write to the closing port. + default: + writeToPort(buf[:nRead]) } } + + if err != nil { + return // EOF on session teardown, or a read error. + } + } + }() +} + +const stdoutQueueLen = 256 + +// newStdoutPump starts a single goroutine that performs every client write and +// returns two closures: +// +// - emit(data) queues data for the client. It never blocks the caller +// indefinitely: if baseCtx is cancelled (client gone) the data is dropped +// rather than wedging the main loop. The caller must not reuse data +// afterwards; pass a fresh slice. +// - flushAndWait(deliver) shuts the pump down. With deliver=true it waits for +// all queued data to be written (safe only while the client is still +// reading, i.e. on graceful completion). With deliver=false it just releases +// the pump without waiting; it is idempotent and safe to defer on every path. +func newStdoutPump(baseCtx context.Context, stdout io.Writer) (func([]byte), func(bool)) { + outCh := make(chan []byte, stdoutQueueLen) + pumpDone := make(chan struct{}) + + go func() { + defer close(pumpDone) + + for data := range outCh { + _, _ = stdout.Write(data) // ChanWriter.Write only fails on misuse + } + }() + + var closeOnce sync.Once + + closeOut := func() { closeOnce.Do(func() { close(outCh) }) } + + emit := func(data []byte) { + select { + case outCh <- data: + case <-baseCtx.Done(): // client gone — drop rather than block forever. } } + + flushAndWait := func(deliver bool) { + closeOut() + + if deliver { + <-pumpDone + } + } + + return emit, flushAndWait } +// pairStride is the number of positional arguments per expect-send pair. +const pairStride = 2 + +// Tag prefixes that mark a tagged-sequence argument. +const ( + expectTag = "expect:" + sendTag = "send:" +) + func (s *Serial) evalArgs(args []string) error { fs := flag.NewFlagSet("serial", flag.ContinueOnError) fs.SetOutput(io.Discard) // Suppress default error output @@ -208,17 +633,93 @@ func (s *Serial) evalArgs(args []string) error { return fmt.Errorf("failed to parse arguments: %w", err) } - // Reset expect so a previous pattern does not carry over into the next run. - s.expect = nil + positional := fs.Args() + + // Tagged-sequence mode is selected when the first argument carries a tag. + // It allows arbitrarily ordered expect/send steps (e.g. a leading send). + if len(positional) > 0 && isTaggedStep(positional[0]) { + return s.evalSequenceArgs(positional) + } + + return s.evalLegacyArgs(positional) +} + +// evalLegacyArgs parses the backward-compatible positional argument forms into +// s.steps: no args (interactive), one arg (single expect), or an even number of +// args (expect-send pairs). +func (s *Serial) evalLegacyArgs(positional []string) error { + switch len(positional) { + case 0: + // Interactive mode: no steps. + case 1: + // Single-expect mode. + log.Printf("serial module: Will wait for pattern: %q", positional[0]) + + pattern, compileErr := regexp.Compile(positional[0]) + if compileErr != nil { + return fmt.Errorf("invalid regular expression: %w", compileErr) + } + + s.steps = []seqStep{{kind: stepExpect, pattern: pattern}} + default: + // Expect-send pairs mode. + if len(positional)%pairStride != 0 { + return fmt.Errorf("expect-send requires an even number of arguments, got %d", len(positional)) + } + + s.steps = make([]seqStep, 0, len(positional)) + + for idx := 0; idx < len(positional); idx += pairStride { + pattern, compileErr := regexp.Compile(positional[idx]) + if compileErr != nil { + return fmt.Errorf("invalid regular expression %q: %w", positional[idx], compileErr) + } + + log.Printf("serial module: Pair %d: pattern=%q response=%q", idx/pairStride+1, positional[idx], positional[idx+1]) + + s.steps = append(s.steps, + seqStep{kind: stepExpect, pattern: pattern}, + seqStep{kind: stepSend, data: unescape(positional[idx+1])}, + ) + } + } + + return nil +} + +// isTaggedStep reports whether arg carries an "expect:" or "send:" tag. +func isTaggedStep(arg string) bool { + return strings.HasPrefix(arg, expectTag) || strings.HasPrefix(arg, sendTag) +} + +// evalSequenceArgs parses tagged-sequence arguments into s.steps. Every +// argument must carry a tag; mixing tagged and untagged arguments is rejected +// so a malformed command fails loudly instead of being silently misread. +func (s *Serial) evalSequenceArgs(args []string) error { + s.steps = make([]seqStep, 0, len(args)) + + for idx, arg := range args { + switch { + case strings.HasPrefix(arg, expectTag): + expr := arg[len(expectTag):] + + pattern, compileErr := regexp.Compile(expr) + if compileErr != nil { + return fmt.Errorf("step %d: invalid regular expression %q: %w", idx+1, expr, compileErr) + } + + log.Printf("serial module: Step %d: expect=%q", idx+1, expr) + + s.steps = append(s.steps, seqStep{kind: stepExpect, pattern: pattern}) + case strings.HasPrefix(arg, sendTag): + data := arg[len(sendTag):] - // Get the expect string if provided (args after flags) - if fs.NArg() > 0 { - expectPattern := fs.Arg(0) - log.Printf("serial module: Will wait for pattern: %q", expectPattern) + log.Printf("serial module: Step %d: send=%q", idx+1, data) - s.expect, err = regexp.Compile(expectPattern) - if err != nil { - return fmt.Errorf("invalid regular expression: %w", err) + s.steps = append(s.steps, seqStep{kind: stepSend, data: unescape(data)}) + default: + return fmt.Errorf("step %d %q: in sequence mode every argument must start with %q or %q", + idx+1, arg, expectTag, sendTag) } } @@ -227,7 +728,8 @@ func (s *Serial) evalArgs(args []string) error { const readTimeout = 100 * time.Millisecond -func (s *Serial) openPort() (*serial.Port, error) { +//nolint:ireturn // intentional: returns the serialPort interface so a fake can be injected in tests +func (s *Serial) openPort() (serialPort, error) { config := &serial.Config{ Name: s.Port, Baud: s.Baud, @@ -241,3 +743,209 @@ func (s *Serial) openPort() (*serial.Port, error) { return port, nil } + +// escByte is the ASCII escape character that starts ANSI/VT escape sequences. +const escByte = 0x1b + +// csiPrefixLen is the length of the CSI prefix "ESC[". +const csiPrefixLen = 2 + +// Bytes that terminate an ANSI string sequence. +const ( + stByte = 0x9c // single-byte String Terminator (C1) + belByte = 0x07 // BEL — also terminates an OSC sequence +) + +// isStringSeqIntroducer reports whether b, following ESC, begins an ANSI string +// sequence: DCS (P), OSC (]), SOS (X), PM (^) or APC (_). These carry terminal +// queries and responses (e.g. capability reports like the WezTerm "name=" DCS), +// never display content, so they are dropped in full. +func isStringSeqIntroducer(b byte) bool { + return b == 'P' || b == ']' || b == 'X' || b == '^' || b == '_' +} + +// stringSeqEnd returns the index of the last byte of the ANSI string sequence +// that starts at data[escIdx] (with introducer data[escIdx+1]), or -1 if the +// sequence is not terminated within data — in which case the caller carries it +// to the next read. The sequence ends at ST ("ESC \" or the single-byte 0x9c); +// an OSC may also end at BEL. +func stringSeqEnd(data []byte, escIdx int) int { + osc := data[escIdx+1] == ']' + + for pos := escIdx + csiPrefixLen; pos < len(data); pos++ { + switch { + case data[pos] == stByte: + return pos + case osc && data[pos] == belByte: + return pos + case data[pos] == escByte: + if pos+1 >= len(data) { + return -1 // ESC at end — the ST may be split across reads + } + + if data[pos+1] == '\\' { + return pos + 1 // "ESC \" = ST + } + // ESC followed by anything else is not a terminator; keep scanning. + } + } + + return -1 +} + +// filterOutputCSI removes terminal control sequences from serial output data, +// except for SGR (Select Graphic Rendition, final byte 'm') which handles colors +// and styles. It drops CSI sequences (cursor positioning, screen clearing, +// queries) and ANSI string sequences (DCS/OSC/SOS/PM/APC terminal query +// responses), while preserving coloured output. +// +// Incomplete sequences at the end of data are stored in remainder so they can be +// prepended to the next buffer read and reconstituted correctly. +// +//nolint:cyclop,varnamelen +func filterOutputCSI(data []byte, remainder *[]byte) []byte { + result := make([]byte, 0, len(data)) + *remainder = nil + + for i := 0; i < len(data); i++ { + if data[i] != escByte { + result = append(result, data[i]) + + continue + } + + // ESC at end of buffer: might be the start of a sequence split across reads. + if i+1 >= len(data) { + *remainder = []byte{escByte} + + break + } + + if isStringSeqIntroducer(data[i+1]) { + end := stringSeqEnd(data, i) + if end < 0 { + // Incomplete string sequence — carry it over to the next read. + *remainder = make([]byte, len(data)-i) + copy(*remainder, data[i:]) + + break + } + + // Drop the whole sequence; the outer loop's i++ advances past data[end]. + i = end + + continue + } + + if data[i+1] != '[' { + // ESC not followed by '[' or a string introducer — emit as-is. + result = append(result, data[i]) + + continue + } + + // CSI sequence: ESC [ + // Find the extent of this CSI sequence. + j := i + csiPrefixLen + + for j < len(data) && data[j] >= 0x30 && data[j] <= 0x3f { + j++ // parameter bytes + } + + for j < len(data) && data[j] >= 0x20 && data[j] <= 0x2f { + j++ // intermediate bytes + } + + if j >= len(data) { + // Incomplete sequence at end of buffer — carry it over to the next read. + *remainder = make([]byte, len(data)-i) + copy(*remainder, data[i:]) + + break + } + + if data[j] == 'm' { + // SGR (colors/styles) — keep it. + result = append(result, data[i:j+1]...) + } + + // All other CSI sequences are dropped, including malformed ones where + // data[j] is not a valid final byte (0x40–0x7E). The byte at data[j] + // is consumed by setting i = j; the outer loop's i++ then advances past + // it. Silently dropping malformed sequences is safer than emitting + // partial escape bytes which could corrupt the terminal display. + i = j + } + + return result +} + +// unescape converts C-style escape sequences in s to their byte equivalents. +// Supported sequences: \n (newline), \r (carriage return), \t (tab), +// \\ (backslash), \xHH (hex byte). Unrecognised sequences are emitted as-is. +// +//nolint:cyclop // a flat switch over the supported escapes; splitting it would not help readability +func unescape(s string) []byte { + out := make([]byte, 0, len(s)) + + for idx := 0; idx < len(s); idx++ { + if s[idx] != '\\' || idx+1 >= len(s) { + out = append(out, s[idx]) + + continue + } + + idx++ + + switch s[idx] { + case 'n': + out = append(out, '\n') + case 'r': + out = append(out, '\r') + case 't': + out = append(out, '\t') + case '\\': + out = append(out, '\\') + case 'x': + if idx+hexDigits < len(s) { + hi, hiOK := fromHex(s[idx+1]) + lo, loOK := fromHex(s[idx+hexDigits]) + + if hiOK && loOK { + out = append(out, hi<= '0' && digit <= '9': + return digit - '0', true + case digit >= 'a' && digit <= 'f': + return digit - 'a' + hexLetterOffset, true + case digit >= 'A' && digit <= 'F': + return digit - 'A' + hexLetterOffset, true + default: + return 0, false + } +} diff --git a/pkg/module/serial/serial_run_test.go b/pkg/module/serial/serial_run_test.go new file mode 100644 index 0000000..9b0d1a6 --- /dev/null +++ b/pkg/module/serial/serial_run_test.go @@ -0,0 +1,592 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package serial + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "sync" + "testing" + "time" +) + +// --- test doubles --------------------------------------------------------- + +// fakePort is an in-memory serialPort. Queued chunks are delivered by Read in +// order; once drained, Read emulates a serial read timeout (so Run stays +// responsive to context cancellation). Bytes written by the module are captured. +type fakePort struct { + mu sync.Mutex + out [][]byte + written []byte + closed bool + + // disconnectErr, if set, is returned once by Read after all queued output is + // drained, simulating the device disappearing. Subsequent reads time out. + disconnectErr error + disconnected bool +} + +type ioTimeout struct{} + +func (ioTimeout) Error() string { return "i/o timeout" } + +func (p *fakePort) queue(b []byte) { + p.mu.Lock() + defer p.mu.Unlock() + p.out = append(p.out, b) +} + +func (p *fakePort) Read(b []byte) (int, error) { + p.mu.Lock() + if len(p.out) > 0 { + chunk := p.out[0] + n := copy(b, chunk) + + if n < len(chunk) { + p.out[0] = chunk[n:] + } else { + p.out = p.out[1:] + } + + p.mu.Unlock() + + return n, nil + } + + if p.disconnectErr != nil && !p.disconnected { + p.disconnected = true + err := p.disconnectErr + p.mu.Unlock() + + return 0, err + } + p.mu.Unlock() + + // Emulate the real port's ReadTimeout so the caller's loop spins at a + // bounded rate instead of busy-waiting. + time.Sleep(2 * time.Millisecond) + + return 0, ioTimeout{} +} + +func (p *fakePort) Write(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + p.written = append(p.written, b...) + + return len(b), nil +} + +func (p *fakePort) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + p.closed = true + + return nil +} + +func (p *fakePort) Flush() error { return nil } + +func (p *fakePort) writtenString() string { + p.mu.Lock() + defer p.mu.Unlock() + + return string(p.written) +} + +// syncBuffer is a goroutine-safe io.Writer for capturing client stdout, which +// the stdout pump writes concurrently with the test reading it. +type syncBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (s *syncBuffer) Write(b []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.buf.Write(b) +} + +func (s *syncBuffer) String() string { + s.mu.Lock() + defer s.mu.Unlock() + + return s.buf.String() +} + +// fakeSession implements module.Session for tests. +type fakeSession struct { + stdinR io.Reader + stdoutW io.Writer +} + +func (f *fakeSession) Print(...any) {} +func (f *fakeSession) Printf(string, ...any) {} +func (f *fakeSession) Println(...any) {} + +func (f *fakeSession) Console() (io.Reader, io.Writer, io.Writer) { + return f.stdinR, f.stdoutW, io.Discard +} + +func (f *fakeSession) RequestFile(string) (io.Reader, error) { return nil, nil } +func (f *fakeSession) SendFile(string, io.Reader) error { return nil } + +// newSession wires a fake session with a pipe-backed stdin (so the stdin pump +// unblocks on cleanup) and a capturing stdout. It returns the session, the +// stdout capture, and the stdin write end. +func newSession(t *testing.T) (*fakeSession, *syncBuffer, *io.PipeWriter) { + t.Helper() + + pr, pw := io.Pipe() + stdout := &syncBuffer{} + + t.Cleanup(func() { _ = pw.Close() }) // EOF unblocks the stdin pump + + return &fakeSession{stdinR: pr, stdoutW: stdout}, stdout, pw +} + +// --- tests ---------------------------------------------------------------- + +func TestRunSingleExpectMatch(t *testing.T) { + port := &fakePort{} + port.queue([]byte("booting kernel...\n")) + port.queue([]byte("dut login: ")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, stdout, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := s.Run(ctx, sess, "login:") + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if got := stdout.String(); !strings.Contains(got, "booting kernel") || + !strings.Contains(got, "Pattern matched") { + t.Errorf("stdout missing expected content: %q", got) + } +} + +func TestRunExpectTimeout(t *testing.T) { + port := &fakePort{} + port.queue([]byte("nothing interesting here\n")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, stdout, _ := newSession(t) + + err := s.Run(context.Background(), sess, "-t", "150ms", "will-never-appear") + if err == nil { + t.Fatal("Run returned nil, want timeout error") + } + + if !strings.Contains(err.Error(), "timeout") { + t.Errorf("error = %v, want timeout", err) + } + + if got := stdout.String(); !strings.Contains(got, "Timeout reached") { + t.Errorf("stdout missing timeout banner: %q", got) + } +} + +func TestRunExpectSendPairs(t *testing.T) { + port := &fakePort{} + // Both prompts are present in the output stream; both pairs must fire in + // order, then the module drains briefly and exits. + port.queue([]byte("dut login: ")) + port.queue([]byte("user\r\nPassword: ")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + start := time.Now() + + err := s.Run(ctx, sess, "login:", "admin\\n", "Password:", "secret\\n") + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if elapsed := time.Since(start); elapsed < pairsDrain { + t.Errorf("Run returned after %s, expected at least the %s drain", elapsed, pairsDrain) + } + + if got := port.writtenString(); got != "admin\nsecret\n" { + t.Errorf("responses written to port = %q, want %q", got, "admin\nsecret\n") + } +} + +func TestRunExpectSendStopsBeforeLastResponseOnTimeout(t *testing.T) { + port := &fakePort{} + port.queue([]byte("dut login: ")) // only the first prompt ever appears + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _, _ := newSession(t) + + err := s.Run(context.Background(), sess, "-t", "200ms", "login:", "admin\\n", "Password:", "secret\\n") + if err == nil || !strings.Contains(err.Error(), "sequence not completed") { + t.Fatalf("Run error = %v, want 'sequence not completed' timeout", err) + } + + // First pair fired, second did not. + if got := port.writtenString(); got != "admin\n" { + t.Errorf("responses written to port = %q, want %q", got, "admin\n") + } +} + +func TestRunInteractiveForwardsStdinToPort(t *testing.T) { + port := &fakePort{} + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + // interactive: no args + + sess, _, stdin := newSession(t) + + ctx, cancel := context.WithCancel(context.Background()) + + runErr := make(chan error, 1) + + go func() { runErr <- s.Run(ctx, sess) }() + + if _, err := stdin.Write([]byte("reboot\n")); err != nil { + t.Fatalf("write stdin: %v", err) + } + + // Wait until the keystrokes reach the port. + deadline := time.After(2 * time.Second) + + for port.writtenString() != "reboot\n" { + select { + case <-deadline: + t.Fatalf("port did not receive stdin, got %q", port.writtenString()) + case <-time.After(5 * time.Millisecond): + } + } + + cancel() + + select { + case err := <-runErr: + if !errors.Is(err, context.Canceled) { + t.Errorf("Run error = %v, want context.Canceled", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Run did not return after cancel") + } +} + +func TestRunContextCancelReturnsPromptly(t *testing.T) { + port := &fakePort{} + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _, _ := newSession(t) + + ctx, cancel := context.WithCancel(context.Background()) + + runErr := make(chan error, 1) + + go func() { runErr <- s.Run(ctx, sess) }() + + time.Sleep(20 * time.Millisecond) + cancel() + + select { + case err := <-runErr: + if !errors.Is(err, context.Canceled) { + t.Errorf("Run error = %v, want context.Canceled", err) + } + case <-time.After(time.Second): + t.Fatal("Run did not return promptly after cancel") + } +} + +func TestRunStripsNonSGROutputButKeepsColour(t *testing.T) { + port := &fakePort{} + // DSR query and cursor-up (must be stripped), an SGR colour (must survive), + // plain text on either side. + port.queue([]byte("ABC\x1b[6nDEF\x1b[2A \x1b[31mred\x1b[0m XYZ")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, stdout, _ := newSession(t) + + ctx, cancel := context.WithCancel(context.Background()) + + runErr := make(chan error, 1) + + go func() { runErr <- s.Run(ctx, sess) }() + + // Wait for the chunk to be displayed. + deadline := time.After(2 * time.Second) + + for !strings.Contains(stdout.String(), "XYZ") { + select { + case <-deadline: + t.Fatalf("output not displayed, got %q", stdout.String()) + case <-time.After(5 * time.Millisecond): + } + } + + cancel() + <-runErr + + got := stdout.String() + if strings.Contains(got, "\x1b[6n") { + t.Errorf("DSR query not stripped: %q", got) + } + + if strings.Contains(got, "\x1b[2A") { + t.Errorf("cursor-up not stripped: %q", got) + } + + if !strings.Contains(got, "\x1b[31mred\x1b[0m") { + t.Errorf("SGR colour not preserved: %q", got) + } + + // The two stripped CSIs sat between ABC/DEF and after red; surrounding plain + // text must remain intact (just with the CSIs removed). + if !strings.Contains(got, "ABCDEF ") || !strings.Contains(got, " XYZ") { + t.Errorf("plain text mangled: %q", got) + } +} + +// TestRunMatchSpanningReads verifies a pattern split across two serial reads is +// still matched, exercising the rolling match window. +func TestRunMatchSpanningReads(t *testing.T) { + port := &fakePort{} + port.queue([]byte("please lo")) + port.queue([]byte("gin: now")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if err := s.Run(ctx, sess, "login:"); err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } +} + +type deviceGoneError struct{} + +func (deviceGoneError) Error() string { return "read /dev/ttyUSB0: input/output error" } + +// TestRunReconnectsOnDeviceLoss verifies the module survives the serial device +// disappearing mid-session, reopens it once it reappears, and keeps matching. +func TestRunReconnectsOnDeviceLoss(t *testing.T) { + // First port delivers some output then reports a disconnect. + port1 := &fakePort{disconnectErr: deviceGoneError{}} + port1.queue([]byte("booting...\n")) + + // Second port (the reappeared device) presents the prompt to match. + port2 := &fakePort{} + port2.queue([]byte("dut login: ")) + + dialed := 0 + + s := &Serial{Port: "/dev/ttyUSB0", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { + dialed++ + if dialed == 1 { + return port1, nil + } + + return port2, nil + } + + sess, stdout, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := s.Run(ctx, sess, "login:"); err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if dialed < 2 { + t.Errorf("dial called %d times, want >= 2 (reconnect)", dialed) + } + + got := stdout.String() + if !strings.Contains(got, "disconnected") || !strings.Contains(got, "reconnected") { + t.Errorf("missing reconnect status messages: %q", got) + } + + if !strings.Contains(got, "Pattern matched") { + t.Errorf("pattern not matched after reconnect: %q", got) + } +} + +// TestRunReconnectAbortsOnTimeout verifies the expect timeout still fires while +// the module is waiting for a missing device to come back. +func TestRunReconnectAbortsOnTimeout(t *testing.T) { + port := &fakePort{disconnectErr: deviceGoneError{}} + port.queue([]byte("hello\n")) + + s := &Serial{Port: "/dev/ttyUSB0", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { + // Device never comes back after the first open. + if port.disconnected { + return nil, deviceGoneError{} + } + + return port, nil + } + + sess, _, _ := newSession(t) + + err := s.Run(context.Background(), sess, "-t", "300ms", "will-never-appear") + if err == nil || !strings.Contains(err.Error(), "timeout") { + t.Fatalf("Run error = %v, want timeout while reconnecting", err) + } +} + +// TestRunReconnectsWhenDeviceNodeVanishes covers the device-loss path that real +// hardware actually hits: a removed USB serial adapter surfaces to Read as an +// idle io.EOF/timeout (NOT a distinct "device gone" error), so the module must +// notice the vanished device node and reconnect instead of spinning forever on +// what looks like a benign idle read. +func TestRunReconnectsWhenDeviceNodeVanishes(t *testing.T) { + port := &fakePort{} // no queued data: idle reads return a timeout, like a quiet port + + opens := 0 + + s := &Serial{Port: "/dev/fake", Baud: DefaultBaudRate} + s.deviceLossGrace = 20 * time.Millisecond // suspect loss quickly in the test + s.portPresent = func() bool { return false } // the device node has vanished + s.dialPort = func() (serialPort, error) { + opens++ + if opens == 1 { + return port, nil // initial open succeeds + } + + return nil, deviceGoneError{} // device stays gone while reconnecting + } + + sess, stdout, _ := newSession(t) + + err := s.Run(context.Background(), sess, "-t", "300ms", "will-never-appear") + if err == nil || !strings.Contains(err.Error(), "timeout") { + t.Fatalf("Run error = %v, want timeout", err) + } + + if got := stdout.String(); !strings.Contains(got, "disconnected, waiting to reconnect") { + t.Errorf("expected reconnect to start after the device node vanished on idle EOF, got %q", got) + } +} + +// TestRunSequenceLeadingSendThenExpect covers a tagged sequence that begins +// with a send (impossible with expect-send pairs): the module writes the +// leading input before any output arrives, then waits for the prompt and sends +// the reply. It ends on a send, so it drains before exiting. +func TestRunSequenceLeadingSendThenExpect(t *testing.T) { + port := &fakePort{} + port.queue([]byte("dut login: ")) // appears after the leading send goes out + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + start := time.Now() + + err := s.Run(ctx, sess, `send:\n`, "expect:login:", `send:root\n`) + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if elapsed := time.Since(start); elapsed < pairsDrain { + t.Errorf("Run returned after %s, expected at least the %s drain", elapsed, pairsDrain) + } + + if got := port.writtenString(); got != "\nroot\n" { + t.Errorf("written to port = %q, want %q", got, "\nroot\n") + } +} + +// TestRunSequenceSendOnly covers a sequence with no expect step at all: the +// module sends the input and drains briefly so the response is visible. +func TestRunSequenceSendOnly(t *testing.T) { + port := &fakePort{} + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := s.Run(ctx, sess, `send:reboot\n`) + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if got := port.writtenString(); got != "reboot\n" { + t.Errorf("written to port = %q, want %q", got, "reboot\n") + } +} + +// TestRunSequenceEndingOnExpectExitsWithoutDraining covers a sequence whose +// final step is an expect: the match is the completion, so the module exits +// immediately rather than draining. +func TestRunSequenceEndingOnExpectExitsWithoutDraining(t *testing.T) { + port := &fakePort{} + port.queue([]byte("starting\n")) + port.queue([]byte("ready> ")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, stdout, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + start := time.Now() + + err := s.Run(ctx, sess, `send:\n`, "expect:ready>") + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if elapsed := time.Since(start); elapsed >= pairsDrain { + t.Errorf("Run took %s; a sequence ending on expect should exit without the %s drain", elapsed, pairsDrain) + } + + if got := port.writtenString(); got != "\n" { + t.Errorf("written to port = %q, want %q", got, "\n") + } + + if got := stdout.String(); !strings.Contains(got, "Pattern matched") { + t.Errorf("stdout missing match banner: %q", got) + } +} diff --git a/pkg/module/serial/serial_test.go b/pkg/module/serial/serial_test.go new file mode 100644 index 0000000..7687934 --- /dev/null +++ b/pkg/module/serial/serial_test.go @@ -0,0 +1,755 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package serial + +import ( + "strings" + "testing" + "time" +) + +// TestFilterOutputCSI verifies that filterOutputCSI correctly passes through +// plain text and SGR sequences, drops all other CSI sequences, and stores +// incomplete sequences in the remainder for the next read. +// +//nolint:funlen +func TestFilterOutputCSI(t *testing.T) { + tests := []struct { + name string + input []byte + wantOutput []byte + wantRemainder []byte + }{ + { + name: "empty input", + input: []byte{}, + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "plain text without ESC", + input: []byte("hello world"), + wantOutput: []byte("hello world"), + wantRemainder: nil, + }, + { + name: "SGR reset ESC[m preserved", + input: []byte("\x1b[m"), + wantOutput: []byte("\x1b[m"), + wantRemainder: nil, + }, + { + name: "SGR foreground colour ESC[31m preserved", + input: []byte("\x1b[31m"), + wantOutput: []byte("\x1b[31m"), + wantRemainder: nil, + }, + { + name: "SGR multi-parameter ESC[0;31m preserved", + input: []byte("\x1b[0;31m"), + wantOutput: []byte("\x1b[0;31m"), + wantRemainder: nil, + }, + { + name: "DSR query ESC[6n dropped", + input: []byte("\x1b[6n"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "cursor position report ESC[10;20R dropped", + input: []byte("\x1b[10;20R"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "cursor-up ESC[2A dropped", + input: []byte("\x1b[2A"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "erase-display ESC[2J dropped", + input: []byte("\x1b[2J"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "text surrounding SGR preserved", + input: []byte("hello \x1b[31mworld"), + wantOutput: []byte("hello \x1b[31mworld"), + wantRemainder: nil, + }, + { + name: "non-SGR CSI in middle of text dropped", + input: []byte("hello \x1b[6n world"), + wantOutput: []byte("hello world"), + wantRemainder: nil, + }, + { + name: "mixed SGR and non-SGR: only SGR kept", + input: []byte("\x1b[31mcolor\x1b[6nquery\x1b[0mreset"), + wantOutput: []byte("\x1b[31mcolorquery\x1b[0mreset"), + wantRemainder: nil, + }, + { + name: "multiple consecutive SGR sequences preserved", + input: []byte("\x1b[31mred\x1b[0mreset"), + wantOutput: []byte("\x1b[31mred\x1b[0mreset"), + wantRemainder: nil, + }, + { + name: "lone ESC at end of buffer stored in remainder", + input: []byte("text\x1b"), + wantOutput: []byte("text"), + wantRemainder: []byte{escByte}, + }, + { + name: "incomplete CSI — ESC[ only — stored in remainder", + input: []byte("\x1b["), + wantOutput: []byte{}, + wantRemainder: []byte("\x1b["), + }, + { + name: "incomplete CSI with params stored in remainder", + input: []byte("text\x1b[31"), + wantOutput: []byte("text"), + wantRemainder: []byte("\x1b[31"), + }, + { + name: "ESC not followed by bracket emitted as-is", + input: []byte("\x1bO"), // SS3 — not a CSI sequence + wantOutput: []byte("\x1bO"), + wantRemainder: nil, + }, + { + name: "ESC not followed by bracket in middle of text", + input: []byte("ab\x1bOcd"), + wantOutput: []byte("ab\x1bOcd"), + wantRemainder: nil, + }, + { + name: "DCS terminated by ST dropped", + input: []byte("\x1bP1+r6E616D65\x1b\\"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + // The exact capability-report DCS that leaked from the BMC console. + name: "real WezTerm capability DCS dropped, surrounding text kept", + input: []byte("before\x1bP1+r6E616D65=57657A5465726D\x1b\\after"), + wantOutput: []byte("beforeafter"), + wantRemainder: nil, + }, + { + name: "OSC terminated by BEL dropped", + input: []byte("\x1b]0;window title\x07rest"), + wantOutput: []byte("rest"), + wantRemainder: nil, + }, + { + name: "OSC terminated by ST dropped", + input: []byte("\x1b]0;title\x1b\\rest"), + wantOutput: []byte("rest"), + wantRemainder: nil, + }, + { + name: "DCS single-byte ST (0x9c) dropped", + input: []byte("x\x1bPdata\x9cy"), + wantOutput: []byte("xy"), + wantRemainder: nil, + }, + { + name: "APC sequence dropped", + input: []byte("a\x1b_payload\x1b\\b"), + wantOutput: []byte("ab"), + wantRemainder: nil, + }, + { + name: "incomplete DCS stored in remainder", + input: []byte("text\x1bP1+r6E61"), + wantOutput: []byte("text"), + wantRemainder: []byte("\x1bP1+r6E61"), + }, + { + name: "DCS with ESC at end (partial ST) stored in remainder", + input: []byte("\x1bP1+rABC\x1b"), + wantOutput: []byte{}, + wantRemainder: []byte("\x1bP1+rABC\x1b"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var remainder []byte + + got := filterOutputCSI(tt.input, &remainder) + + if string(got) != string(tt.wantOutput) { + t.Errorf("output = %q, want %q", got, tt.wantOutput) + } + + if string(remainder) != string(tt.wantRemainder) { + t.Errorf("remainder = %q, want %q", remainder, tt.wantRemainder) + } + }) + } +} + +// TestFilterOutputCSISequenceSplitAcrossReads simulates an SGR sequence +// whose bytes arrive in two separate buffer reads, verifying that the +// remainder mechanism reconstitutes it correctly. +func TestFilterOutputCSISequenceSplitAcrossReads(t *testing.T) { + var remainder []byte + + // Read 1: "ESC[31" — missing the final byte 'm'. + out1 := filterOutputCSI([]byte("\x1b[31"), &remainder) + + if len(out1) != 0 { + t.Errorf("read 1: expected no output for incomplete sequence, got %q", out1) + } + + if string(remainder) != "\x1b[31" { + t.Errorf("read 1: remainder = %q, want %q", remainder, "\x1b[31") + } + + // Read 2: prepend remainder to the new chunk (mirroring the main loop). + chunk := append(remainder, 'm') + out2 := filterOutputCSI(chunk, &remainder) + + if string(out2) != "\x1b[31m" { + t.Errorf("read 2: output = %q, want %q", out2, "\x1b[31m") + } + + if len(remainder) != 0 { + t.Errorf("read 2: expected empty remainder, got %q", remainder) + } +} + +// TestFilterOutputCSILoneESCSplitAcrossReads simulates a lone ESC at the end +// of one read whose '[' and final byte arrive in the next read. +func TestFilterOutputCSILoneESCSplitAcrossReads(t *testing.T) { + var remainder []byte + + // Read 1: text followed by a lone ESC at the buffer boundary. + out1 := filterOutputCSI([]byte("text\x1b"), &remainder) + + if string(out1) != "text" { + t.Errorf("read 1: output = %q, want %q", out1, "text") + } + + if string(remainder) != "\x1b" { + t.Errorf("read 1: remainder = %q, want %q", remainder, "\x1b") + } + + // Read 2: the bytes "[A" arrive — together with the remainder this forms the + // cursor-up sequence ESC[A which must be dropped. + chunk := append(remainder, []byte("[A")...) + out2 := filterOutputCSI(chunk, &remainder) + + if len(out2) != 0 { + t.Errorf("read 2: cursor-up sequence should be dropped, got %q", out2) + } + + if len(remainder) != 0 { + t.Errorf("read 2: expected empty remainder, got %q", remainder) + } +} + +// TestFilterOutputCSIStringSeqSplitAcrossReads simulates a DCS sequence whose +// String Terminator arrives in a later read, verifying the remainder mechanism +// reconstitutes and drops the whole sequence. +func TestFilterOutputCSIStringSeqSplitAcrossReads(t *testing.T) { + var remainder []byte + + // Read 1: a DCS that is not yet terminated. + out1 := filterOutputCSI([]byte("keep\x1bP1+r41424"), &remainder) + + if string(out1) != "keep" { + t.Errorf("read 1: output = %q, want %q", out1, "keep") + } + + if string(remainder) != "\x1bP1+r41424" { + t.Errorf("read 1: remainder = %q, want %q", remainder, "\x1bP1+r41424") + } + + // Read 2: prepend remainder to the new chunk; the ST now completes the DCS, + // which must be dropped, leaving only the trailing text. + chunk := append(remainder, []byte("3\x1b\\done")...) + out2 := filterOutputCSI(chunk, &remainder) + + if string(out2) != "done" { + t.Errorf("read 2: output = %q, want %q", out2, "done") + } + + if len(remainder) != 0 { + t.Errorf("read 2: expected empty remainder, got %q", remainder) + } +} + +// TestEvalArgs covers argument parsing for the serial module. +func TestEvalArgs(t *testing.T) { + tests := []struct { + name string + args []string + wantTimeout time.Duration + wantPattern string // empty means expect should be nil + wantErr bool + }{ + { + name: "no args", + args: nil, + }, + { + name: "timeout flag only", + args: []string{"-t", "5s"}, + wantTimeout: 5 * time.Second, + }, + { + name: "pattern only", + args: []string{"login:"}, + wantPattern: "login:", + }, + { + name: "timeout and pattern", + args: []string{"-t", "2m", "hello world"}, + wantTimeout: 2 * time.Minute, + wantPattern: "hello world", + }, + { + name: "regex pattern with flags", + args: []string{`(?i)Login\s*:`}, + wantPattern: `(?i)Login\s*:`, + }, + { + name: "invalid regex", + args: []string{"[invalid"}, + wantErr: true, + }, + { + name: "unknown flag", + args: []string{"-x"}, + wantErr: true, + }, + { + name: "invalid timeout value", + args: []string{"-t", "not-a-duration"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Serial{} + + err := s.evalArgs(tt.args) + + if (err != nil) != tt.wantErr { + t.Fatalf("evalArgs() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + if s.timeout != tt.wantTimeout { + t.Errorf("timeout = %v, want %v", s.timeout, tt.wantTimeout) + } + + if tt.wantPattern == "" { + if len(s.steps) != 0 { + t.Errorf("steps = %v, want none (interactive)", s.steps) + } + } else { + if len(s.steps) != 1 || s.steps[0].kind != stepExpect { + t.Fatalf("steps = %v, want a single expect-step for %q", s.steps, tt.wantPattern) + } + + if s.steps[0].pattern.String() != tt.wantPattern { + t.Errorf("expect pattern = %q, want %q", s.steps[0].pattern.String(), tt.wantPattern) + } + } + }) + } +} + +// TestSerialInit covers Init validation and default-baud-rate assignment. +func TestSerialInit(t *testing.T) { + tests := []struct { + name string + serial Serial + wantBaud int + wantErr bool + }{ + { + name: "missing port returns error", + serial: Serial{}, + wantErr: true, + }, + { + name: "zero baud is replaced with DefaultBaudRate", + serial: Serial{Port: "/dev/ttyS0"}, + wantBaud: DefaultBaudRate, + }, + { + name: "explicit baud rate is preserved", + serial: Serial{Port: "/dev/ttyS0", Baud: 9600}, + wantBaud: 9600, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.serial.Init() + + if (err != nil) != tt.wantErr { + t.Fatalf("Init() error = %v, wantErr %v", err, tt.wantErr) + } + + if !tt.wantErr && tt.serial.Baud != tt.wantBaud { + t.Errorf("Baud = %d, want %d", tt.serial.Baud, tt.wantBaud) + } + }) + } +} + +// TestSerialHelp verifies that the help text includes the configured port, +// baud rate, and key usage information. +func TestSerialHelp(t *testing.T) { + tests := []struct { + name string + serial Serial + wantIn []string + }{ + { + name: "contains configured port and baud", + serial: Serial{Port: "/dev/ttyS0", Baud: 9600}, + wantIn: []string{"/dev/ttyS0", "9600"}, + }, + { + name: "contains timeout flag and regex mention", + serial: Serial{Port: "/dev/ttyUSB0", Baud: DefaultBaudRate}, + wantIn: []string{"-t", "expect", "regex"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + help := strings.ToLower(tt.serial.Help()) + + for _, want := range tt.wantIn { + if !strings.Contains(help, strings.ToLower(want)) { + t.Errorf("Help() missing %q", want) + } + } + }) + } +} + +// TestUnescape verifies that unescape correctly converts C-style escape sequences. +// +//nolint:funlen +func TestUnescape(t *testing.T) { + tests := []struct { + name string + input string + want []byte + }{ + { + name: "empty string", + input: "", + want: []byte{}, + }, + { + name: "plain text unchanged", + input: "hello", + want: []byte("hello"), + }, + { + name: "newline escape", + input: `\n`, + want: []byte{'\n'}, + }, + { + name: "carriage return escape", + input: `\r`, + want: []byte{'\r'}, + }, + { + name: "tab escape", + input: `\t`, + want: []byte{'\t'}, + }, + { + name: "backslash escape", + input: `\\`, + want: []byte{'\\'}, + }, + { + name: "hex escape lowercase", + input: `\x41`, + want: []byte{'A'}, + }, + { + name: "hex escape uppercase", + input: `\x4F`, + want: []byte{'O'}, + }, + { + name: "hex escape zero byte", + input: `\x00`, + want: []byte{0x00}, + }, + { + name: "multiple hex escapes", + input: `\x0d\x0a`, + want: []byte{'\r', '\n'}, + }, + { + name: "mixed text and escapes", + input: `user\r\n`, + want: []byte{'u', 's', 'e', 'r', '\r', '\n'}, + }, + { + name: "unknown escape emitted as-is", + input: `\q`, + want: []byte{'\\', 'q'}, + }, + { + name: "lone backslash at end emitted as-is", + input: `abc\`, + want: []byte{'a', 'b', 'c', '\\'}, + }, + { + name: "incomplete hex escape emitted as-is", + input: `\x`, + want: []byte{'\\', 'x'}, + }, + { + // \x requires exactly two hex digits; with only one the \x is + // emitted as-is and the remaining character follows normally. + name: "hex escape with one digit only emitted as-is", + input: `\x4`, + want: []byte{'\\', 'x', '4'}, + }, + { + // Non-hex chars after \x: the \x is emitted as-is and the + // non-hex characters are emitted as regular bytes. + name: "hex escape with non-hex chars emitted as-is", + input: `\xGG`, + want: []byte{'\\', 'x', 'G', 'G'}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := unescape(tt.input) + + if string(got) != string(tt.want) { + t.Errorf("unescape(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +// TestEvalArgsPairs covers the expect-send pair parsing mode of evalArgs. +func TestEvalArgsPairs(t *testing.T) { + tests := []struct { + name string + args []string + wantPairs int // expected number of pairs + wantPat0 string // pattern of first pair (empty = skip check) + wantResp0 string // response of first pair (empty = skip check) + wantPat1 string // pattern of second pair (empty = skip check) + wantResp1 string // response of second pair (empty = skip check) + wantErr bool + }{ + { + name: "one pair", + args: []string{"login:", "user\\n"}, + wantPairs: 1, + wantPat0: "login:", + wantResp0: "user\n", + }, + { + name: "two pairs", + args: []string{"login:", "user\\n", "Password:", "secret\\n"}, + wantPairs: 2, + wantPat0: "login:", + wantResp0: "user\n", + wantPat1: "Password:", + wantResp1: "secret\n", + }, + { + name: "pair with hex escape in response", + args: []string{"ready", "\\x0d"}, + wantPairs: 1, + wantPat0: "ready", + wantResp0: "\r", + }, + { + name: "odd number of args returns error", + args: []string{"login:", "user", "extra"}, + wantErr: true, + }, + { + name: "invalid regex in pair returns error", + args: []string{"[invalid", "response"}, + wantErr: true, + }, + { + name: "invalid regex in second pair returns error", + args: []string{"login:", "user\\n", "[bad", "pass"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Serial{} + + err := s.evalArgs(tt.args) + + if (err != nil) != tt.wantErr { + t.Fatalf("evalArgs() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + // Each expect-send pair compiles to two steps: an expect-step + // followed by a send-step. + if len(s.steps) != tt.wantPairs*pairStride { + t.Fatalf("len(steps) = %d, want %d (%d pairs)", len(s.steps), tt.wantPairs*pairStride, tt.wantPairs) + } + + if tt.wantPat0 != "" && s.steps[0].pattern.String() != tt.wantPat0 { + t.Errorf("steps[0].pattern = %q, want %q", s.steps[0].pattern.String(), tt.wantPat0) + } + + if tt.wantResp0 != "" && string(s.steps[1].data) != tt.wantResp0 { + t.Errorf("steps[1].data = %q, want %q", s.steps[1].data, tt.wantResp0) + } + + if tt.wantPat1 != "" && len(s.steps) > 2 && s.steps[2].pattern.String() != tt.wantPat1 { + t.Errorf("steps[2].pattern = %q, want %q", s.steps[2].pattern.String(), tt.wantPat1) + } + + if tt.wantResp1 != "" && len(s.steps) > 3 && string(s.steps[3].data) != tt.wantResp1 { + t.Errorf("steps[3].data = %q, want %q", s.steps[3].data, tt.wantResp1) + } + }) + } +} + +// wantStep is the expected kind and payload of a single parsed sequence step. +type wantStep struct { + expect string // non-empty => stepExpect with this pattern + send string // used when expect is empty => stepSend with this (unescaped) data +} + +// TestEvalArgsSequence covers the tagged expect:/send: sequence parsing mode. +// +//nolint:funlen +func TestEvalArgsSequence(t *testing.T) { + tests := []struct { + name string + args []string + want []wantStep + wantErr bool + }{ + { + name: "leading send then expect", + args: []string{`send:\n`, "expect:login:"}, + want: []wantStep{{send: "\n"}, {expect: "login:"}}, + }, + { + name: "expect then send", + args: []string{"expect:login:", `send:root\n`}, + want: []wantStep{{expect: "login:"}, {send: "root\n"}}, + }, + { + name: "two sends in a row", + args: []string{`send:a`, `send:b`}, + want: []wantStep{{send: "a"}, {send: "b"}}, + }, + { + name: "full login-and-reboot sequence", + args: []string{`send:\n`, "expect:login:", `send:root\n`, "expect:# ", `send:reboot\n`}, + want: []wantStep{ + {send: "\n"}, {expect: "login:"}, {send: "root\n"}, {expect: "# "}, {send: "reboot\n"}, + }, + }, + { + name: "send with hex escape", + args: []string{`send:\x03`}, // Ctrl-C + want: []wantStep{{send: "\x03"}}, + }, + { + name: "empty send is allowed", + args: []string{"expect:done", "send:"}, + want: []wantStep{{expect: "done"}, {send: ""}}, + }, + { + name: "invalid regex in expect step", + args: []string{"expect:[invalid", `send:x`}, + wantErr: true, + }, + { + name: "untagged argument mixed in is rejected", + args: []string{"expect:login:", "root"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Serial{} + + err := s.evalArgs(tt.args) + + if (err != nil) != tt.wantErr { + t.Fatalf("evalArgs() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + if len(s.steps) != len(tt.want) { + t.Fatalf("len(steps) = %d, want %d", len(s.steps), len(tt.want)) + } + + for i, w := range tt.want { + got := s.steps[i] + + if w.expect != "" { + if got.kind != stepExpect { + t.Errorf("steps[%d] kind = send, want expect %q", i, w.expect) + + continue + } + + if got.pattern.String() != w.expect { + t.Errorf("steps[%d].pattern = %q, want %q", i, got.pattern.String(), w.expect) + } + + continue + } + + if got.kind != stepSend { + t.Errorf("steps[%d] kind = expect, want send %q", i, w.send) + + continue + } + + if string(got.data) != w.send { + t.Errorf("steps[%d].data = %q, want %q", i, got.data, w.send) + } + } + }) + } +}