From bc79362ffa3af2d46cc3a07ef299221e3a8d9967 Mon Sep 17 00:00:00 2001 From: Jens Topp Date: Thu, 11 Jun 2026 12:17:21 +0200 Subject: [PATCH 1/3] Replace deprecated h2c package with net/http Protocols (#341) golang.org/x/net/http2/h2c is deprecated; unencrypted HTTP/2 is now supported directly by net/http via the Protocols field (Go 1.24+). - servers: replace h2c.NewHandler with an http.Server configured via Server.Protocols (HTTP/1 + unencrypted HTTP/2) - clients: replace http2.Transport (AllowHTTP/DialTLS) with http.Transport configured via Transport.Protocols - add ReadHeaderTimeout to the servers to mitigate Slowloris (gosec G112), removing the previous //nolint:gosec - drop the now-unused golang.org/x/net direct dependency (moved to indirect) Signed-off-by: Jens Topp --- cmds/dutagent/dutagent.go | 47 ++++++++++++++++++--------------- cmds/dutctl/dutctl.go | 27 ++++++++----------- cmds/exp/dutserver/dutserver.go | 23 ++++++++++------ cmds/exp/dutserver/rpc.go | 26 ++++++++---------- go.mod | 2 +- 5 files changed, 63 insertions(+), 62 deletions(-) diff --git a/cmds/dutagent/dutagent.go b/cmds/dutagent/dutagent.go index 506b01e4..631ee45d 100644 --- a/cmds/dutagent/dutagent.go +++ b/cmds/dutagent/dutagent.go @@ -9,25 +9,22 @@ package main import ( "context" - "crypto/tls" "errors" "flag" "fmt" "io" "log" - "net" "net/http" "os" "os/signal" "syscall" + "time" "connectrpc.com/connect" "github.com/BlindspotSoftware/dutctl/internal/buildinfo" "github.com/BlindspotSoftware/dutctl/internal/dutagent" "github.com/BlindspotSoftware/dutctl/pkg/dut" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "gopkg.in/yaml.v3" pb "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1" @@ -155,6 +152,9 @@ func printInitErr(err error) { log.Print(err) } +// readHeaderTimeout bounds how long the server waits to read request headers. +const readHeaderTimeout = 10 * time.Second + // startRPCService starts the RPC service, that ideally listens for incoming // connections forever. It always returns an non-nil error. func (agt *agent) startRPCService() error { @@ -167,12 +167,17 @@ func (agt *agent) startRPCService() error { path, handler := dutctlv1connect.NewDeviceServiceHandler(service) mux.Handle(path, handler) - //nolint:gosec - return http.ListenAndServe( - agt.address, - // Use h2c so we can serve HTTP/2 without TLS. - h2c.NewHandler(mux, &http2.Server{}), - ) + // Serve HTTP/2 without TLS (h2c) + srv := &http.Server{ + Addr: agt.address, + Handler: mux, + ReadHeaderTimeout: readHeaderTimeout, + } + srv.Protocols = new(http.Protocols) + srv.Protocols.SetHTTP1(true) + srv.Protocols.SetUnencryptedHTTP2(true) + + return srv.ListenAndServe() } func (agt *agent) registerWithServer() error { @@ -211,19 +216,17 @@ func spawnClient(agendURL string) dutctlv1connect.RelayServiceClient { // TODO: refactor into pkg and reuse in dutctl and dutserver. func newInsecureClient() *http.Client { + transport := &http.Transport{} + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // If you're also using this client for non-h2c traffic, you may want - // to delegate to tls.Dial if the network isn't TCP or the addr isn't - // in an allowlist. - - //nolint:noctx - return net.Dial(network, addr) - }, - // TODO: Don't forget timeouts! - }, + Transport: transport, + // TODO: Don't forget timeouts! http.Client.Timeout must not be used here: + // it bounds the entire exchange including the response body, which would + // abort long-lived streaming RPCs. Instead use per-RPC context deadlines + // on unary calls and/or transport timeouts (DialContext, + // TLSHandshakeTimeout, ResponseHeaderTimeout, IdleConnTimeout). } } diff --git a/cmds/dutctl/dutctl.go b/cmds/dutctl/dutctl.go index 3810743c..34d4d017 100644 --- a/cmds/dutctl/dutctl.go +++ b/cmds/dutctl/dutctl.go @@ -7,13 +7,11 @@ package main import ( - "crypto/tls" "errors" "flag" "fmt" "io" "log" - "net" "net/http" "os" @@ -22,7 +20,6 @@ import ( "github.com/BlindspotSoftware/dutctl/internal/output" "github.com/BlindspotSoftware/dutctl/pkg/lock" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" ) const usageAbstract = `dutctl - The client application of the DUT Control system. @@ -128,7 +125,6 @@ type application struct { func (app *application) setupRPCClient() { client := dutctlv1connect.NewDeviceServiceClient( - // Instead of http.DefaultClient, use the HTTP/2 protocol without TLS newInsecureClient(), fmt.Sprintf("http://%s", app.serverAddr), connect.WithGRPC(), @@ -138,19 +134,18 @@ func (app *application) setupRPCClient() { } func newInsecureClient() *http.Client { + // Use the HTTP/2 protocol without TLS (h2c) + transport := &http.Transport{} + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // If you're also using this client for non-h2c traffic, you may want - // to delegate to tls.Dial if the network isn't TCP or the addr isn't - // in an allowlist. - - //nolint:noctx - return net.Dial(network, addr) - }, - // Don't forget timeouts! - }, + Transport: transport, + // TODO: Don't forget timeouts! http.Client.Timeout must not be used here: + // it bounds the entire exchange including the response body, which would + // abort long-lived streaming RPCs. Instead use per-RPC context deadlines + // on unary calls and/or transport timeouts (DialContext, + // TLSHandshakeTimeout, ResponseHeaderTimeout, IdleConnTimeout). } } diff --git a/cmds/exp/dutserver/dutserver.go b/cmds/exp/dutserver/dutserver.go index ab9965b9..d029b0c1 100644 --- a/cmds/exp/dutserver/dutserver.go +++ b/cmds/exp/dutserver/dutserver.go @@ -14,10 +14,9 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) const ( @@ -76,6 +75,9 @@ func (svr *server) watchInterrupt() { }() } +// readHeaderTimeout bounds how long the server waits to read request headers. +const readHeaderTimeout = 10 * time.Second + // startRPCService starts the RPC service, that ideally listens for incoming // connections forever. It always returns an non-nil error. func (svr *server) startRPCService() error { @@ -94,12 +96,17 @@ func (svr *server) startRPCService() error { path, handler = dutctlv1connect.NewRelayServiceHandler(service) mux.Handle(path, handler) - //nolint:gosec - return http.ListenAndServe( - svr.address, - // Use h2c so we can serve HTTP/2 without TLS. - h2c.NewHandler(mux, &http2.Server{}), - ) + // Serve HTTP/2 without TLS (h2c) + srv := &http.Server{ + Addr: svr.address, + Handler: mux, + ReadHeaderTimeout: readHeaderTimeout, + } + srv.Protocols = new(http.Protocols) + srv.Protocols.SetHTTP1(true) + srv.Protocols.SetUnencryptedHTTP2(true) + + return srv.ListenAndServe() } // start orchestrates the dutagent execution. diff --git a/cmds/exp/dutserver/rpc.go b/cmds/exp/dutserver/rpc.go index c13cf901..4692684f 100644 --- a/cmds/exp/dutserver/rpc.go +++ b/cmds/exp/dutserver/rpc.go @@ -6,13 +6,11 @@ package main import ( "context" - "crypto/tls" "errors" "fmt" "io" "log" "maps" - "net" "net/http" "slices" "sync" @@ -20,7 +18,6 @@ import ( "connectrpc.com/connect" "github.com/BlindspotSoftware/dutctl/pkg/lock" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" pb "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1" ) @@ -390,19 +387,18 @@ func spawnClient(agendURL string) dutctlv1connect.DeviceServiceClient { // TODO: refactor into pkg and reuse in dutctl and dutserver. func newInsecureClient() *http.Client { + // Use the HTTP/2 protocol without TLS (h2c). + transport := &http.Transport{} + transport.Protocols = new(http.Protocols) + transport.Protocols.SetUnencryptedHTTP2(true) + return &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // If you're also using this client for non-h2c traffic, you may want - // to delegate to tls.Dial if the network isn't TCP or the addr isn't - // in an allowlist. - - //nolint:noctx - return net.Dial(network, addr) - }, - // TODO: Don't forget timeouts! - }, + Transport: transport, + // TODO: Don't forget timeouts! http.Client.Timeout must not be used here: + // it bounds the entire exchange including the response body, which would + // abort long-lived streaming RPCs. Instead use per-RPC context deadlines + // on unary calls and/or transport timeouts (DialContext, + // TLSHandshakeTimeout, ResponseHeaderTimeout, IdleConnTimeout). } } diff --git a/go.mod b/go.mod index ba1890b8..113c51bf 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,6 @@ require ( github.com/stianeikeland/go-rpio/v4 v4.6.0 github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 golang.org/x/crypto v0.51.0 - golang.org/x/net v0.53.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 ) @@ -32,6 +31,7 @@ require ( github.com/olekukonko/tablewriter v1.0.9 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/rogpeppe/go-internal v1.6.1 // indirect + golang.org/x/net v0.53.0 // indirect golang.org/x/sys v0.44.0 // indirect golang.org/x/text v0.37.0 // indirect ) From 416a8d169880883345531a2f67ec71a65f41b474 Mon Sep 17 00:00:00 2001 From: Fabian Wienand Date: Wed, 10 Jun 2026 13:30:42 +0200 Subject: [PATCH 2/3] fix: allow command templating without a passthrough module Signed-off-by: Fabian Wienand --- pkg/dut/dut.go | 13 ++++++++----- pkg/dut/dut_test.go | 6 +++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pkg/dut/dut.go b/pkg/dut/dut.go index 29337683..6ff8c3b5 100644 --- a/pkg/dut/dut.go +++ b/pkg/dut/dut.go @@ -15,9 +15,9 @@ import ( ) var ( - ErrDeviceNotFound = errors.New("no such device") - ErrCommandNotFound = errors.New("no such command") - ErrNoPassthroughForArgs = errors.New("arguments provided but command has no passthrough module to receive them") + ErrDeviceNotFound = errors.New("no such device") + ErrCommandNotFound = errors.New("no such command") + ErrNoReceiverForArgs = errors.New("arguments provided but command has neither a passthrough module nor declared arguments to receive them") ) // Devlist is a list of devices-under-test. @@ -124,8 +124,11 @@ func (c *Command) countPassthrough() int { // receive their statically configured Args with template references substituted // using runtimeArgs. The returned slice has the same length and ordering as c.Modules. func (c *Command) ModuleArgs(runtimeArgs []string) ([][]string, error) { - if len(runtimeArgs) > 0 && !c.HasPassthrough() { - return nil, ErrNoPassthroughForArgs + // Runtime args may be consumed either by a passthrough module or by + // command-level templating (declared c.Args substituted via ${name}). + // Only reject when neither can receive them. + if len(runtimeArgs) > 0 && !c.HasPassthrough() && len(c.Args) == 0 { + return nil, ErrNoReceiverForArgs } result := make([][]string, len(c.Modules)) diff --git a/pkg/dut/dut_test.go b/pkg/dut/dut_test.go index 918d26a1..8260ff1a 100644 --- a/pkg/dut/dut_test.go +++ b/pkg/dut/dut_test.go @@ -191,7 +191,7 @@ func TestModuleArgs(t *testing.T) { }}, runtimeArgs: []string{"a"}, want: nil, - err: ErrNoPassthroughForArgs, + err: ErrNoReceiverForArgs, }, { name: "mixed passthrough and non-passthrough", @@ -222,7 +222,7 @@ func TestModuleArgs(t *testing.T) { cmd: Command{}, runtimeArgs: []string{"a"}, want: nil, - err: ErrNoPassthroughForArgs, + err: ErrNoReceiverForArgs, }, { name: "error when runtime args provided but no passthrough module", @@ -231,7 +231,7 @@ func TestModuleArgs(t *testing.T) { }}, runtimeArgs: []string{"a"}, want: nil, - err: ErrNoPassthroughForArgs, + err: ErrNoReceiverForArgs, }, { name: "passthrough module with no runtime args", From 5dea46f3e69237535e2f0cc6ca64fb0fe75bd51f Mon Sep 17 00:00:00 2001 From: llogen Date: Mon, 22 Jun 2026 15:40:07 +0200 Subject: [PATCH 3/3] feat: serial expect/send automation sequences MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The serial module gains expect-send and arbitrarily ordered tagged expect:/send: sequences, so a test can drive a DUT hands-off — wait for a prompt, send input, in any order. Matching uses a rolling window (so prompts without a trailing newline match immediately) and non-SGR terminal escape sequences are stripped from the output. Signed-off-by: llogen --- pkg/module/serial/serial.go | 721 +++++++++++++++++++++--- pkg/module/serial/serial_run_test.go | 458 ++++++++++++++++ pkg/module/serial/serial_test.go | 789 +++++++++++++++++++++++++++ 3 files changed, 1896 insertions(+), 72 deletions(-) create mode 100644 pkg/module/serial/serial_run_test.go create mode 100644 pkg/module/serial/serial_test.go diff --git a/pkg/module/serial/serial.go b/pkg/module/serial/serial.go index 75039d9c..00bcb081 100644 --- a/pkg/module/serial/serial.go +++ b/pkg/module/serial/serial.go @@ -6,15 +6,14 @@ package serial import ( - "bytes" "context" - "errors" "flag" "fmt" "io" "log" "regexp" "strings" + "sync" "time" "github.com/BlindspotSoftware/dutctl/pkg/module" @@ -31,14 +30,63 @@ func init() { // DefaultBaudRate is the default baud rate for the serial connection. const DefaultBaudRate = 115200 -// Serial is a module that forwards the serial output of a connected DUT to the dutctl client. -// It is non-interactive and does not support stdin yet. +// sendDrain is the time the module keeps reading serial output after the last +// send-step of a sequence, so the output triggered by that final input is +// visible before the connection closes. +const sendDrain = time.Second + +// defaultDelay is the pause applied before each send when no Delay is +// configured. A small pause makes sends robust against consoles that present a +// prompt slightly before the tty is ready to read; configure Delay as "0s" to +// disable it. +const defaultDelay = 50 * time.Millisecond + +// maxMatchWindow bounds the buffer of recent serial output kept for regex +// matching. Matching is always done against the tail of the output, so old +// bytes can be discarded; the cap keeps memory and regex cost bounded while +// staying far larger than any realistic prompt or expect pattern. +const maxMatchWindow = 64 * 1024 + +// stepKind distinguishes the two kinds of step in a serial sequence. +type stepKind int + +const ( + stepExpect stepKind = iota // wait until pattern matches the serial output + stepSend // write data to the serial port +) + +// seqStep is one step of a serial automation sequence. The module walks the +// steps in order: it waits for each expect-step's pattern to appear in the +// output, and writes each send-step's data to the port. A single expect +// pattern and the legacy expect-send pairs both compile down to a slice of +// these, so the run loop has exactly one code path for all non-interactive +// modes. +type seqStep struct { + kind stepKind + pattern *regexp.Regexp // set when kind == stepExpect + data []byte // set when kind == stepSend (empty = send nothing) +} + +// serialPort is the subset of the serial device used by Run. It is satisfied by +// *serial.Port and allows a fake to be injected in tests via Serial.dialPort. +type serialPort interface { + io.ReadWriteCloser + Flush() error +} + +// Serial is a module that automates interaction with a DUT's serial console. type Serial struct { - Port string // Port is the path to the serial device on the dutagent. - Baud int // Baud is the baud rate of the serial device. Is unset, DefaultBaudRate is used. + Port string // Port is the path to the serial device on the dutagent. + Baud int // Baud is the baud rate of the serial device. If unset, DefaultBaudRate is used. + Delay string // Delay is the pause before each send (e.g. "200ms") to pace input. Default: 50ms; "0s" disables. - expect *regexp.Regexp // expect is a pattern to match against the serial output. - timeout time.Duration // timeout is the maximum time to wait for the expect pattern to match. + steps []seqStep // steps is the ordered expect/send sequence (empty = monitor mode). + timeout time.Duration // timeout is the maximum time to wait for the sequence to complete. + delay time.Duration // delay is the parsed Delay, applied before each send-step. + + // dialPort opens the serial port. It defaults to openPort and is overridden + // in tests to inject a fake. Set lazily in Run so the zero value works. + dialPort func() (serialPort, error) } // Ensure implementing the Module interface. @@ -49,18 +97,39 @@ const abstract = `Serial connection to the DUT const usage = ` ARGUMENTS: - [-t ] [] + [-t ] [ [ ...]] + [-t ] expect:|send: [expect:|send: ...] ` const description = ` -The serial connection is read-only and does not support stdin yet. -If a regex is provided, the module will wait for the regex to match on the serial output, -then exit with success. If no expect string is provided, the module will read from the serial port -until it is terminated by a signal (e.g. Ctrl-C). -The expect string supports regular expressions according to [1]. -The optional -t flag specifies the maximum time to wait for the regex to match. -Quote the expect string if it contains spaces or special characters. E.g.: "(?i)hello\s+world!? dutctl" +The serial module automates interaction with the DUT's serial console: it reads +the serial output, optionally matches it against expect patterns, and writes +responses back to the port. + +Modes of operation: + - Monitor (no arguments): stream the serial output until the session is cancelled. + - Expect (1 argument): wait for the regex to match on the serial output, then exit. + - Expect-send (even number of arguments >= 2): pass pattern/response pairs. + For each pair, the module waits for the pattern to match and then sends the + response to the serial port. Pairs are processed in order; after the last + pair matches the module reads serial output for 1 more second so the output + of the triggered command is visible, then exits. + - Sequence (every argument carries an "expect:" or "send:" tag): an ordered + list of steps run one after another. An expect-step waits for its regex to + match the serial output; a send-step writes its data to the port. Unlike + expect-send pairs, the steps may appear in any order, so a sequence can + begin with a send (e.g. an Enter to wake the console) or chain several + sends or expects in a row. The whole sequence shares the -t deadline. If + the last step is a send, the module drains output for 1 more second before + exiting; if it is an expect, it exits as soon as that pattern matches. + + e.g.: send:"\n" expect:"login:" send:"root\n" expect:"# " send:"reboot\n" + +The expect string supports regular expressions according to [1]. +The optional -t flag specifies the maximum time to wait. +Quote strings containing spaces or special characters. E.g.: "(?i)hello\s+world" +Response and send strings support C-style escape sequences: \n, \r, \t, \\, \xHH. [1] https://golang.org/s/re2syntax. ` @@ -71,7 +140,15 @@ func (s *Serial) Help() string { help := strings.Builder{} help.WriteString(abstract) help.WriteString(usage) - help.WriteString(fmt.Sprintf("Configured COM port is %q with baud rate %d.\n", s.Port, s.Baud)) + fmt.Fprintf(&help, "Configured COM port is %q with baud rate %d.\n", s.Port, s.Baud) + + delayDesc := s.Delay + if delayDesc == "" { + delayDesc = defaultDelay.String() + } + + fmt.Fprintf(&help, "A delay of %s is applied before each send.\n", delayDesc) + help.WriteString(description) return help.String() @@ -88,6 +165,17 @@ func (s *Serial) Init() error { s.Baud = DefaultBaudRate } + s.delay = defaultDelay + + if s.Delay != "" { + parsed, err := time.ParseDuration(s.Delay) + if err != nil { + return fmt.Errorf("invalid delay %q: %w", s.Delay, err) + } + + s.delay = parsed + } + // Note: We don't open the port here to allow dutagent to start // even if the serial device is not yet available (e.g., powered off). // The port will be opened when Run() is called. @@ -101,7 +189,21 @@ func (s *Serial) Deinit() error { return nil } -//nolint:cyclop,funlen,gocognit +// Run bridges the DUT's serial port to the client. It reads serial output, +// streams it to the client, and drives the expect/send sequence by matching +// patterns against the output and writing send-step responses to the port. +// +// Concurrency model (deliberately minimal to be race- and deadlock-free): +// +// - The main loop is the sole owner of the port and all match state (steps, +// draining, the match window, and the CSI remainder). Nothing else touches +// them, so no lock is needed. +// - One "stdout pump" goroutine performs every client write, so the main loop +// can keep matching and draining without blocking on a slow client, and so +// a vanished client can never wedge the main loop (writes abort on context +// cancellation). +// +//nolint:gocognit,cyclop,funlen,gocyclo,maintidx func (s *Serial) Run(ctx context.Context, session module.Session, args ...string) error { log.Println("serial module: Run called") @@ -110,94 +212,286 @@ func (s *Serial) Run(ctx context.Context, session module.Session, args ...string return err } - port, err := s.openPort() + dial := s.dialPort + if dial == nil { + dial = s.openPort + } + + port, err := dial() if err != nil { return err } - defer port.Close() + + defer func() { _ = port.Close() }() // Discard any stale bytes left in the kernel/driver RX buffer from a // previous session, otherwise the user sees data from the last boot. - err = port.Flush() - if err != nil { - log.Printf("serial module: flush failed: %v", err) + flushErr := port.Flush() + if flushErr != nil { + log.Printf("serial module: flush failed: %v", flushErr) } + // stdin is unused: this module never forwards client input; it only reads the + // port, matches, and writes send-step responses. stderr is unused too. + _, stdout, _ := session.Console() + log.Printf("serial module: connected to %s at %d baud", s.Port, s.Baud) - session.Print(fmt.Sprintf("--- Connected to %s at %d baud ---\n", s.Port, s.Baud)) + fmt.Fprintf(stdout, "--- Connected to %s at %d baud ---\n", s.Port, s.Baud) - var cancel context.CancelFunc + emit, flushAndWait := newStdoutPump(ctx, stdout) + defer flushAndWait(false) // ensure the pump goroutine is released on every exit path + + // loopCtx carries the active deadline: the expect timeout first, then the + // post-send drain. It is always derived from ctx, so external cancellation + // still propagates through it. + loopCtx := ctx + + var loopCancel context.CancelFunc if s.timeout > 0 { log.Printf("serial module: setting timeout of %s", s.timeout) - ctx, cancel = context.WithTimeout(ctx, s.timeout) + loopCtx, loopCancel = context.WithTimeout(ctx, s.timeout) + } + + defer func() { + if loopCancel != nil { + loopCancel() + } + }() + + var ( + remainder []byte // partial CSI sequence carried across reads + matchWindow []byte // bounded window of recent output for regex matching + currentStep int // cursor into s.steps + draining bool + ) + + //nolint:fatcontext // intentional: the post-send drain replaces the loop deadline + startDrain := func() { + if loopCancel != nil { + loopCancel() + } - defer cancel() + loopCtx, loopCancel = context.WithTimeout(ctx, sendDrain) + draining = true + + log.Printf("serial module: draining serial output for %s before closing", sendDrain) + } + + // advanceSends fires consecutive send-steps at the cursor, writing each to + // the port, until the cursor reaches an expect-step or the end of the + // sequence. It is called once before the first read (so a sequence may begin + // with a send) and again after every expect-step matches. + // + // When a delay is configured it pauses before every send to pace input — both + // between back-to-back sends and after a prompt match (some consoles drop + // characters sent the instant the prompt appears). The pause is interruptible + // so an expired deadline or a cancelled session still ends the run promptly. + advanceSends := func() { + for currentStep < len(s.steps) && s.steps[currentStep].kind == stepSend { + if s.delay > 0 { + select { + case <-time.After(s.delay): + case <-loopCtx.Done(): + return + } + } + + data := s.steps[currentStep].data + currentStep++ + + if len(data) > 0 { + _, werr := port.Write(data) + if werr != nil { + log.Printf("serial module: error writing to serial port: %v", werr) + } + } + } } const bufferSize = 4096 readBuffer := make([]byte, bufferSize) - lineBuffer := &bytes.Buffer{} + + // Fire any leading send-steps before the first read, so a sequence may begin + // by sending (e.g. an Enter to wake the console) rather than expecting. If + // the sequence is sends only, drain briefly so the DUT's response to the + // final input is visible, then exit via the drain deadline below. + advanceSends() + + if len(s.steps) > 0 && currentStep >= len(s.steps) { + startDrain() + } for { select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - session.Print("\n--- Timeout reached, no match found ---\n") + case <-loopCtx.Done(): + if ctx.Err() != nil { + // External cancellation: client disconnect or agent teardown. + // The client may no longer be reading, so do not wait on a flush. + log.Println("serial module: context cancelled, closing") - return fmt.Errorf("timeout of %s reached, pattern %q not found", s.timeout, s.expect) + return ctx.Err() } - session.Print("\n--- Connection closed ---\n") + // One of our own deadlines fired. + if draining { + emit([]byte("\n--- Sequence complete, connection closed ---")) + flushAndWait(true) - return ctx.Err() - default: - sbytes, err := port.Read(readBuffer) - if err != nil { - // Ignore timeout errors as these are expected with the read timeout config - if err != io.EOF && !strings.Contains(err.Error(), "timeout") { - return fmt.Errorf("error reading from serial port: %w", err) - } + return nil + } - continue + emit([]byte("\n--- Timeout reached, no match found ---")) + flushAndWait(true) + + if len(s.steps) == 1 && s.steps[0].kind == stepExpect { + return fmt.Errorf("timeout of %s reached, pattern %q not found", s.timeout, s.steps[0].pattern) } - if sbytes == 0 { + return fmt.Errorf("timeout of %s reached, sequence not completed", s.timeout) + default: + } + + nRead, readErr := port.Read(readBuffer) + if readErr != nil { + // With VMIN=0 an idle read returns io.EOF, and some platforms report a + // "timeout" error — both just mean "no data yet", so keep looping and + // stay responsive to ctx. Any other error means the device failed. + if readErr == io.EOF || strings.Contains(readErr.Error(), "timeout") { continue } - // Forward bytes to the client immediately so partial lines are - // not held back waiting for a newline from the DUT. - session.Print(string(readBuffer[:sbytes])) + return fmt.Errorf("serial read error: %w", readErr) + } - if s.expect == nil { - continue + if nRead == 0 { + continue + } + + // Filter cursor/query CSI sequences (SGR colour is preserved), + // reconstructing sequences split across reads via remainder. + chunk := readBuffer[:nRead] + if len(remainder) > 0 { + chunk = append(remainder, chunk...) + remainder = nil + } + + out := filterOutputCSI(chunk, &remainder) + if len(out) == 0 { + continue + } + + // Display everything immediately, including partial lines (e.g. prompts + // without a trailing newline). out is a fresh slice owned by the pump. + emit(out) + + // Matching is skipped while draining and in monitor mode (no steps). + if draining || len(s.steps) == 0 { + continue + } + + matchWindow = append(matchWindow, out...) + if len(matchWindow) > maxMatchWindow { + // Copy the tail to the front of the same backing array rather than + // resliding the start forward, which would leak the discarded head + // until the slice grew enough to force a reallocation. This keeps the + // backing array bounded at ~maxMatchWindow, as the cap promises. + matchWindow = append(matchWindow[:0], matchWindow[len(matchWindow)-maxMatchWindow:]...) + } + + // Walk the sequence: satisfy every expect-step the current window already + // matches, firing the send-steps that follow each match. The cursor rests + // on an expect-step here, because advanceSends consumed any leading or + // trailing sends after the previous iteration. + for currentStep < len(s.steps) && s.steps[currentStep].kind == stepExpect { + loc := s.steps[currentStep].pattern.FindIndex(matchWindow) + if loc == nil { + break } - // Process the data read character by character - for i := range sbytes { - b := readBuffer[i] - lineBuffer.WriteByte(b) + matchWindow = matchWindow[loc[1]:] // consume through the match + currentStep++ - // If we reach a newline or a buffer limit, process the line - if b == '\n' || lineBuffer.Len() >= 1024 { - line := lineBuffer.String() + advanceSends() // fire the send-steps that follow this match - // Check for regex match if we have one - if s.expect.MatchString(line) { - session.Print("\n--- Pattern matched, connection closed ---\n") + if currentStep < len(s.steps) { + continue // more steps remain; keep matching the current window + } - return nil // Success - pattern found - } + // Sequence complete. If it ended on a send, drain so the DUT's + // response to the final input is visible; if it ended on an expect, + // the match itself is the completion, so exit immediately. + if s.steps[len(s.steps)-1].kind == stepSend { + startDrain() + } else { + emit([]byte("\n--- Pattern matched, connection closed ---")) + flushAndWait(true) - lineBuffer.Reset() - } + return nil } + + break } } } +const stdoutQueueLen = 256 + +// newStdoutPump starts a single goroutine that performs every client write and +// returns two closures: +// +// - emit(data) queues data for the client. It never blocks the caller +// indefinitely: if baseCtx is cancelled (client gone) the data is dropped +// rather than wedging the main loop. The caller must not reuse data +// afterwards; pass a fresh slice. +// - flushAndWait(deliver) shuts the pump down. With deliver=true it waits for +// all queued data to be written (safe only while the client is still +// reading, i.e. on graceful completion). With deliver=false it just releases +// the pump without waiting; it is idempotent and safe to defer on every path. +func newStdoutPump(baseCtx context.Context, stdout io.Writer) (func([]byte), func(bool)) { + outCh := make(chan []byte, stdoutQueueLen) + pumpDone := make(chan struct{}) + + go func() { + defer close(pumpDone) + + for data := range outCh { + _, _ = stdout.Write(data) // ChanWriter.Write only fails on misuse + } + }() + + var closeOnce sync.Once + + closeOut := func() { closeOnce.Do(func() { close(outCh) }) } + + emit := func(data []byte) { + select { + case outCh <- data: + case <-baseCtx.Done(): // client gone — drop rather than block forever. + } + } + + flushAndWait := func(deliver bool) { + closeOut() + + if deliver { + <-pumpDone + } + } + + return emit, flushAndWait +} + +// pairStride is the number of positional arguments per expect-send pair. +const pairStride = 2 + +// Tag prefixes that mark a tagged-sequence argument. +const ( + expectTag = "expect:" + sendTag = "send:" +) + func (s *Serial) evalArgs(args []string) error { fs := flag.NewFlagSet("serial", flag.ContinueOnError) fs.SetOutput(io.Discard) // Suppress default error output @@ -208,17 +502,93 @@ func (s *Serial) evalArgs(args []string) error { return fmt.Errorf("failed to parse arguments: %w", err) } - // Reset expect so a previous pattern does not carry over into the next run. - s.expect = nil + positional := fs.Args() - // Get the expect string if provided (args after flags) - if fs.NArg() > 0 { - expectPattern := fs.Arg(0) - log.Printf("serial module: Will wait for pattern: %q", expectPattern) + // Tagged-sequence mode is selected when the first argument carries a tag. + // It allows arbitrarily ordered expect/send steps (e.g. a leading send). + if len(positional) > 0 && isTaggedStep(positional[0]) { + return s.evalSequenceArgs(positional) + } - s.expect, err = regexp.Compile(expectPattern) - if err != nil { - return fmt.Errorf("invalid regular expression: %w", err) + return s.evalLegacyArgs(positional) +} + +// evalLegacyArgs parses the backward-compatible positional argument forms into +// s.steps: no args (interactive), one arg (single expect), or an even number of +// args (expect-send pairs). +func (s *Serial) evalLegacyArgs(positional []string) error { + switch len(positional) { + case 0: + // Interactive mode: no steps. + case 1: + // Single-expect mode. + log.Printf("serial module: Will wait for pattern: %q", positional[0]) + + pattern, compileErr := regexp.Compile(positional[0]) + if compileErr != nil { + return fmt.Errorf("invalid regular expression: %w", compileErr) + } + + s.steps = []seqStep{{kind: stepExpect, pattern: pattern}} + default: + // Expect-send pairs mode. + if len(positional)%pairStride != 0 { + return fmt.Errorf("expect-send requires an even number of arguments, got %d", len(positional)) + } + + s.steps = make([]seqStep, 0, len(positional)) + + for idx := 0; idx < len(positional); idx += pairStride { + pattern, compileErr := regexp.Compile(positional[idx]) + if compileErr != nil { + return fmt.Errorf("invalid regular expression %q: %w", positional[idx], compileErr) + } + + log.Printf("serial module: Pair %d: pattern=%q response=%q", idx/pairStride+1, positional[idx], positional[idx+1]) + + s.steps = append(s.steps, + seqStep{kind: stepExpect, pattern: pattern}, + seqStep{kind: stepSend, data: unescape(positional[idx+1])}, + ) + } + } + + return nil +} + +// isTaggedStep reports whether arg carries an "expect:" or "send:" tag. +func isTaggedStep(arg string) bool { + return strings.HasPrefix(arg, expectTag) || strings.HasPrefix(arg, sendTag) +} + +// evalSequenceArgs parses tagged-sequence arguments into s.steps. Every +// argument must carry a tag; mixing tagged and untagged arguments is rejected +// so a malformed command fails loudly instead of being silently misread. +func (s *Serial) evalSequenceArgs(args []string) error { + s.steps = make([]seqStep, 0, len(args)) + + for idx, arg := range args { + switch { + case strings.HasPrefix(arg, expectTag): + expr := arg[len(expectTag):] + + pattern, compileErr := regexp.Compile(expr) + if compileErr != nil { + return fmt.Errorf("step %d: invalid regular expression %q: %w", idx+1, expr, compileErr) + } + + log.Printf("serial module: Step %d: expect=%q", idx+1, expr) + + s.steps = append(s.steps, seqStep{kind: stepExpect, pattern: pattern}) + case strings.HasPrefix(arg, sendTag): + data := arg[len(sendTag):] + + log.Printf("serial module: Step %d: send=%q", idx+1, data) + + s.steps = append(s.steps, seqStep{kind: stepSend, data: unescape(data)}) + default: + return fmt.Errorf("step %d %q: in sequence mode every argument must start with %q or %q", + idx+1, arg, expectTag, sendTag) } } @@ -227,7 +597,8 @@ func (s *Serial) evalArgs(args []string) error { const readTimeout = 100 * time.Millisecond -func (s *Serial) openPort() (*serial.Port, error) { +//nolint:ireturn // intentional: returns the serialPort interface so a fake can be injected in tests +func (s *Serial) openPort() (serialPort, error) { config := &serial.Config{ Name: s.Port, Baud: s.Baud, @@ -241,3 +612,209 @@ func (s *Serial) openPort() (*serial.Port, error) { return port, nil } + +// escByte is the ASCII escape character that starts ANSI/VT escape sequences. +const escByte = 0x1b + +// csiPrefixLen is the length of the CSI prefix "ESC[". +const csiPrefixLen = 2 + +// Bytes that terminate an ANSI string sequence. +const ( + stByte = 0x9c // single-byte String Terminator (C1) + belByte = 0x07 // BEL — also terminates an OSC sequence +) + +// isStringSeqIntroducer reports whether b, following ESC, begins an ANSI string +// sequence: DCS (P), OSC (]), SOS (X), PM (^) or APC (_). These carry terminal +// queries and responses (e.g. capability reports like the WezTerm "name=" DCS), +// never display content, so they are dropped in full. +func isStringSeqIntroducer(b byte) bool { + return b == 'P' || b == ']' || b == 'X' || b == '^' || b == '_' +} + +// stringSeqEnd returns the index of the last byte of the ANSI string sequence +// that starts at data[escIdx] (with introducer data[escIdx+1]), or -1 if the +// sequence is not terminated within data — in which case the caller carries it +// to the next read. The sequence ends at ST ("ESC \" or the single-byte 0x9c); +// an OSC may also end at BEL. +func stringSeqEnd(data []byte, escIdx int) int { + osc := data[escIdx+1] == ']' + + for pos := escIdx + csiPrefixLen; pos < len(data); pos++ { + switch { + case data[pos] == stByte: + return pos + case osc && data[pos] == belByte: + return pos + case data[pos] == escByte: + if pos+1 >= len(data) { + return -1 // ESC at end — the ST may be split across reads + } + + if data[pos+1] == '\\' { + return pos + 1 // "ESC \" = ST + } + // ESC followed by anything else is not a terminator; keep scanning. + } + } + + return -1 +} + +// filterOutputCSI removes terminal control sequences from serial output data, +// except for SGR (Select Graphic Rendition, final byte 'm') which handles colors +// and styles. It drops CSI sequences (cursor positioning, screen clearing, +// queries) and ANSI string sequences (DCS/OSC/SOS/PM/APC terminal query +// responses), while preserving coloured output. +// +// Incomplete sequences at the end of data are stored in remainder so they can be +// prepended to the next buffer read and reconstituted correctly. +// +//nolint:cyclop,varnamelen +func filterOutputCSI(data []byte, remainder *[]byte) []byte { + result := make([]byte, 0, len(data)) + *remainder = nil + + for i := 0; i < len(data); i++ { + if data[i] != escByte { + result = append(result, data[i]) + + continue + } + + // ESC at end of buffer: might be the start of a sequence split across reads. + if i+1 >= len(data) { + *remainder = []byte{escByte} + + break + } + + if isStringSeqIntroducer(data[i+1]) { + end := stringSeqEnd(data, i) + if end < 0 { + // Incomplete string sequence — carry it over to the next read. + *remainder = make([]byte, len(data)-i) + copy(*remainder, data[i:]) + + break + } + + // Drop the whole sequence; the outer loop's i++ advances past data[end]. + i = end + + continue + } + + if data[i+1] != '[' { + // ESC not followed by '[' or a string introducer — emit as-is. + result = append(result, data[i]) + + continue + } + + // CSI sequence: ESC [ + // Find the extent of this CSI sequence. + j := i + csiPrefixLen + + for j < len(data) && data[j] >= 0x30 && data[j] <= 0x3f { + j++ // parameter bytes + } + + for j < len(data) && data[j] >= 0x20 && data[j] <= 0x2f { + j++ // intermediate bytes + } + + if j >= len(data) { + // Incomplete sequence at end of buffer — carry it over to the next read. + *remainder = make([]byte, len(data)-i) + copy(*remainder, data[i:]) + + break + } + + if data[j] == 'm' { + // SGR (colors/styles) — keep it. + result = append(result, data[i:j+1]...) + } + + // All other CSI sequences are dropped, including malformed ones where + // data[j] is not a valid final byte (0x40–0x7E). The byte at data[j] + // is consumed by setting i = j; the outer loop's i++ then advances past + // it. Silently dropping malformed sequences is safer than emitting + // partial escape bytes which could corrupt the terminal display. + i = j + } + + return result +} + +// unescape converts C-style escape sequences in s to their byte equivalents. +// Supported sequences: \n (newline), \r (carriage return), \t (tab), +// \\ (backslash), \xHH (hex byte). Unrecognised sequences are emitted as-is. +// +//nolint:cyclop // a flat switch over the supported escapes; splitting it would not help readability +func unescape(s string) []byte { + out := make([]byte, 0, len(s)) + + for idx := 0; idx < len(s); idx++ { + if s[idx] != '\\' || idx+1 >= len(s) { + out = append(out, s[idx]) + + continue + } + + idx++ + + switch s[idx] { + case 'n': + out = append(out, '\n') + case 'r': + out = append(out, '\r') + case 't': + out = append(out, '\t') + case '\\': + out = append(out, '\\') + case 'x': + if idx+hexDigits < len(s) { + hi, hiOK := fromHex(s[idx+1]) + lo, loOK := fromHex(s[idx+hexDigits]) + + if hiOK && loOK { + out = append(out, hi<= '0' && digit <= '9': + return digit - '0', true + case digit >= 'a' && digit <= 'f': + return digit - 'a' + hexLetterOffset, true + case digit >= 'A' && digit <= 'F': + return digit - 'A' + hexLetterOffset, true + default: + return 0, false + } +} diff --git a/pkg/module/serial/serial_run_test.go b/pkg/module/serial/serial_run_test.go new file mode 100644 index 00000000..e1e1ca9e --- /dev/null +++ b/pkg/module/serial/serial_run_test.go @@ -0,0 +1,458 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package serial + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "sync" + "testing" + "time" +) + +// --- test doubles --------------------------------------------------------- + +// fakePort is an in-memory serialPort. Queued chunks are delivered by Read in +// order; once drained, Read emulates a serial read timeout (so Run stays +// responsive to context cancellation). Bytes written by the module are captured. +type fakePort struct { + mu sync.Mutex + out [][]byte + written []byte + closed bool +} + +type ioTimeout struct{} + +func (ioTimeout) Error() string { return "i/o timeout" } + +func (p *fakePort) queue(b []byte) { + p.mu.Lock() + defer p.mu.Unlock() + p.out = append(p.out, b) +} + +func (p *fakePort) Read(b []byte) (int, error) { + p.mu.Lock() + if len(p.out) > 0 { + chunk := p.out[0] + n := copy(b, chunk) + + if n < len(chunk) { + p.out[0] = chunk[n:] + } else { + p.out = p.out[1:] + } + + p.mu.Unlock() + + return n, nil + } + p.mu.Unlock() + + // Emulate the real port's ReadTimeout so the caller's loop spins at a + // bounded rate instead of busy-waiting. + time.Sleep(2 * time.Millisecond) + + return 0, ioTimeout{} +} + +func (p *fakePort) Write(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + p.written = append(p.written, b...) + + return len(b), nil +} + +func (p *fakePort) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + p.closed = true + + return nil +} + +func (p *fakePort) Flush() error { return nil } + +func (p *fakePort) writtenString() string { + p.mu.Lock() + defer p.mu.Unlock() + + return string(p.written) +} + +// syncBuffer is a goroutine-safe io.Writer for capturing client stdout, which +// the stdout pump writes concurrently with the test reading it. +type syncBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (s *syncBuffer) Write(b []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.buf.Write(b) +} + +func (s *syncBuffer) String() string { + s.mu.Lock() + defer s.mu.Unlock() + + return s.buf.String() +} + +// fakeSession implements module.Session for tests. +type fakeSession struct { + stdinR io.Reader + stdoutW io.Writer +} + +func (f *fakeSession) Print(...any) {} +func (f *fakeSession) Printf(string, ...any) {} +func (f *fakeSession) Println(...any) {} + +func (f *fakeSession) Console() (io.Reader, io.Writer, io.Writer) { + return f.stdinR, f.stdoutW, io.Discard +} + +func (f *fakeSession) RequestFile(string) (io.Reader, error) { return nil, nil } +func (f *fakeSession) SendFile(string, io.Reader) error { return nil } + +// newSession wires a fake session with an empty stdin (the module never reads +// it) and a capturing stdout. It returns the session and the stdout capture. +func newSession(t *testing.T) (*fakeSession, *syncBuffer) { + t.Helper() + + stdout := &syncBuffer{} + + return &fakeSession{stdinR: strings.NewReader(""), stdoutW: stdout}, stdout +} + +// --- tests ---------------------------------------------------------------- + +func TestRunSingleExpectMatch(t *testing.T) { + port := &fakePort{} + port.queue([]byte("booting kernel...\n")) + port.queue([]byte("dut login: ")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, stdout := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := s.Run(ctx, sess, "login:") + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if got := stdout.String(); !strings.Contains(got, "booting kernel") || + !strings.Contains(got, "Pattern matched") { + t.Errorf("stdout missing expected content: %q", got) + } +} + +func TestRunExpectTimeout(t *testing.T) { + port := &fakePort{} + port.queue([]byte("nothing interesting here\n")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, stdout := newSession(t) + + err := s.Run(context.Background(), sess, "-t", "150ms", "will-never-appear") + if err == nil { + t.Fatal("Run returned nil, want timeout error") + } + + if !strings.Contains(err.Error(), "timeout") { + t.Errorf("error = %v, want timeout", err) + } + + if got := stdout.String(); !strings.Contains(got, "Timeout reached") { + t.Errorf("stdout missing timeout banner: %q", got) + } +} + +func TestRunExpectSendPairs(t *testing.T) { + port := &fakePort{} + // Both prompts are present in the output stream; both pairs must fire in + // order, then the module drains briefly and exits. + port.queue([]byte("dut login: ")) + port.queue([]byte("user\r\nPassword: ")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + start := time.Now() + + err := s.Run(ctx, sess, "login:", "admin\\n", "Password:", "secret\\n") + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if elapsed := time.Since(start); elapsed < sendDrain { + t.Errorf("Run returned after %s, expected at least the %s drain", elapsed, sendDrain) + } + + if got := port.writtenString(); got != "admin\nsecret\n" { + t.Errorf("responses written to port = %q, want %q", got, "admin\nsecret\n") + } +} + +func TestRunExpectSendStopsBeforeLastResponseOnTimeout(t *testing.T) { + port := &fakePort{} + port.queue([]byte("dut login: ")) // only the first prompt ever appears + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _ := newSession(t) + + err := s.Run(context.Background(), sess, "-t", "200ms", "login:", "admin\\n", "Password:", "secret\\n") + if err == nil || !strings.Contains(err.Error(), "sequence not completed") { + t.Fatalf("Run error = %v, want 'sequence not completed' timeout", err) + } + + // First pair fired, second did not. + if got := port.writtenString(); got != "admin\n" { + t.Errorf("responses written to port = %q, want %q", got, "admin\n") + } +} + +func TestRunContextCancelReturnsPromptly(t *testing.T) { + port := &fakePort{} + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _ := newSession(t) + + ctx, cancel := context.WithCancel(context.Background()) + + runErr := make(chan error, 1) + + go func() { runErr <- s.Run(ctx, sess) }() + + time.Sleep(20 * time.Millisecond) + cancel() + + select { + case err := <-runErr: + if !errors.Is(err, context.Canceled) { + t.Errorf("Run error = %v, want context.Canceled", err) + } + case <-time.After(time.Second): + t.Fatal("Run did not return promptly after cancel") + } +} + +func TestRunStripsNonSGROutputButKeepsColour(t *testing.T) { + port := &fakePort{} + // DSR query and cursor-up (must be stripped), an SGR colour (must survive), + // plain text on either side. + port.queue([]byte("ABC\x1b[6nDEF\x1b[2A \x1b[31mred\x1b[0m XYZ")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, stdout := newSession(t) + + ctx, cancel := context.WithCancel(context.Background()) + + runErr := make(chan error, 1) + + go func() { runErr <- s.Run(ctx, sess) }() + + // Wait for the chunk to be displayed. + deadline := time.After(2 * time.Second) + + for !strings.Contains(stdout.String(), "XYZ") { + select { + case <-deadline: + t.Fatalf("output not displayed, got %q", stdout.String()) + case <-time.After(5 * time.Millisecond): + } + } + + cancel() + <-runErr + + got := stdout.String() + if strings.Contains(got, "\x1b[6n") { + t.Errorf("DSR query not stripped: %q", got) + } + + if strings.Contains(got, "\x1b[2A") { + t.Errorf("cursor-up not stripped: %q", got) + } + + if !strings.Contains(got, "\x1b[31mred\x1b[0m") { + t.Errorf("SGR colour not preserved: %q", got) + } + + // The two stripped CSIs sat between ABC/DEF and after red; surrounding plain + // text must remain intact (just with the CSIs removed). + if !strings.Contains(got, "ABCDEF ") || !strings.Contains(got, " XYZ") { + t.Errorf("plain text mangled: %q", got) + } +} + +// TestRunMatchSpanningReads verifies a pattern split across two serial reads is +// still matched, exercising the rolling match window. +func TestRunMatchSpanningReads(t *testing.T) { + port := &fakePort{} + port.queue([]byte("please lo")) + port.queue([]byte("gin: now")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if err := s.Run(ctx, sess, "login:"); err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } +} + +// TestRunSequenceLeadingSendThenExpect covers a tagged sequence that begins +// with a send (impossible with expect-send pairs): the module writes the +// leading input before any output arrives, then waits for the prompt and sends +// the reply. It ends on a send, so it drains before exiting. +func TestRunSequenceLeadingSendThenExpect(t *testing.T) { + port := &fakePort{} + port.queue([]byte("dut login: ")) // appears after the leading send goes out + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + start := time.Now() + + err := s.Run(ctx, sess, `send:\n`, "expect:login:", `send:root\n`) + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if elapsed := time.Since(start); elapsed < sendDrain { + t.Errorf("Run returned after %s, expected at least the %s drain", elapsed, sendDrain) + } + + if got := port.writtenString(); got != "\nroot\n" { + t.Errorf("written to port = %q, want %q", got, "\nroot\n") + } +} + +// TestRunSequenceSendOnly covers a sequence with no expect step at all: the +// module sends the input and drains briefly so the response is visible. +func TestRunSequenceSendOnly(t *testing.T) { + port := &fakePort{} + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + err := s.Run(ctx, sess, `send:reboot\n`) + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if got := port.writtenString(); got != "reboot\n" { + t.Errorf("written to port = %q, want %q", got, "reboot\n") + } +} + +// TestRunSequenceEndingOnExpectExitsWithoutDraining covers a sequence whose +// final step is an expect: the match is the completion, so the module exits +// immediately rather than draining. +func TestRunSequenceEndingOnExpectExitsWithoutDraining(t *testing.T) { + port := &fakePort{} + port.queue([]byte("starting\n")) + port.queue([]byte("ready> ")) + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + + sess, stdout := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + start := time.Now() + + err := s.Run(ctx, sess, `send:\n`, "expect:ready>") + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if elapsed := time.Since(start); elapsed >= sendDrain { + t.Errorf("Run took %s; a sequence ending on expect should exit without the %s drain", elapsed, sendDrain) + } + + if got := port.writtenString(); got != "\n" { + t.Errorf("written to port = %q, want %q", got, "\n") + } + + if got := stdout.String(); !strings.Contains(got, "Pattern matched") { + t.Errorf("stdout missing match banner: %q", got) + } +} + +// TestRunDelayPausesBeforeSend verifies the configured delay paces sends: the +// leading send waits the delay before it is written to the port. +func TestRunDelayPausesBeforeSend(t *testing.T) { + port := &fakePort{} + port.queue([]byte("ready> ")) // appears once the leading send (after the delay) goes out + + s := &Serial{Port: "fake", Baud: DefaultBaudRate} + s.dialPort = func() (serialPort, error) { return port, nil } + s.delay = 60 * time.Millisecond + + sess, _ := newSession(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + start := time.Now() + + // Leading send is paced by the delay; the sequence ends on an expect, so it + // exits as soon as the prompt matches (no drain to inflate the timing). + err := s.Run(ctx, sess, `send:\n`, "expect:ready>") + if err != nil { + t.Fatalf("Run returned error, want nil: %v", err) + } + + if elapsed := time.Since(start); elapsed < s.delay { + t.Errorf("Run took %s; expected at least the %s pre-send delay", elapsed, s.delay) + } + + if got := port.writtenString(); got != "\n" { + t.Errorf("written to port = %q, want %q", got, "\n") + } +} diff --git a/pkg/module/serial/serial_test.go b/pkg/module/serial/serial_test.go new file mode 100644 index 00000000..eec0f48c --- /dev/null +++ b/pkg/module/serial/serial_test.go @@ -0,0 +1,789 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package serial + +import ( + "strings" + "testing" + "time" +) + +// TestFilterOutputCSI verifies that filterOutputCSI correctly passes through +// plain text and SGR sequences, drops all other CSI sequences, and stores +// incomplete sequences in the remainder for the next read. +// +//nolint:funlen +func TestFilterOutputCSI(t *testing.T) { + tests := []struct { + name string + input []byte + wantOutput []byte + wantRemainder []byte + }{ + { + name: "empty input", + input: []byte{}, + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "plain text without ESC", + input: []byte("hello world"), + wantOutput: []byte("hello world"), + wantRemainder: nil, + }, + { + name: "SGR reset ESC[m preserved", + input: []byte("\x1b[m"), + wantOutput: []byte("\x1b[m"), + wantRemainder: nil, + }, + { + name: "SGR foreground colour ESC[31m preserved", + input: []byte("\x1b[31m"), + wantOutput: []byte("\x1b[31m"), + wantRemainder: nil, + }, + { + name: "SGR multi-parameter ESC[0;31m preserved", + input: []byte("\x1b[0;31m"), + wantOutput: []byte("\x1b[0;31m"), + wantRemainder: nil, + }, + { + name: "DSR query ESC[6n dropped", + input: []byte("\x1b[6n"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "cursor position report ESC[10;20R dropped", + input: []byte("\x1b[10;20R"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "cursor-up ESC[2A dropped", + input: []byte("\x1b[2A"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "erase-display ESC[2J dropped", + input: []byte("\x1b[2J"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + name: "text surrounding SGR preserved", + input: []byte("hello \x1b[31mworld"), + wantOutput: []byte("hello \x1b[31mworld"), + wantRemainder: nil, + }, + { + name: "non-SGR CSI in middle of text dropped", + input: []byte("hello \x1b[6n world"), + wantOutput: []byte("hello world"), + wantRemainder: nil, + }, + { + name: "mixed SGR and non-SGR: only SGR kept", + input: []byte("\x1b[31mcolor\x1b[6nquery\x1b[0mreset"), + wantOutput: []byte("\x1b[31mcolorquery\x1b[0mreset"), + wantRemainder: nil, + }, + { + name: "multiple consecutive SGR sequences preserved", + input: []byte("\x1b[31mred\x1b[0mreset"), + wantOutput: []byte("\x1b[31mred\x1b[0mreset"), + wantRemainder: nil, + }, + { + name: "lone ESC at end of buffer stored in remainder", + input: []byte("text\x1b"), + wantOutput: []byte("text"), + wantRemainder: []byte{escByte}, + }, + { + name: "incomplete CSI — ESC[ only — stored in remainder", + input: []byte("\x1b["), + wantOutput: []byte{}, + wantRemainder: []byte("\x1b["), + }, + { + name: "incomplete CSI with params stored in remainder", + input: []byte("text\x1b[31"), + wantOutput: []byte("text"), + wantRemainder: []byte("\x1b[31"), + }, + { + name: "ESC not followed by bracket emitted as-is", + input: []byte("\x1bO"), // SS3 — not a CSI sequence + wantOutput: []byte("\x1bO"), + wantRemainder: nil, + }, + { + name: "ESC not followed by bracket in middle of text", + input: []byte("ab\x1bOcd"), + wantOutput: []byte("ab\x1bOcd"), + wantRemainder: nil, + }, + { + name: "DCS terminated by ST dropped", + input: []byte("\x1bP1+r6E616D65\x1b\\"), + wantOutput: []byte{}, + wantRemainder: nil, + }, + { + // The exact capability-report DCS that leaked from the BMC console. + name: "real WezTerm capability DCS dropped, surrounding text kept", + input: []byte("before\x1bP1+r6E616D65=57657A5465726D\x1b\\after"), + wantOutput: []byte("beforeafter"), + wantRemainder: nil, + }, + { + name: "OSC terminated by BEL dropped", + input: []byte("\x1b]0;window title\x07rest"), + wantOutput: []byte("rest"), + wantRemainder: nil, + }, + { + name: "OSC terminated by ST dropped", + input: []byte("\x1b]0;title\x1b\\rest"), + wantOutput: []byte("rest"), + wantRemainder: nil, + }, + { + name: "DCS single-byte ST (0x9c) dropped", + input: []byte("x\x1bPdata\x9cy"), + wantOutput: []byte("xy"), + wantRemainder: nil, + }, + { + name: "APC sequence dropped", + input: []byte("a\x1b_payload\x1b\\b"), + wantOutput: []byte("ab"), + wantRemainder: nil, + }, + { + name: "incomplete DCS stored in remainder", + input: []byte("text\x1bP1+r6E61"), + wantOutput: []byte("text"), + wantRemainder: []byte("\x1bP1+r6E61"), + }, + { + name: "DCS with ESC at end (partial ST) stored in remainder", + input: []byte("\x1bP1+rABC\x1b"), + wantOutput: []byte{}, + wantRemainder: []byte("\x1bP1+rABC\x1b"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var remainder []byte + + got := filterOutputCSI(tt.input, &remainder) + + if string(got) != string(tt.wantOutput) { + t.Errorf("output = %q, want %q", got, tt.wantOutput) + } + + if string(remainder) != string(tt.wantRemainder) { + t.Errorf("remainder = %q, want %q", remainder, tt.wantRemainder) + } + }) + } +} + +// TestFilterOutputCSISequenceSplitAcrossReads simulates an SGR sequence +// whose bytes arrive in two separate buffer reads, verifying that the +// remainder mechanism reconstitutes it correctly. +func TestFilterOutputCSISequenceSplitAcrossReads(t *testing.T) { + var remainder []byte + + // Read 1: "ESC[31" — missing the final byte 'm'. + out1 := filterOutputCSI([]byte("\x1b[31"), &remainder) + + if len(out1) != 0 { + t.Errorf("read 1: expected no output for incomplete sequence, got %q", out1) + } + + if string(remainder) != "\x1b[31" { + t.Errorf("read 1: remainder = %q, want %q", remainder, "\x1b[31") + } + + // Read 2: prepend remainder to the new chunk (mirroring the main loop). + chunk := append(remainder, 'm') + out2 := filterOutputCSI(chunk, &remainder) + + if string(out2) != "\x1b[31m" { + t.Errorf("read 2: output = %q, want %q", out2, "\x1b[31m") + } + + if len(remainder) != 0 { + t.Errorf("read 2: expected empty remainder, got %q", remainder) + } +} + +// TestFilterOutputCSILoneESCSplitAcrossReads simulates a lone ESC at the end +// of one read whose '[' and final byte arrive in the next read. +func TestFilterOutputCSILoneESCSplitAcrossReads(t *testing.T) { + var remainder []byte + + // Read 1: text followed by a lone ESC at the buffer boundary. + out1 := filterOutputCSI([]byte("text\x1b"), &remainder) + + if string(out1) != "text" { + t.Errorf("read 1: output = %q, want %q", out1, "text") + } + + if string(remainder) != "\x1b" { + t.Errorf("read 1: remainder = %q, want %q", remainder, "\x1b") + } + + // Read 2: the bytes "[A" arrive — together with the remainder this forms the + // cursor-up sequence ESC[A which must be dropped. + chunk := append(remainder, []byte("[A")...) + out2 := filterOutputCSI(chunk, &remainder) + + if len(out2) != 0 { + t.Errorf("read 2: cursor-up sequence should be dropped, got %q", out2) + } + + if len(remainder) != 0 { + t.Errorf("read 2: expected empty remainder, got %q", remainder) + } +} + +// TestFilterOutputCSIStringSeqSplitAcrossReads simulates a DCS sequence whose +// String Terminator arrives in a later read, verifying the remainder mechanism +// reconstitutes and drops the whole sequence. +func TestFilterOutputCSIStringSeqSplitAcrossReads(t *testing.T) { + var remainder []byte + + // Read 1: a DCS that is not yet terminated. + out1 := filterOutputCSI([]byte("keep\x1bP1+r41424"), &remainder) + + if string(out1) != "keep" { + t.Errorf("read 1: output = %q, want %q", out1, "keep") + } + + if string(remainder) != "\x1bP1+r41424" { + t.Errorf("read 1: remainder = %q, want %q", remainder, "\x1bP1+r41424") + } + + // Read 2: prepend remainder to the new chunk; the ST now completes the DCS, + // which must be dropped, leaving only the trailing text. + chunk := append(remainder, []byte("3\x1b\\done")...) + out2 := filterOutputCSI(chunk, &remainder) + + if string(out2) != "done" { + t.Errorf("read 2: output = %q, want %q", out2, "done") + } + + if len(remainder) != 0 { + t.Errorf("read 2: expected empty remainder, got %q", remainder) + } +} + +// TestEvalArgs covers argument parsing for the serial module. +func TestEvalArgs(t *testing.T) { + tests := []struct { + name string + args []string + wantTimeout time.Duration + wantPattern string // empty means expect should be nil + wantErr bool + }{ + { + name: "no args", + args: nil, + }, + { + name: "timeout flag only", + args: []string{"-t", "5s"}, + wantTimeout: 5 * time.Second, + }, + { + name: "pattern only", + args: []string{"login:"}, + wantPattern: "login:", + }, + { + name: "timeout and pattern", + args: []string{"-t", "2m", "hello world"}, + wantTimeout: 2 * time.Minute, + wantPattern: "hello world", + }, + { + name: "regex pattern with flags", + args: []string{`(?i)Login\s*:`}, + wantPattern: `(?i)Login\s*:`, + }, + { + name: "invalid regex", + args: []string{"[invalid"}, + wantErr: true, + }, + { + name: "unknown flag", + args: []string{"-x"}, + wantErr: true, + }, + { + name: "invalid timeout value", + args: []string{"-t", "not-a-duration"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Serial{} + + err := s.evalArgs(tt.args) + + if (err != nil) != tt.wantErr { + t.Fatalf("evalArgs() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + if s.timeout != tt.wantTimeout { + t.Errorf("timeout = %v, want %v", s.timeout, tt.wantTimeout) + } + + if tt.wantPattern == "" { + if len(s.steps) != 0 { + t.Errorf("steps = %v, want none (interactive)", s.steps) + } + } else { + if len(s.steps) != 1 || s.steps[0].kind != stepExpect { + t.Fatalf("steps = %v, want a single expect-step for %q", s.steps, tt.wantPattern) + } + + if s.steps[0].pattern.String() != tt.wantPattern { + t.Errorf("expect pattern = %q, want %q", s.steps[0].pattern.String(), tt.wantPattern) + } + } + }) + } +} + +// TestSerialInit covers Init validation and default-baud-rate assignment. +func TestSerialInit(t *testing.T) { + tests := []struct { + name string + serial Serial + wantBaud int + wantDelay time.Duration + wantErr bool + }{ + { + name: "missing port returns error", + serial: Serial{}, + wantErr: true, + }, + { + name: "zero baud is replaced with DefaultBaudRate", + serial: Serial{Port: "/dev/ttyS0"}, + wantBaud: DefaultBaudRate, + wantDelay: defaultDelay, + }, + { + name: "explicit baud rate is preserved", + serial: Serial{Port: "/dev/ttyS0", Baud: 9600}, + wantBaud: 9600, + wantDelay: defaultDelay, + }, + { + name: "empty delay uses the default", + serial: Serial{Port: "/dev/ttyS0"}, + wantBaud: DefaultBaudRate, + wantDelay: defaultDelay, + }, + { + name: "explicit delay overrides the default", + serial: Serial{Port: "/dev/ttyS0", Delay: "200ms"}, + wantBaud: DefaultBaudRate, + wantDelay: 200 * time.Millisecond, + }, + { + name: "zero delay disables the default", + serial: Serial{Port: "/dev/ttyS0", Delay: "0s"}, + wantBaud: DefaultBaudRate, + wantDelay: 0, + }, + { + name: "invalid delay returns error", + serial: Serial{Port: "/dev/ttyS0", Delay: "not-a-duration"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.serial.Init() + + if (err != nil) != tt.wantErr { + t.Fatalf("Init() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + if tt.serial.Baud != tt.wantBaud { + t.Errorf("Baud = %d, want %d", tt.serial.Baud, tt.wantBaud) + } + + if tt.serial.delay != tt.wantDelay { + t.Errorf("delay = %v, want %v", tt.serial.delay, tt.wantDelay) + } + }) + } +} + +// TestSerialHelp verifies that the help text includes the configured port, +// baud rate, and key usage information. +func TestSerialHelp(t *testing.T) { + tests := []struct { + name string + serial Serial + wantIn []string + }{ + { + name: "contains configured port and baud", + serial: Serial{Port: "/dev/ttyS0", Baud: 9600}, + wantIn: []string{"/dev/ttyS0", "9600"}, + }, + { + name: "contains timeout flag and regex mention", + serial: Serial{Port: "/dev/ttyUSB0", Baud: DefaultBaudRate}, + wantIn: []string{"-t", "expect", "regex"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + help := strings.ToLower(tt.serial.Help()) + + for _, want := range tt.wantIn { + if !strings.Contains(help, strings.ToLower(want)) { + t.Errorf("Help() missing %q", want) + } + } + }) + } +} + +// TestUnescape verifies that unescape correctly converts C-style escape sequences. +// +//nolint:funlen +func TestUnescape(t *testing.T) { + tests := []struct { + name string + input string + want []byte + }{ + { + name: "empty string", + input: "", + want: []byte{}, + }, + { + name: "plain text unchanged", + input: "hello", + want: []byte("hello"), + }, + { + name: "newline escape", + input: `\n`, + want: []byte{'\n'}, + }, + { + name: "carriage return escape", + input: `\r`, + want: []byte{'\r'}, + }, + { + name: "tab escape", + input: `\t`, + want: []byte{'\t'}, + }, + { + name: "backslash escape", + input: `\\`, + want: []byte{'\\'}, + }, + { + name: "hex escape lowercase", + input: `\x41`, + want: []byte{'A'}, + }, + { + name: "hex escape uppercase", + input: `\x4F`, + want: []byte{'O'}, + }, + { + name: "hex escape zero byte", + input: `\x00`, + want: []byte{0x00}, + }, + { + name: "multiple hex escapes", + input: `\x0d\x0a`, + want: []byte{'\r', '\n'}, + }, + { + name: "mixed text and escapes", + input: `user\r\n`, + want: []byte{'u', 's', 'e', 'r', '\r', '\n'}, + }, + { + name: "unknown escape emitted as-is", + input: `\q`, + want: []byte{'\\', 'q'}, + }, + { + name: "lone backslash at end emitted as-is", + input: `abc\`, + want: []byte{'a', 'b', 'c', '\\'}, + }, + { + name: "incomplete hex escape emitted as-is", + input: `\x`, + want: []byte{'\\', 'x'}, + }, + { + // \x requires exactly two hex digits; with only one the \x is + // emitted as-is and the remaining character follows normally. + name: "hex escape with one digit only emitted as-is", + input: `\x4`, + want: []byte{'\\', 'x', '4'}, + }, + { + // Non-hex chars after \x: the \x is emitted as-is and the + // non-hex characters are emitted as regular bytes. + name: "hex escape with non-hex chars emitted as-is", + input: `\xGG`, + want: []byte{'\\', 'x', 'G', 'G'}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := unescape(tt.input) + + if string(got) != string(tt.want) { + t.Errorf("unescape(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +// TestEvalArgsPairs covers the expect-send pair parsing mode of evalArgs. +func TestEvalArgsPairs(t *testing.T) { + tests := []struct { + name string + args []string + wantPairs int // expected number of pairs + wantPat0 string // pattern of first pair (empty = skip check) + wantResp0 string // response of first pair (empty = skip check) + wantPat1 string // pattern of second pair (empty = skip check) + wantResp1 string // response of second pair (empty = skip check) + wantErr bool + }{ + { + name: "one pair", + args: []string{"login:", "user\\n"}, + wantPairs: 1, + wantPat0: "login:", + wantResp0: "user\n", + }, + { + name: "two pairs", + args: []string{"login:", "user\\n", "Password:", "secret\\n"}, + wantPairs: 2, + wantPat0: "login:", + wantResp0: "user\n", + wantPat1: "Password:", + wantResp1: "secret\n", + }, + { + name: "pair with hex escape in response", + args: []string{"ready", "\\x0d"}, + wantPairs: 1, + wantPat0: "ready", + wantResp0: "\r", + }, + { + name: "odd number of args returns error", + args: []string{"login:", "user", "extra"}, + wantErr: true, + }, + { + name: "invalid regex in pair returns error", + args: []string{"[invalid", "response"}, + wantErr: true, + }, + { + name: "invalid regex in second pair returns error", + args: []string{"login:", "user\\n", "[bad", "pass"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Serial{} + + err := s.evalArgs(tt.args) + + if (err != nil) != tt.wantErr { + t.Fatalf("evalArgs() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + // Each expect-send pair compiles to two steps: an expect-step + // followed by a send-step. + if len(s.steps) != tt.wantPairs*pairStride { + t.Fatalf("len(steps) = %d, want %d (%d pairs)", len(s.steps), tt.wantPairs*pairStride, tt.wantPairs) + } + + if tt.wantPat0 != "" && s.steps[0].pattern.String() != tt.wantPat0 { + t.Errorf("steps[0].pattern = %q, want %q", s.steps[0].pattern.String(), tt.wantPat0) + } + + if tt.wantResp0 != "" && string(s.steps[1].data) != tt.wantResp0 { + t.Errorf("steps[1].data = %q, want %q", s.steps[1].data, tt.wantResp0) + } + + if tt.wantPat1 != "" && len(s.steps) > 2 && s.steps[2].pattern.String() != tt.wantPat1 { + t.Errorf("steps[2].pattern = %q, want %q", s.steps[2].pattern.String(), tt.wantPat1) + } + + if tt.wantResp1 != "" && len(s.steps) > 3 && string(s.steps[3].data) != tt.wantResp1 { + t.Errorf("steps[3].data = %q, want %q", s.steps[3].data, tt.wantResp1) + } + }) + } +} + +// wantStep is the expected kind and payload of a single parsed sequence step. +type wantStep struct { + expect string // non-empty => stepExpect with this pattern + send string // used when expect is empty => stepSend with this (unescaped) data +} + +// TestEvalArgsSequence covers the tagged expect:/send: sequence parsing mode. +// +//nolint:funlen +func TestEvalArgsSequence(t *testing.T) { + tests := []struct { + name string + args []string + want []wantStep + wantErr bool + }{ + { + name: "leading send then expect", + args: []string{`send:\n`, "expect:login:"}, + want: []wantStep{{send: "\n"}, {expect: "login:"}}, + }, + { + name: "expect then send", + args: []string{"expect:login:", `send:root\n`}, + want: []wantStep{{expect: "login:"}, {send: "root\n"}}, + }, + { + name: "two sends in a row", + args: []string{`send:a`, `send:b`}, + want: []wantStep{{send: "a"}, {send: "b"}}, + }, + { + name: "full login-and-reboot sequence", + args: []string{`send:\n`, "expect:login:", `send:root\n`, "expect:# ", `send:reboot\n`}, + want: []wantStep{ + {send: "\n"}, {expect: "login:"}, {send: "root\n"}, {expect: "# "}, {send: "reboot\n"}, + }, + }, + { + name: "send with hex escape", + args: []string{`send:\x03`}, // Ctrl-C + want: []wantStep{{send: "\x03"}}, + }, + { + name: "empty send is allowed", + args: []string{"expect:done", "send:"}, + want: []wantStep{{expect: "done"}, {send: ""}}, + }, + { + name: "invalid regex in expect step", + args: []string{"expect:[invalid", `send:x`}, + wantErr: true, + }, + { + name: "untagged argument mixed in is rejected", + args: []string{"expect:login:", "root"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Serial{} + + err := s.evalArgs(tt.args) + + if (err != nil) != tt.wantErr { + t.Fatalf("evalArgs() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + if len(s.steps) != len(tt.want) { + t.Fatalf("len(steps) = %d, want %d", len(s.steps), len(tt.want)) + } + + for i, w := range tt.want { + got := s.steps[i] + + if w.expect != "" { + if got.kind != stepExpect { + t.Errorf("steps[%d] kind = send, want expect %q", i, w.expect) + + continue + } + + if got.pattern.String() != w.expect { + t.Errorf("steps[%d].pattern = %q, want %q", i, got.pattern.String(), w.expect) + } + + continue + } + + if got.kind != stepSend { + t.Errorf("steps[%d] kind = expect, want send %q", i, w.send) + + continue + } + + if string(got.data) != w.send { + t.Errorf("steps[%d].data = %q, want %q", i, got.data, w.send) + } + } + }) + } +}