Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmds/dutagent/dutagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ func printInitErr(err error) {
func (agt *agent) startRPCService() error {
service := &rpcService{
devices: agt.config.Devices,
locker: dutagent.NewLocker(),
}

mux := http.NewServeMux()
Expand Down
130 changes: 128 additions & 2 deletions cmds/dutagent/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,32 @@ import (
"errors"
"fmt"
"log"
"net/http"
"time"

"connectrpc.com/connect"
"github.com/BlindspotSoftware/dutctl/internal/dutagent"
"github.com/BlindspotSoftware/dutctl/internal/fsm"
"github.com/BlindspotSoftware/dutctl/pkg/dut"
"github.com/BlindspotSoftware/dutctl/pkg/lock"

pb "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1"
)

// rpcService is the service implementation for the RPCs provided by dutagent.
type rpcService struct {
devices dut.Devlist
locker *dutagent.Locker
}

// userFromHeader returns the calling user's identity from a request header,
// or a unique anonymous placeholder when the header is missing.
func userFromHeader(h http.Header) string {
if user := h.Get(lock.UserHeader); user != "" {
return user
}

return lock.AnonymousUser()
}

// List is the handler for the List RPC.
Expand All @@ -29,8 +44,27 @@ func (a *rpcService) List(
) (*connect.Response[pb.ListResponse], error) {
log.Println("Server received List request")

locks := a.locker.StatusAll()

names := a.devices.Names()
infos := make([]*pb.DeviceInfo, 0, len(names))

for _, name := range names {
info := &pb.DeviceInfo{Name: name}

if explicit := locks[name].Explicit; explicit != nil {
info.Lock = &pb.LockInfo{
Owner: explicit.Owner,
LockedAt: explicit.LockedAt.Unix(),
ExpiresAt: explicit.ExpiresAt.Unix(),
}
}

infos = append(infos, info)
}

res := connect.NewResponse(&pb.ListResponse{
Devices: a.devices.Names(),
Devices: infos,
})

log.Print("List-RPC finished")
Expand Down Expand Up @@ -121,6 +155,87 @@ func (a *rpcService) Details(
return res, nil
}

// Lock is the handler for the Lock RPC.
func (a *rpcService) Lock(
_ context.Context,
req *connect.Request[pb.LockRequest],
) (*connect.Response[pb.LockResponse], error) {
log.Println("Server received Lock request")

device := req.Msg.GetDevice()
user := userFromHeader(req.Header())

if _, ok := a.devices[device]; !ok {
return nil, connect.NewError(
connect.CodeInvalidArgument,
fmt.Errorf("device %q: %w", device, dut.ErrDeviceNotFound),
)
}

dur := time.Duration(req.Msg.GetDurationSeconds()) * time.Second

info, lockErr := a.locker.Lock(device, user, dur)
if lockErr != nil {
switch {
case errors.Is(lockErr, dutagent.ErrWrongOwner):
return nil, connect.NewError(connect.CodeFailedPrecondition, lockErr)
case errors.Is(lockErr, dutagent.ErrInvalidDuration):
return nil, connect.NewError(connect.CodeInvalidArgument, lockErr)
default:
return nil, connect.NewError(connect.CodeInternal, lockErr)
}
}

var expiresAt int64
if !info.ExpiresAt.IsZero() {
expiresAt = info.ExpiresAt.Unix()
}

res := connect.NewResponse(&pb.LockResponse{
Device: device,
Owner: info.Owner,
LockedAt: info.LockedAt.Unix(),
ExpiresAt: expiresAt,
})

log.Print("Lock-RPC finished")

return res, nil
}

// Unlock is the handler for the Unlock RPC.
func (a *rpcService) Unlock(
_ context.Context,
req *connect.Request[pb.UnlockRequest],
) (*connect.Response[pb.UnlockResponse], error) {
log.Println("Server received Unlock request")

device := req.Msg.GetDevice()
user := userFromHeader(req.Header())

var err error
if req.Msg.GetForce() {
err = a.locker.ForceClearLock(device)
} else {
err = a.locker.ClearLock(device, user)
}

if err != nil {
switch {
case errors.Is(err, dutagent.ErrWrongOwner):
return nil, connect.NewError(connect.CodePermissionDenied, err)
case errors.Is(err, dutagent.ErrNotLocked):
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
default:
return nil, connect.NewError(connect.CodeInternal, err)
}
}

log.Print("Unlock-RPC finished")

return connect.NewResponse(&pb.UnlockResponse{}), nil
}

// streamAdapter decouples a connect.BidiStream to the dutagent.Stream interface.
type streamAdapter struct {
inner *connect.BidiStream[pb.RunRequest, pb.RunResponse]
Expand All @@ -139,9 +254,20 @@ func (a *rpcService) Run(
fsmArgs := runCmdArgs{
stream: &streamAdapter{inner: stream},
deviceList: a.devices,
locker: a.locker,
user: userFromHeader(stream.RequestHeader()),
}

_, err := fsm.Run(ctx, fsmArgs, receiveCommandRPC)
finalArgs, err := fsm.Run(ctx, fsmArgs, receiveCommandRPC)

// Safety net for error paths that short-circuit the FSM before
// releaseAutoLock runs. Delegating to the state function keeps the
// cleanup logic in one place. The state tolerates ErrNotLocked, so a
// happy-path call (where the FSM already released the auto-lock) is a
// harmless no-op.
if finalArgs.cmdMsg != nil {
releaseAutoLock(ctx, finalArgs) //nolint:errcheck // state never returns an error
}

var connectErr *connect.Error
if err != nil && !errors.As(err, &connectErr) {
Expand Down
217 changes: 217 additions & 0 deletions cmds/dutagent/rpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
// Copyright 2025 Blindspot Software
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main

import (
"context"
"strings"
"testing"
"time"

"connectrpc.com/connect"
"github.com/BlindspotSoftware/dutctl/internal/dutagent"
"github.com/BlindspotSoftware/dutctl/pkg/dut"
"github.com/BlindspotSoftware/dutctl/pkg/lock"

pb "github.com/BlindspotSoftware/dutctl/protobuf/gen/dutctl/v1"
)

func newTestService() *rpcService {
return &rpcService{
devices: dut.Devlist{"devA": dut.Device{}, "otherDev": dut.Device{}},
locker: dutagent.NewLocker(),
}
}

func lockReq(device, user string, durSeconds int64) *connect.Request[pb.LockRequest] {
req := connect.NewRequest(&pb.LockRequest{Device: device, DurationSeconds: durSeconds})
if user != "" {
req.Header().Set(lock.UserHeader, user)
}

return req
}

func unlockReq(device, user string, force bool) *connect.Request[pb.UnlockRequest] {
req := connect.NewRequest(&pb.UnlockRequest{Device: device, Force: force})
if user != "" {
req.Header().Set(lock.UserHeader, user)
}

return req
}

func TestLockRPC(t *testing.T) {
svc := newTestService()

res, err := svc.Lock(context.Background(), lockReq("devA", "alice", 60))
if err != nil {
t.Fatalf("Lock: unexpected error: %v", err)
}

if res.Msg.GetOwner() != "alice" {
t.Errorf("owner = %q, want alice", res.Msg.GetOwner())
}

if res.Msg.GetExpiresAt() == 0 {
t.Error("expires_at = 0, want a timed expiry")
}
}

func TestLockRPCUnknownDevice(t *testing.T) {
svc := newTestService()

_, err := svc.Lock(context.Background(), lockReq("ghost", "alice", 60))
if connect.CodeOf(err) != connect.CodeInvalidArgument {
t.Errorf("code = %v, want InvalidArgument", connect.CodeOf(err))
}
}

func TestLockRPCDifferentOwnerRejected(t *testing.T) {
svc := newTestService()

if _, err := svc.Lock(context.Background(), lockReq("devA", "alice", 60)); err != nil {
t.Fatalf("first Lock: %v", err)
}

_, err := svc.Lock(context.Background(), lockReq("devA", "bob", 60))
if connect.CodeOf(err) != connect.CodeFailedPrecondition {
t.Errorf("code = %v, want FailedPrecondition", connect.CodeOf(err))
}
}

func TestLockRPCMissingUserHeader(t *testing.T) {
svc := newTestService()

first, err := svc.Lock(context.Background(), lockReq("devA", "", 60))
if err != nil {
t.Fatalf("Lock: %v", err)
}

if !strings.HasPrefix(first.Msg.GetOwner(), "unknown-") {
t.Errorf("owner = %q, want unknown-<rand> prefix", first.Msg.GetOwner())
}

// A second anonymous caller must get a distinct identity so they cannot
// satisfy CheckAccess against the first caller's lock.
second, err := svc.Lock(context.Background(), lockReq("otherDev", "", 60))
if err != nil {
t.Fatalf("second Lock: %v", err)
}

if first.Msg.GetOwner() == second.Msg.GetOwner() {
t.Errorf("two anonymous callers shared identity %q", first.Msg.GetOwner())
}
}

func TestUnlockRPC(t *testing.T) {
svc := newTestService()

if _, err := svc.Lock(context.Background(), lockReq("devA", "alice", 60)); err != nil {
t.Fatalf("Lock: %v", err)
}

if _, err := svc.Unlock(context.Background(), unlockReq("devA", "alice", false)); err != nil {
t.Errorf("Unlock by owner: %v", err)
}
}

func TestUnlockRPCWrongOwner(t *testing.T) {
svc := newTestService()

if _, err := svc.Lock(context.Background(), lockReq("devA", "alice", 60)); err != nil {
t.Fatalf("Lock: %v", err)
}

_, err := svc.Unlock(context.Background(), unlockReq("devA", "bob", false))
if connect.CodeOf(err) != connect.CodePermissionDenied {
t.Errorf("code = %v, want PermissionDenied", connect.CodeOf(err))
}
}

func TestUnlockRPCNotLocked(t *testing.T) {
svc := newTestService()

_, err := svc.Unlock(context.Background(), unlockReq("devA", "alice", false))
if connect.CodeOf(err) != connect.CodeFailedPrecondition {
t.Errorf("code = %v, want FailedPrecondition", connect.CodeOf(err))
}
}

func TestUnlockRPCForce(t *testing.T) {
svc := newTestService()

if _, err := svc.Lock(context.Background(), lockReq("devA", "alice", 60)); err != nil {
t.Fatalf("Lock: %v", err)
}

if _, err := svc.Unlock(context.Background(), unlockReq("devA", "bob", true)); err != nil {
t.Errorf("forced Unlock by non-owner: %v", err)
}
}

func TestLockRPCZeroDurationRejected(t *testing.T) {
svc := newTestService()

for _, dur := range []int64{0, -5} {
_, err := svc.Lock(context.Background(), lockReq("devA", "alice", dur))
if connect.CodeOf(err) != connect.CodeInvalidArgument {
t.Errorf("dur=%d: code = %v, want InvalidArgument", dur, connect.CodeOf(err))
}
}
}

func TestListRPCHidesAutoOnlyLock(t *testing.T) {
svc := newTestService()

if _, err := svc.locker.AutoLock("devA", "alice"); err != nil {
t.Fatalf("AutoLock: %v", err)
}

res, err := svc.List(context.Background(), connect.NewRequest(&pb.ListRequest{}))
if err != nil {
t.Fatalf("List: %v", err)
}

var got *pb.LockInfo

for _, info := range res.Msg.GetDevices() {
if info.GetName() == "devA" {
got = info.GetLock()
}
}

if got != nil {
t.Errorf("auto-only lock surfaced in List: %+v, want no lock info", got)
}
}

func TestListRPCExplicitShadowsAuto(t *testing.T) {
svc := newTestService()

if _, err := svc.locker.AutoLock("devA", "alice"); err != nil {
t.Fatalf("AutoLock: %v", err)
}

if _, err := svc.locker.Lock("devA", "alice", time.Minute); err != nil {
t.Fatalf("Lock: %v", err)
}

res, err := svc.List(context.Background(), connect.NewRequest(&pb.ListRequest{}))
if err != nil {
t.Fatalf("List: %v", err)
}

var got *pb.LockInfo

for _, info := range res.Msg.GetDevices() {
if info.GetName() == "devA" {
got = info.GetLock()
}
}

if got.GetExpiresAt() == 0 {
t.Error("expected explicit-slot expires_at to win, got 0")
}
}
Loading