Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 61 additions & 30 deletions libs/apps/vite/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/databricks/cli/libs/cmdio"
Expand Down Expand Up @@ -75,17 +76,19 @@ type prioritizedMessage struct {
}

type Bridge struct {
ctx context.Context
w *databricks.WorkspaceClient
appName string
tunnelConn *websocket.Conn
ctx context.Context
w *databricks.WorkspaceClient
appName string
// Atomic because reconnects swap the connection while the writer goroutine reads it.
tunnelConn atomic.Pointer[websocket.Conn]
hmrConn *websocket.Conn
tunnelID string
tunnelWriteChan chan prioritizedMessage
stopChan chan struct{}
stop func()
httpClient *http.Client
connectionRequests chan *BridgeMessage
stdinLines chan string // Lines read by the persistent stdin reader, consumed by connection prompts
port int
keepaliveDone chan struct{} // Signals keepalive goroutine to stop on reconnect
keepaliveMu sync.Mutex // Protects keepaliveDone
Expand Down Expand Up @@ -116,16 +119,17 @@ func NewBridge(ctx context.Context, w *databricks.WorkspaceClient, appName strin
stopChan: make(chan struct{}),
tunnelWriteChan: make(chan prioritizedMessage, 100), // Buffered channel for async writes
connectionRequests: make(chan *BridgeMessage, 10),
stdinLines: make(chan string),
port: port,
autoApprove: autoApprove,
}

b.stop = sync.OnceFunc(func() {
close(b.stopChan)

if b.tunnelConn != nil {
_ = b.tunnelConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
b.tunnelConn.Close()
if conn := b.tunnelConn.Load(); conn != nil {
_ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
conn.Close()
}

if b.hmrConn != nil {
Expand Down Expand Up @@ -219,7 +223,7 @@ func (vb *Bridge) connectToTunnel(appDomain *url.URL) error {
return nil
})

vb.tunnelConn = conn
vb.setTunnelConn(conn)

// Start keepalive ping goroutine (stop existing one first if any)
vb.keepaliveMu.Lock()
Expand All @@ -235,6 +239,13 @@ func (vb *Bridge) connectToTunnel(appDomain *url.URL) error {
return nil
}

// setTunnelConn installs a new tunnel connection, closing the old one so reconnects don't leak it.
func (vb *Bridge) setTunnelConn(conn *websocket.Conn) {
if old := vb.tunnelConn.Swap(conn); old != nil {
old.Close()
}
}

// ConnectToTunnelWithRetry attempts to connect to the tunnel with exponential backoff.
// This handles cases where the app isn't fully ready yet (e.g., right after deployment).
func (vb *Bridge) ConnectToTunnelWithRetry(appDomain *url.URL) error {
Expand Down Expand Up @@ -346,7 +357,8 @@ func (vb *Bridge) tunnelWriter(ctx context.Context) error {
case <-vb.stopChan:
return nil
case msg := <-vb.tunnelWriteChan:
if err := vb.tunnelConn.WriteMessage(msg.messageType, msg.data); err != nil {
// Load per message so writes follow a reconnect to the new connection.
if err := vb.tunnelConn.Load().WriteMessage(msg.messageType, msg.data); err != nil {
log.Errorf(vb.ctx, "[vite_bridge] Failed to write message: %v", err)
return fmt.Errorf("failed to write to tunnel: %w", err)
}
Expand All @@ -364,7 +376,7 @@ func (vb *Bridge) handleTunnelMessages(ctx context.Context) error {
default:
}

_, message, err := vb.tunnelConn.ReadMessage()
_, message, err := vb.tunnelConn.Load().ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure) {
cmdio.LogString(vb.ctx, "🔄 Tunnel closed, reconnecting...")
Expand Down Expand Up @@ -408,8 +420,15 @@ func (vb *Bridge) handleMessage(msg *BridgeMessage) error {
return nil

case "connection:request":
vb.connectionRequests <- msg
return nil
// The consumer may be blocked on a stdin prompt or gone after stop; a bare send could hang the tunnel reader.
select {
case vb.connectionRequests <- msg:
return nil
case <-vb.stopChan:
return nil
case <-time.After(wsWriteTimeout):
return errors.New("connection request queue full, dropping request")
}

case "fetch":
go func(fetchMsg BridgeMessage) {
Expand Down Expand Up @@ -446,6 +465,25 @@ func (vb *Bridge) handleMessage(msg *BridgeMessage) error {
}
}

// readStdinLines forwards lines from r to stdinLines, closing the channel on read
// error so prompts fail instead of hanging. A single persistent reader keeps a
// timed-out prompt from leaking a goroutine that swallows the next prompt's answer.
func (vb *Bridge) readStdinLines(r io.Reader) {
reader := bufio.NewReader(r)
for {
line, err := reader.ReadString('\n')
if err != nil {
close(vb.stdinLines)
return
}
select {
case vb.stdinLines <- line:
case <-vb.stopChan:
return
}
}
}

func (vb *Bridge) handleConnectionRequest(msg *BridgeMessage) error {
cmdio.LogString(vb.ctx, "")
cmdio.LogString(vb.ctx, "🔔 Connection Request")
Expand All @@ -458,25 +496,14 @@ func (vb *Bridge) handleConnectionRequest(msg *BridgeMessage) error {
} else {
cmdio.LogString(vb.ctx, " Approve this connection? (y/n)")

// Read from stdin with timeout to prevent indefinite blocking
inputChan := make(chan string, 1)
errChan := make(chan error, 1)

go func() {
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
errChan <- err
return
}
inputChan <- input
}()

select {
case input := <-inputChan:
case input, ok := <-vb.stdinLines:
if !ok {
return errors.New("failed to read user input: stdin closed")
}
approved = strings.ToLower(strings.TrimSpace(input)) == "y"
case err := <-errChan:
return fmt.Errorf("failed to read user input: %w", err)
case <-vb.stopChan:
return nil
case <-time.After(BridgeConnTimeout):
// Default to denying after timeout
cmdio.LogString(vb.ctx, "⏱️ Timeout waiting for response, denying connection")
Expand Down Expand Up @@ -907,7 +934,7 @@ func (vb *Bridge) Start() error {
readyChan := make(chan error, 1)
go func() {
for vb.tunnelID == "" {
_, message, err := vb.tunnelConn.ReadMessage()
_, message, err := vb.tunnelConn.Load().ReadMessage()
if err != nil {
readyChan <- err
return
Expand Down Expand Up @@ -953,6 +980,10 @@ func (vb *Bridge) Start() error {
return nil
})

if !vb.autoApprove {
go vb.readStdinLines(os.Stdin)
}

// Connection request handler - not in errgroup to avoid blocking other handlers
go func() {
for {
Expand Down
Loading
Loading