From 149cdbc6c03472900a472a59e0b560cf7a48c53c Mon Sep 17 00:00:00 2001 From: Fabian Wienand Date: Thu, 8 Jan 2026 15:04:44 +0100 Subject: [PATCH 1/3] refactor: extract insecure HTTP client to pkg/rpc Move the duplicated newInsecureClient implementation from dutctl and dutagent into a shared pkg/rpc.NewInsecureClient function. This eliminates code duplication and provides a single source of truth for h2c (HTTP/2 without TLS) client configuration. Signed-off-by: Fabian Wienand --- cmds/dutagent/dutagent.go | 24 ++---------------------- cmds/dutctl/dutctl.go | 27 ++++----------------------- pkg/rpc/httpclient.go | 27 +++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 45 deletions(-) create mode 100644 pkg/rpc/httpclient.go diff --git a/cmds/dutagent/dutagent.go b/cmds/dutagent/dutagent.go index 72ffc5cb..94a6dc33 100644 --- a/cmds/dutagent/dutagent.go +++ b/cmds/dutagent/dutagent.go @@ -9,13 +9,11 @@ package main import ( "context" - "crypto/tls" "errors" "flag" "fmt" "io" "log" - "net" "net/http" "os" "os/signal" @@ -25,6 +23,7 @@ import ( "github.com/BlindspotSoftware/dutctl/internal/buildinfo" "github.com/BlindspotSoftware/dutctl/internal/dutagent" "github.com/BlindspotSoftware/dutctl/pkg/dut" + "github.com/BlindspotSoftware/dutctl/pkg/rpc" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -199,31 +198,12 @@ func spawnClient(agendURL string) dutctlv1connect.RelayServiceClient { log.Printf("Spawning new client for agent %q", agendURL) return dutctlv1connect.NewRelayServiceClient( - // Instead of http.DefaultClient, use the HTTP/2 protocol without TLS - newInsecureClient(), + rpc.NewInsecureClient(), fmt.Sprintf("http://%s", agendURL), connect.WithGRPC(), ) } -// TODO: refactor into pkg and reuse in dutctl and dutserver. -func newInsecureClient() *http.Client { - return &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // If you're also using this client for non-h2c traffic, you may want - // to delegate to tls.Dial if the network isn't TCP or the addr isn't - // in an allowlist. - - //nolint:noctx - return net.Dial(network, addr) - }, - // TODO: Don't forget timeouts! - }, - } -} - // start orchestrates the dutagent execution. // //nolint:cyclop diff --git a/cmds/dutctl/dutctl.go b/cmds/dutctl/dutctl.go index 3d397d37..1ec8a9ab 100644 --- a/cmds/dutctl/dutctl.go +++ b/cmds/dutctl/dutctl.go @@ -7,25 +7,23 @@ package main import ( - "crypto/tls" "errors" "flag" "fmt" "io" "log" - "net" - "net/http" "os" "connectrpc.com/connect" "github.com/BlindspotSoftware/dutctl/internal/buildinfo" "github.com/BlindspotSoftware/dutctl/internal/output" + "github.com/BlindspotSoftware/dutctl/pkg/rpc" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" - "golang.org/x/net/http2" ) const usageAbstract = `dutctl - The client application of the DUT Control system. ` + const usageSynopsis = ` SYNOPSIS: dutctl [options] list @@ -35,6 +33,7 @@ SYNOPSIS: dutctl version ` + const usageDescription = ` If a device and a command are provided, dutctl will execute the command on the device. The optional args are passed to the command. @@ -115,8 +114,7 @@ type application struct { func (app *application) setupRPCClient() { client := dutctlv1connect.NewDeviceServiceClient( - // Instead of http.DefaultClient, use the HTTP/2 protocol without TLS - newInsecureClient(), + rpc.NewInsecureClient(), fmt.Sprintf("http://%s", app.serverAddr), connect.WithGRPC(), ) @@ -124,23 +122,6 @@ func (app *application) setupRPCClient() { app.rpcClient = client } -func newInsecureClient() *http.Client { - return &http.Client{ - Transport: &http2.Transport{ - AllowHTTP: true, - DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { - // If you're also using this client for non-h2c traffic, you may want - // to delegate to tls.Dial if the network isn't TCP or the addr isn't - // in an allowlist. - - //nolint:noctx - return net.Dial(network, addr) - }, - // Don't forget timeouts! - }, - } -} - var errInvalidCmdline = fmt.Errorf("invalid command line") // start is the entry point of the application. diff --git a/pkg/rpc/httpclient.go b/pkg/rpc/httpclient.go new file mode 100644 index 00000000..43b88b76 --- /dev/null +++ b/pkg/rpc/httpclient.go @@ -0,0 +1,27 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rpc provides HTTP client utilities for RPC communication. +package rpc + +import ( + "crypto/tls" + "net" + "net/http" + + "golang.org/x/net/http2" +) + +// NewInsecureClient creates an HTTP client for h2c (HTTP/2 without TLS). +func NewInsecureClient() *http.Client { + return &http.Client{ + Transport: &http2.Transport{ + AllowHTTP: true, + DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) { + //nolint:noctx + return net.Dial(network, addr) + }, + }, + } +} From 6fa0fa5ede16562e74ebfed899701c31d6d7540a Mon Sep 17 00:00:00 2001 From: Fabian Wienand Date: Thu, 8 Jan 2026 14:17:26 +0100 Subject: [PATCH 2/3] feat: add TLS encryption for client-agent communication Implement TLS support using Ed25519 self-signed certificates to encrypt communication between dutctl client and dutagent server. TLS is enabled by default with an --insecure flag available for HTTP support. This provides encryption only, not client authentication. Any client can connect to the agent. Signed-off-by: Fabian Wienand --- cmds/dutagent/dutagent.go | 94 ++++++++++++++--- cmds/dutctl/dutctl.go | 12 ++- internal/tlsutil/tlsutil.go | 203 ++++++++++++++++++++++++++++++++++++ pkg/rpc/httpclient.go | 42 +++++++- 4 files changed, 331 insertions(+), 20 deletions(-) create mode 100644 internal/tlsutil/tlsutil.go diff --git a/cmds/dutagent/dutagent.go b/cmds/dutagent/dutagent.go index 94a6dc33..d6a1e807 100644 --- a/cmds/dutagent/dutagent.go +++ b/cmds/dutagent/dutagent.go @@ -9,6 +9,7 @@ package main import ( "context" + "crypto/tls" "errors" "flag" "fmt" @@ -18,10 +19,12 @@ import ( "os" "os/signal" "syscall" + "time" "connectrpc.com/connect" "github.com/BlindspotSoftware/dutctl/internal/buildinfo" "github.com/BlindspotSoftware/dutctl/internal/dutagent" + "github.com/BlindspotSoftware/dutctl/internal/tlsutil" "github.com/BlindspotSoftware/dutctl/pkg/dut" "github.com/BlindspotSoftware/dutctl/pkg/rpc" "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect" @@ -54,6 +57,9 @@ func newAgent(stdout io.Writer, exitFunc func(int), args []string) *agent { fs.BoolVar(&agt.dryRun, "dry-run", false, dryRunInfo) fs.StringVar(&agt.server, "server", "", serverInfo) fs.BoolVar(&agt.versionFlag, "v", false, versionFlagInfo) + fs.BoolVar(&agt.insecure, "insecure", false, "Disable TLS (use plain HTTP)") + fs.StringVar(&agt.tlsCertPath, "tls-cert", "/etc/dutagent/tls/cert.pem", "Path to TLS certificate file (auto-generated if missing)") + fs.StringVar(&agt.tlsKeyPath, "tls-key", "/etc/dutagent/tls/key.pem", "Path to TLS key file (auto-generated if missing)") //nolint:errcheck // flag.Parse always returns no error because of flag.ExitOnError fs.Parse(args[1:]) @@ -72,9 +78,13 @@ type agent struct { checkConfig bool dryRun bool server string + insecure bool + tlsCertPath string + tlsKeyPath string // state - config config + config config + httpServer *http.Server } // config holds the dutagent configuration that is parsed from YAML data. @@ -94,6 +104,20 @@ const ( // Afterwards agt.exit is called. If clean-up fails, agt.exit is called with code 1, // otherwise with provided exitCode. func (agt *agent) cleanup(code exitCode) { + // Gracefully shutdown HTTP server first + if agt.httpServer != nil { + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + + log.Print("Shutting down HTTP server gracefully...") + + err := agt.httpServer.Shutdown(ctx) + if err != nil { + log.Printf("HTTP server shutdown error: %v", err) + // Continue with cleanup even if shutdown fails + } + } + devlist := agt.config.Devices if devlist != nil { err := dutagent.Deinit(devlist) @@ -152,6 +176,14 @@ func printInitErr(err error) { log.Print(err) } +const ( + readHeaderTimeout = 10 * time.Second + writeTimeout = 30 * time.Second + idleTimeout = 120 * time.Second + maxHeaderBytes = 1 << 20 // 1 MB + shutdownTimeout = 10 * time.Second +) + // startRPCService starts the RPC service, that ideally listens for incoming // connections forever. It always returns an non-nil error. func (agt *agent) startRPCService() error { @@ -163,18 +195,49 @@ func (agt *agent) startRPCService() error { path, handler := dutctlv1connect.NewDeviceServiceHandler(service) mux.Handle(path, handler) - //nolint:gosec - return http.ListenAndServe( - agt.address, - // Use h2c so we can serve HTTP/2 without TLS. - h2c.NewHandler(mux, &http2.Server{}), - ) + if agt.insecure { + // Use h2c so we can serve HTTP/2 without TLS + log.Printf("Starting in INSECURE mode (plain HTTP) on %s", agt.address) + //nolint:gosec + return http.ListenAndServe( + agt.address, + h2c.NewHandler(mux, &http2.Server{}), + ) + } + + // Use TLS mode (default) - load or auto-generate certificate + cert, err := tlsutil.LoadOrGenerateCert(agt.tlsCertPath, agt.tlsKeyPath) + if err != nil { + return fmt.Errorf("failed to load/generate TLS certificate: %w", err) + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS13, + } + + server := &http.Server{ + Addr: agt.address, + Handler: mux, + TLSConfig: tlsConfig, + ReadHeaderTimeout: readHeaderTimeout, + WriteTimeout: writeTimeout, + IdleTimeout: idleTimeout, + MaxHeaderBytes: maxHeaderBytes, + } + + agt.httpServer = server + + log.Printf("Starting TLS-enabled RPC service on %s", agt.address) + + // ListenAndServeTLS with empty cert/key paths since we've already loaded them in tlsConfig + return server.ListenAndServeTLS("", "") } func (agt *agent) registerWithServer() error { log.Printf("Registering with server %q", agt.server) - client := spawnClient(agt.server) + client := spawnClient(agt.server, agt.insecure) req := connect.NewRequest(&pb.RegisterRequest{ Devices: agt.config.Devices.Names(), Address: agt.address, @@ -191,15 +254,20 @@ func (agt *agent) registerWithServer() error { } // spawnClient creates a new client to the DUT server specified by the server address. -// TODO: refactor into pkg and reuse in dutctl and dutserver. // //nolint:ireturn -func spawnClient(agendURL string) dutctlv1connect.RelayServiceClient { - log.Printf("Spawning new client for agent %q", agendURL) +func spawnClient(serverURL string, insecure bool) dutctlv1connect.RelayServiceClient { + client, scheme := rpc.NewClient(insecure) + + if insecure { + log.Printf("Spawning insecure client for server %q", serverURL) + } else { + log.Printf("Spawning TLS client for server %q", serverURL) + } return dutctlv1connect.NewRelayServiceClient( - rpc.NewInsecureClient(), - fmt.Sprintf("http://%s", agendURL), + client, + fmt.Sprintf("%s://%s", scheme, serverURL), connect.WithGRPC(), ) } diff --git a/cmds/dutctl/dutctl.go b/cmds/dutctl/dutctl.go index 1ec8a9ab..9b993d1e 100644 --- a/cmds/dutctl/dutctl.go +++ b/cmds/dutctl/dutctl.go @@ -77,6 +77,7 @@ func newApp(stdin io.Reader, stdout, stderr io.Writer, exitFunc func(int), args fs.StringVar(&app.outputFormat, "f", "", outputFormatInfo) fs.BoolVar(&app.verbose, "v", false, verboseInfo) fs.BoolVar(&app.noColor, "no-color", false, noColorInfo) + fs.BoolVar(&app.insecure, "insecure", false, "Disable TLS (use plain HTTP)") //nolint:errcheck // flag.Parse always returns no error because of flag.ExitOnError fs.Parse(args[1:]) @@ -105,6 +106,7 @@ type application struct { outputFormat string verbose bool noColor bool + insecure bool args []string printFlagDefaults func() @@ -113,13 +115,13 @@ type application struct { } func (app *application) setupRPCClient() { - client := dutctlv1connect.NewDeviceServiceClient( - rpc.NewInsecureClient(), - fmt.Sprintf("http://%s", app.serverAddr), + client, scheme := rpc.NewClient(app.insecure) + + app.rpcClient = dutctlv1connect.NewDeviceServiceClient( + client, + fmt.Sprintf("%s://%s", scheme, app.serverAddr), connect.WithGRPC(), ) - - app.rpcClient = client } var errInvalidCmdline = fmt.Errorf("invalid command line") diff --git a/internal/tlsutil/tlsutil.go b/internal/tlsutil/tlsutil.go new file mode 100644 index 00000000..2007c6dd --- /dev/null +++ b/internal/tlsutil/tlsutil.go @@ -0,0 +1,203 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tlsutil + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "log" + "math/big" + "net" + "os" + "path/filepath" + "time" +) + +const ( + // File and directory permissions. + certFileMode = 0644 // Public read, owner write. + keyFileMode = 0600 // Owner read/write only + dirMode = 0755 // Standard directory permissions. + + // Certificate serial number bit size. + serialNumberBits = 128 +) + +// GenerateSelfSignedCert creates a new self-signed TLS certificate and private key. +// The certificate is valid for 10 years and includes localhost and system hostname in SANs. +// Uses Ed25519 for better performance and security compared to RSA. +func GenerateSelfSignedCert(certPath, keyPath string) error { + // Generate Ed25519 private key (much faster than RSA) + publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("failed to generate keys: %w", err) + } + + // Create self-signed certificate + derBytes, err := createSelfSignedCertificate(publicKey, privateKey) + if err != nil { + return err + } + + err = writeCertificate(certPath, derBytes) + if err != nil { + return err + } + + err = writePrivateKey(keyPath, privateKey) + if err != nil { + return err + } + + log.Printf("Generated self-signed TLS certificate: %s", certPath) + log.Printf("Generated private key: %s", keyPath) + + return nil +} + +func createSelfSignedCertificate(publicKey ed25519.PublicKey, privateKey ed25519.PrivateKey) ([]byte, error) { + // Generate a random serial number + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), serialNumberBits) + + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + // Get system hostname for SANs + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" // Fallback if hostname detection fails + } + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Blindspot Software"}, + CommonName: "dutagent", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), // 10 years + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost", hostname}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, publicKey, privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %w", err) + } + + return derBytes, nil +} + +func writeCertificate(certPath string, derBytes []byte) error { + certOut, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, certFileMode) + if err != nil { + return fmt.Errorf("failed to create certificate file: %w", err) + } + defer certOut.Close() + + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + if err != nil { + return fmt.Errorf("failed to write certificate: %w", err) + } + + return nil +} + +func writePrivateKey(keyPath string, privateKey ed25519.PrivateKey) error { + // Marshal Ed25519 private key in PKCS8 format + privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return fmt.Errorf("failed to marshal private key: %w", err) + } + + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, keyFileMode) + if err != nil { + return fmt.Errorf("failed to create key file: %w", err) + } + defer keyOut.Close() + + err = pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + if err != nil { + return fmt.Errorf("failed to write private key: %w", err) + } + + return nil +} + +// LoadOrGenerateCert attempts to load an existing TLS certificate/key pair. +// If the files don't exist, it generates a new self-signed certificate. +// If the files exist but cannot be loaded, it returns an error without overwriting them. +func LoadOrGenerateCert(certPath, keyPath string) (tls.Certificate, error) { + // Check if certificate and key files exist + certExists := fileExists(certPath) + keyExists := fileExists(keyPath) + + // If either file exists, we must load them (don't auto-generate) + if certExists || keyExists { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return tls.Certificate{}, fmt.Errorf("certificate/key files exist but failed to load (cert exists: %v, key exists: %v): %w", + certExists, keyExists, err) + } + + log.Printf("Loaded existing TLS certificate from: %s", certPath) + + return cert, nil + } + + // Neither file exists, generate new certificate + log.Printf("TLS certificate not found, generating new self-signed certificate...") + + // Derive directory from cert path + certDir := filepath.Dir(certPath) + keyDir := filepath.Dir(keyPath) + + // Ensure directories exist + err := os.MkdirAll(certDir, dirMode) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to create certificate directory: %w", err) + } + + if certDir != keyDir { + err := os.MkdirAll(keyDir, dirMode) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to create key directory: %w", err) + } + } + + // Generate certificate + err = GenerateSelfSignedCert(certPath, keyPath) + if err != nil { + return tls.Certificate{}, err + } + + // Load the newly generated certificate + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to load generated certificate: %w", err) + } + + return cert, nil +} + +// fileExists checks if a file exists and is not a directory. +func fileExists(path string) bool { + info, err := os.Stat(path) + if err != nil { + return false + } + + return !info.IsDir() +} diff --git a/pkg/rpc/httpclient.go b/pkg/rpc/httpclient.go index 43b88b76..f3018054 100644 --- a/pkg/rpc/httpclient.go +++ b/pkg/rpc/httpclient.go @@ -9,12 +9,50 @@ import ( "crypto/tls" "net" "net/http" + "time" "golang.org/x/net/http2" ) -// NewInsecureClient creates an HTTP client for h2c (HTTP/2 without TLS). -func NewInsecureClient() *http.Client { +const ( + // HTTP transport timeout configurations. + responseHeaderTimeout = 10 * time.Second + idleConnTimeout = 90 * time.Second + tlsHandshakeTimeout = 10 * time.Second + expectContinueTimeout = 1 * time.Second +) + +// NewClient creates an HTTP client for RPC communication. +// If insecure is true, it returns an h2c (HTTP/2 without TLS) client with "http" scheme. +// Otherwise, it returns a TLS-enabled HTTP/2 client with proper timeouts and "https" scheme. +// Returns the HTTP client and the URL scheme to use. +func NewClient(insecure bool) (*http.Client, string) { + if insecure { + return newInsecureClient(), "http" + } + + return newTLSClient(), "https" +} + +// newTLSClient creates an HTTP client configured for TLS with HTTP/2. +func newTLSClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec // User controls server trust; InsecureSkipVerify appropriate for test environments + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + }, + ResponseHeaderTimeout: responseHeaderTimeout, + IdleConnTimeout: idleConnTimeout, + TLSHandshakeTimeout: tlsHandshakeTimeout, + ExpectContinueTimeout: expectContinueTimeout, + }, + } +} + +// newInsecureClient creates an HTTP client for h2c (HTTP/2 without TLS). +func newInsecureClient() *http.Client { return &http.Client{ Transport: &http2.Transport{ AllowHTTP: true, From ab6f2298862295d386519b14415f100a56a829c1 Mon Sep 17 00:00:00 2001 From: Fabian Wienand Date: Thu, 8 Jan 2026 15:42:48 +0100 Subject: [PATCH 3/3] test: add unit tests for RPC client and TLS utilities Signed-off-by: Fabian Wienand --- internal/tlsutil/tlsutil_test.go | 134 +++++++++++++++++++++++++++++++ pkg/rpc/httpclient_test.go | 73 +++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 internal/tlsutil/tlsutil_test.go create mode 100644 pkg/rpc/httpclient_test.go diff --git a/internal/tlsutil/tlsutil_test.go b/internal/tlsutil/tlsutil_test.go new file mode 100644 index 00000000..5d5d1bf3 --- /dev/null +++ b/internal/tlsutil/tlsutil_test.go @@ -0,0 +1,134 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tlsutil + +import ( + "crypto/tls" + "crypto/x509" + "os" + "path/filepath" + "testing" +) + +func TestGenerateSelfSignedCert(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) (certPath, keyPath string) + wantErr bool + }{ + { + name: "generates valid certificate", + setupFunc: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + return filepath.Join(tmpDir, "cert.pem"), filepath.Join(tmpDir, "key.pem") + }, + wantErr: false, + }, + { + name: "fails with invalid path", + setupFunc: func(t *testing.T) (string, string) { + return "/nonexistent/directory/cert.pem", "/nonexistent/directory/key.pem" + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + certPath, keyPath := tt.setupFunc(t) + + err := GenerateSelfSignedCert(certPath, keyPath) + + if (err != nil) != tt.wantErr { + t.Errorf("GenerateSelfSignedCert() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + // Verify files exist and can be loaded + if _, err := os.Stat(certPath); err != nil { + t.Errorf("Certificate file not created: %v", err) + } + if _, err := os.Stat(keyPath); err != nil { + t.Errorf("Key file not created: %v", err) + } + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + t.Fatalf("Failed to load certificate: %v", err) + } + + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + t.Fatalf("Failed to parse certificate: %v", err) + } + + if x509Cert.Subject.CommonName != "dutagent" { + t.Errorf("CommonName = %q, want %q", x509Cert.Subject.CommonName, "dutagent") + } + } + }) + } +} + +func TestLoadOrGenerateCert(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) (certPath, keyPath string) + wantErr bool + }{ + { + name: "generates when files don't exist", + setupFunc: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + return filepath.Join(tmpDir, "cert.pem"), filepath.Join(tmpDir, "key.pem") + }, + wantErr: false, + }, + { + name: "loads existing certificate", + setupFunc: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + if err := GenerateSelfSignedCert(certPath, keyPath); err != nil { + t.Fatalf("Setup failed: %v", err) + } + return certPath, keyPath + }, + wantErr: false, + }, + { + name: "fails when only cert exists", + setupFunc: func(t *testing.T) (string, string) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + if err := os.WriteFile(certPath, []byte("invalid"), 0644); err != nil { + t.Fatalf("Setup failed: %v", err) + } + return certPath, keyPath + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + certPath, keyPath := tt.setupFunc(t) + + cert, err := LoadOrGenerateCert(certPath, keyPath) + + if (err != nil) != tt.wantErr { + t.Errorf("LoadOrGenerateCert() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && len(cert.Certificate) == 0 { + t.Error("Certificate is empty") + } + }) + } +} diff --git a/pkg/rpc/httpclient_test.go b/pkg/rpc/httpclient_test.go new file mode 100644 index 00000000..fcbac4aa --- /dev/null +++ b/pkg/rpc/httpclient_test.go @@ -0,0 +1,73 @@ +// Copyright 2025 Blindspot Software +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc_test + +import ( + "crypto/tls" + "net/http" + "testing" + + "github.com/BlindspotSoftware/dutctl/pkg/rpc" + "golang.org/x/net/http2" +) + +func TestNewClient(t *testing.T) { + tests := []struct { + name string + insecure bool + wantScheme string + wantTransportType interface{} + }{ + { + name: "insecure returns http and http2.Transport", + insecure: true, + wantScheme: "http", + wantTransportType: &http2.Transport{}, + }, + { + name: "secure returns https and http.Transport with TLS", + insecure: false, + wantScheme: "https", + wantTransportType: &http.Transport{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, scheme := rpc.NewClient(tt.insecure) + + if client == nil { + t.Fatal("NewClient returned nil client") + } + + if scheme != tt.wantScheme { + t.Errorf("scheme = %q, want %q", scheme, tt.wantScheme) + } + + if client.Transport == nil { + t.Fatal("Client transport is nil") + } + + switch tt.wantTransportType.(type) { + case *http2.Transport: + if _, ok := client.Transport.(*http2.Transport); !ok { + t.Errorf("Expected *http2.Transport, got %T", client.Transport) + } + case *http.Transport: + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("Expected *http.Transport, got %T", client.Transport) + } + + if transport.TLSClientConfig == nil { + t.Fatal("TLS client config is nil") + } + if transport.TLSClientConfig.MinVersion != tls.VersionTLS13 { + t.Errorf("MinVersion = %v, want TLS 1.3", transport.TLSClientConfig.MinVersion) + } + } + }) + } +}