Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 81 additions & 33 deletions cmds/dutagent/dutagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"connectrpc.com/connect"
"github.com/BlindspotSoftware/dutctl/internal/buildinfo"
"github.com/BlindspotSoftware/dutctl/internal/dutagent"
"github.com/BlindspotSoftware/dutctl/internal/tlsutil"
"github.com/BlindspotSoftware/dutctl/pkg/dut"
"github.com/BlindspotSoftware/dutctl/pkg/rpc"
"github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
Expand Down Expand Up @@ -55,6 +57,9 @@ func newAgent(stdout io.Writer, exitFunc func(int), args []string) *agent {
fs.BoolVar(&agt.dryRun, "dry-run", false, dryRunInfo)
fs.StringVar(&agt.server, "server", "", serverInfo)
fs.BoolVar(&agt.versionFlag, "v", false, versionFlagInfo)
fs.BoolVar(&agt.insecure, "insecure", false, "Disable TLS (use plain HTTP)")
fs.StringVar(&agt.tlsCertPath, "tls-cert", "/etc/dutagent/tls/cert.pem", "Path to TLS certificate file (auto-generated if missing)")
fs.StringVar(&agt.tlsKeyPath, "tls-key", "/etc/dutagent/tls/key.pem", "Path to TLS key file (auto-generated if missing)")
Comment thread
RiSKeD marked this conversation as resolved.
//nolint:errcheck // flag.Parse always returns no error because of flag.ExitOnError
fs.Parse(args[1:])

Expand All @@ -73,9 +78,13 @@ type agent struct {
checkConfig bool
dryRun bool
server string
insecure bool
tlsCertPath string
tlsKeyPath string

// state
config config
config config
httpServer *http.Server
}

// config holds the dutagent configuration that is parsed from YAML data.
Expand All @@ -95,6 +104,20 @@ const (
// Afterwards agt.exit is called. If clean-up fails, agt.exit is called with code 1,
// otherwise with provided exitCode.
func (agt *agent) cleanup(code exitCode) {
// Gracefully shutdown HTTP server first
if agt.httpServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()

log.Print("Shutting down HTTP server gracefully...")

err := agt.httpServer.Shutdown(ctx)
if err != nil {
log.Printf("HTTP server shutdown error: %v", err)
// Continue with cleanup even if shutdown fails
}
}

devlist := agt.config.Devices
if devlist != nil {
err := dutagent.Deinit(devlist)
Expand Down Expand Up @@ -153,6 +176,14 @@ func printInitErr(err error) {
log.Print(err)
}

const (
readHeaderTimeout = 10 * time.Second
writeTimeout = 30 * time.Second
idleTimeout = 120 * time.Second
maxHeaderBytes = 1 << 20 // 1 MB
shutdownTimeout = 10 * time.Second
)

// startRPCService starts the RPC service, that ideally listens for incoming
// connections forever. It always returns an non-nil error.
func (agt *agent) startRPCService() error {
Expand All @@ -164,18 +195,49 @@ func (agt *agent) startRPCService() error {
path, handler := dutctlv1connect.NewDeviceServiceHandler(service)
mux.Handle(path, handler)

//nolint:gosec
return http.ListenAndServe(
agt.address,
// Use h2c so we can serve HTTP/2 without TLS.
h2c.NewHandler(mux, &http2.Server{}),
)
if agt.insecure {
// Use h2c so we can serve HTTP/2 without TLS
log.Printf("Starting in INSECURE mode (plain HTTP) on %s", agt.address)
//nolint:gosec
return http.ListenAndServe(
agt.address,
h2c.NewHandler(mux, &http2.Server{}),
)
}

// Use TLS mode (default) - load or auto-generate certificate
cert, err := tlsutil.LoadOrGenerateCert(agt.tlsCertPath, agt.tlsKeyPath)
if err != nil {
return fmt.Errorf("failed to load/generate TLS certificate: %w", err)
}

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
}
Comment thread
RiSKeD marked this conversation as resolved.

server := &http.Server{
Addr: agt.address,
Handler: mux,
TLSConfig: tlsConfig,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
MaxHeaderBytes: maxHeaderBytes,
}

agt.httpServer = server

log.Printf("Starting TLS-enabled RPC service on %s", agt.address)

// ListenAndServeTLS with empty cert/key paths since we've already loaded them in tlsConfig
return server.ListenAndServeTLS("", "")
}

func (agt *agent) registerWithServer() error {
log.Printf("Registering with server %q", agt.server)

client := spawnClient(agt.server)
client := spawnClient(agt.server, agt.insecure)
req := connect.NewRequest(&pb.RegisterRequest{
Devices: agt.config.Devices.Names(),
Address: agt.address,
Expand All @@ -192,38 +254,24 @@ func (agt *agent) registerWithServer() error {
}

// spawnClient creates a new client to the DUT server specified by the server address.
// TODO: refactor into pkg and reuse in dutctl and dutserver.
//
//nolint:ireturn
func spawnClient(agendURL string) dutctlv1connect.RelayServiceClient {
log.Printf("Spawning new client for agent %q", agendURL)
func spawnClient(serverURL string, insecure bool) dutctlv1connect.RelayServiceClient {
client, scheme := rpc.NewClient(insecure)

if insecure {
log.Printf("Spawning insecure client for server %q", serverURL)
} else {
log.Printf("Spawning TLS client for server %q", serverURL)
}

return dutctlv1connect.NewRelayServiceClient(
// Instead of http.DefaultClient, use the HTTP/2 protocol without TLS
newInsecureClient(),
fmt.Sprintf("http://%s", agendURL),
client,
fmt.Sprintf("%s://%s", scheme, serverURL),
connect.WithGRPC(),
)
}

// TODO: refactor into pkg and reuse in dutctl and dutserver.
func newInsecureClient() *http.Client {
return &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
// If you're also using this client for non-h2c traffic, you may want
// to delegate to tls.Dial if the network isn't TCP or the addr isn't
// in an allowlist.

//nolint:noctx
return net.Dial(network, addr)
},
// TODO: Don't forget timeouts!
},
}
}

// start orchestrates the dutagent execution.
//
//nolint:cyclop
Expand Down
37 changes: 10 additions & 27 deletions cmds/dutctl/dutctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,23 @@
package main

import (
"crypto/tls"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"os"

"connectrpc.com/connect"
"github.com/BlindspotSoftware/dutctl/internal/buildinfo"
"github.com/BlindspotSoftware/dutctl/internal/output"
"github.com/BlindspotSoftware/dutctl/pkg/rpc"
"github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1/dutctlv1connect"
"golang.org/x/net/http2"
)

const usageAbstract = `dutctl - The client application of the DUT Control system.
`

const usageSynopsis = `
SYNOPSIS:
dutctl [options] list
Expand All @@ -35,6 +33,7 @@ SYNOPSIS:
dutctl version

`

const usageDescription = `
If a device and a command are provided, dutctl will execute the command on the device.
The optional args are passed to the command.
Expand Down Expand Up @@ -78,6 +77,7 @@ func newApp(stdin io.Reader, stdout, stderr io.Writer, exitFunc func(int), args
fs.StringVar(&app.outputFormat, "f", "", outputFormatInfo)
fs.BoolVar(&app.verbose, "v", false, verboseInfo)
fs.BoolVar(&app.noColor, "no-color", false, noColorInfo)
fs.BoolVar(&app.insecure, "insecure", false, "Disable TLS (use plain HTTP)")
Comment thread
RiSKeD marked this conversation as resolved.

//nolint:errcheck // flag.Parse always returns no error because of flag.ExitOnError
fs.Parse(args[1:])
Expand Down Expand Up @@ -106,6 +106,7 @@ type application struct {
outputFormat string
verbose bool
noColor bool
insecure bool
args []string
printFlagDefaults func()

Expand All @@ -114,31 +115,13 @@ type application struct {
}

func (app *application) setupRPCClient() {
client := dutctlv1connect.NewDeviceServiceClient(
// Instead of http.DefaultClient, use the HTTP/2 protocol without TLS
newInsecureClient(),
fmt.Sprintf("http://%s", app.serverAddr),
client, scheme := rpc.NewClient(app.insecure)

app.rpcClient = dutctlv1connect.NewDeviceServiceClient(
client,
fmt.Sprintf("%s://%s", scheme, app.serverAddr),
connect.WithGRPC(),
)

app.rpcClient = client
}

func newInsecureClient() *http.Client {
return &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
// If you're also using this client for non-h2c traffic, you may want
// to delegate to tls.Dial if the network isn't TCP or the addr isn't
// in an allowlist.

//nolint:noctx
return net.Dial(network, addr)
},
// Don't forget timeouts!
},
}
}

var errInvalidCmdline = fmt.Errorf("invalid command line")
Expand Down
Loading
Loading