From 7f9067179d1cbc1cecd9a61636bfe019ac07c284 Mon Sep 17 00:00:00 2001 From: Kiran Yadav Date: Thu, 15 Jan 2026 16:36:21 +0530 Subject: [PATCH 1/2] feat(ui): add active sessions list with kill and prefill support --- internal/adapters/ui/handlers.go | 69 +++- internal/adapters/ui/server_form.go | 160 ++++---- internal/adapters/ui/tui.go | 10 +- internal/core/domain/server.go | 1 + internal/core/ports/services.go | 3 + internal/core/services/server_service.go | 448 +++++++++++++++++++++++ 6 files changed, 617 insertions(+), 74 deletions(-) diff --git a/internal/adapters/ui/handlers.go b/internal/adapters/ui/handlers.go index 897e053..d3a49a9 100644 --- a/internal/adapters/ui/handlers.go +++ b/internal/adapters/ui/handlers.go @@ -86,6 +86,11 @@ func (t *tui) handleGlobalKeys(event *tcell.EventKey) *tcell.EventKey { case 'x': t.handleStopForwarding() return nil + case 'K': + if t.isActiveListFocused() { + t.handleKillActiveSessions() + } + return nil case 'j': t.handleNavigateDown() return nil @@ -172,6 +177,8 @@ func (t *tui) handleSearchInput(query string) { filtered, _ := t.serverService.ListServers(query) sortServersForUI(filtered, t.sortMode) t.serverList.UpdateServers(filtered) + active, _ := t.serverService.ListActiveSessions(query) + t.activeList.UpdateServers(active) if len(filtered) == 0 { t.details.ShowEmpty() } @@ -235,6 +242,19 @@ func (t *tui) handleServerSelectionChange(server domain.Server) { } func (t *tui) handleServerAdd() { + if t.isActiveListFocused() { + if server, ok := t.activeList.GetSelectedServer(); ok { + form := NewServerForm(ServerFormAdd, nil). + SetPrefill(&server). + SetApp(t.app). + SetVersionInfo(t.version, t.commit). + OnSave(t.handleServerSave). + OnCancel(t.handleFormCancel) + t.app.SetRoot(form, true) + return + } + } + form := NewServerForm(ServerFormAdd, nil). SetApp(t.app). SetVersionInfo(t.version, t.commit). @@ -244,6 +264,10 @@ func (t *tui) handleServerAdd() { } func (t *tui) handleServerEdit() { + if t.isActiveListFocused() { + t.showStatusTemp("Edit disabled for active sessions. Use 'a' to add.") + return + } if server, ok := t.serverList.GetSelectedServer(); ok { form := NewServerForm(ServerFormEdit, &server). SetApp(t.app). @@ -258,7 +282,15 @@ func (t *tui) handleServerSave(server domain.Server, original *domain.Server) { var err error if original != nil { // Edit mode - err = t.serverService.UpdateServer(*original, server) + base := *original + if resolved, ok, resolveErr := t.serverService.ResolveConfigServer(*original); resolveErr != nil { + err = resolveErr + } else if ok { + base = resolved + } + if err == nil { + err = t.serverService.UpdateServer(base, server) + } } else { // Add mode err = t.serverService.AddServer(server) @@ -332,9 +364,17 @@ func (t *tui) handleRefreshBackground() { }) return } + active, activeErr := t.serverService.ListActiveSessions(q) + if activeErr != nil { + t.app.QueueUpdateDraw(func() { + t.showStatusTempColor(fmt.Sprintf("Active refresh failed: %v", activeErr), "#FF6B6B") + }) + return + } sortServersForUI(servers, t.sortMode) t.app.QueueUpdateDraw(func() { t.serverList.UpdateServers(servers) + t.activeList.UpdateServers(active) // Try to restore selection if still valid if prevIdx >= 0 && prevIdx < t.serverList.List.GetItemCount() { t.serverList.SetCurrentItem(prevIdx) @@ -581,6 +621,8 @@ func (t *tui) refreshServerList() { filtered, _ := t.serverService.ListServers(query) sortServersForUI(filtered, t.sortMode) t.serverList.UpdateServers(filtered) + active, _ := t.serverService.ListActiveSessions(query) + t.activeList.UpdateServers(active) } func (t *tui) returnToMain() { @@ -629,3 +671,28 @@ func (t *tui) handleStopForwarding() { }() } } + +// Terminate active SSH sessions for the selected server. +func (t *tui) handleKillActiveSessions() { + if server, ok := t.activeList.GetSelectedServer(); ok { + go func(selected domain.Server) { + count, err := t.serverService.KillActiveSessions(selected) + t.app.QueueUpdateDraw(func() { + if err != nil { + t.showStatusTempColor("Failed to terminate SSH sessions: "+err.Error(), "#FF6B6B") + } else { + t.showStatusTemp(fmt.Sprintf("Terminated %d SSH session(s) for %s", count, selected.Alias)) + } + t.refreshServerList() + }) + }(server) + } +} + +func (t *tui) isActiveListFocused() bool { + if t.app == nil || t.activeList == nil { + return false + } + focus := t.app.GetFocus() + return focus == t.activeList || focus == t.activeList.List +} diff --git a/internal/adapters/ui/server_form.go b/internal/adapters/ui/server_form.go index 286b47f..ee474b8 100644 --- a/internal/adapters/ui/server_form.go +++ b/internal/adapters/ui/server_form.go @@ -55,6 +55,7 @@ type ServerForm struct { tabAbbrev map[string]string // Abbreviated tab names for narrow views mode ServerFormMode original *domain.Server + prefill *domain.Server onSave func(domain.Server, *domain.Server) onCancel func() app *tview.Application // Reference to app for showing modals @@ -113,6 +114,11 @@ func NewServerForm(mode ServerFormMode, original *domain.Server) *ServerForm { return form } +func (sf *ServerForm) SetPrefill(server *domain.Server) *ServerForm { + sf.prefill = server + return sf +} + func (sf *ServerForm) build() { // Create header sf.header = NewAppHeader(sf.version, sf.commit, RepoURL) @@ -1065,78 +1071,13 @@ func (sf *ServerForm) validateAllFields() bool { // getDefaultValues returns default form values based on mode func (sf *ServerForm) getDefaultValues() ServerFormData { if sf.mode == ServerFormEdit && sf.original != nil { - return ServerFormData{ - Alias: sf.original.Alias, - Host: sf.original.Host, - User: sf.original.User, - Port: fmt.Sprint(sf.original.Port), - Key: strings.Join(sf.original.IdentityFiles, ", "), - Tags: strings.Join(sf.original.Tags, ", "), - ProxyJump: sf.original.ProxyJump, - ProxyCommand: sf.original.ProxyCommand, - RemoteCommand: sf.original.RemoteCommand, - RequestTTY: sf.original.RequestTTY, - SessionType: sf.original.SessionType, - ConnectTimeout: sf.original.ConnectTimeout, - ConnectionAttempts: sf.original.ConnectionAttempts, - BindAddress: sf.original.BindAddress, - BindInterface: sf.original.BindInterface, - AddressFamily: sf.original.AddressFamily, - ExitOnForwardFailure: sf.original.ExitOnForwardFailure, - IPQoS: sf.original.IPQoS, - // Hostname canonicalization - CanonicalizeHostname: sf.original.CanonicalizeHostname, - CanonicalDomains: sf.original.CanonicalDomains, - CanonicalizeFallbackLocal: sf.original.CanonicalizeFallbackLocal, - CanonicalizeMaxDots: sf.original.CanonicalizeMaxDots, - CanonicalizePermittedCNAMEs: sf.original.CanonicalizePermittedCNAMEs, - GatewayPorts: sf.original.GatewayPorts, - LocalForward: strings.Join(sf.original.LocalForward, ", "), - RemoteForward: strings.Join(sf.original.RemoteForward, ", "), - DynamicForward: strings.Join(sf.original.DynamicForward, ", "), - ClearAllForwardings: sf.original.ClearAllForwardings, - // Public key - PubkeyAuthentication: sf.original.PubkeyAuthentication, - IdentitiesOnly: sf.original.IdentitiesOnly, - // SSH Agent - AddKeysToAgent: sf.original.AddKeysToAgent, - IdentityAgent: sf.original.IdentityAgent, - // Password & Interactive - PasswordAuthentication: sf.original.PasswordAuthentication, - KbdInteractiveAuthentication: sf.original.KbdInteractiveAuthentication, - NumberOfPasswordPrompts: sf.original.NumberOfPasswordPrompts, - // Advanced - PreferredAuthentications: sf.original.PreferredAuthentications, - ForwardAgent: sf.original.ForwardAgent, - ForwardX11: sf.original.ForwardX11, - ForwardX11Trusted: sf.original.ForwardX11Trusted, - ControlMaster: sf.original.ControlMaster, - ControlPath: sf.original.ControlPath, - ControlPersist: sf.original.ControlPersist, - ServerAliveInterval: sf.original.ServerAliveInterval, - ServerAliveCountMax: sf.original.ServerAliveCountMax, - Compression: sf.original.Compression, - TCPKeepAlive: sf.original.TCPKeepAlive, - BatchMode: sf.original.BatchMode, - StrictHostKeyChecking: sf.original.StrictHostKeyChecking, - UserKnownHostsFile: sf.original.UserKnownHostsFile, - HostKeyAlgorithms: sf.original.HostKeyAlgorithms, - PubkeyAcceptedAlgorithms: sf.original.PubkeyAcceptedAlgorithms, - HostbasedAcceptedAlgorithms: sf.original.HostbasedAcceptedAlgorithms, - MACs: sf.original.MACs, - Ciphers: sf.original.Ciphers, - KexAlgorithms: sf.original.KexAlgorithms, - VerifyHostKeyDNS: sf.original.VerifyHostKeyDNS, - UpdateHostKeys: sf.original.UpdateHostKeys, - HashKnownHosts: sf.original.HashKnownHosts, - VisualHostKey: sf.original.VisualHostKey, - LocalCommand: sf.original.LocalCommand, - PermitLocalCommand: sf.original.PermitLocalCommand, - EscapeChar: sf.original.EscapeChar, - SendEnv: strings.Join(sf.original.SendEnv, ", "), - SetEnv: strings.Join(sf.original.SetEnv, ", "), - LogLevel: sf.original.LogLevel, - } + return serverToFormData(*sf.original) + } + if sf.mode == ServerFormAdd && sf.prefill != nil { + data := serverToFormData(*sf.prefill) + data.Alias = "" + data.Tags = "" + return data } // For new servers, use empty values instead of SSH defaults // SSH defaults will be applied by the SSH client if values are not specified @@ -1236,6 +1177,81 @@ func (sf *ServerForm) getDefaultValues() ServerFormData { } } +func serverToFormData(server domain.Server) ServerFormData { + return ServerFormData{ + Alias: server.Alias, + Host: server.Host, + User: server.User, + Port: fmt.Sprint(server.Port), + Key: strings.Join(server.IdentityFiles, ", "), + Tags: strings.Join(server.Tags, ", "), + ProxyJump: server.ProxyJump, + ProxyCommand: server.ProxyCommand, + RemoteCommand: server.RemoteCommand, + RequestTTY: server.RequestTTY, + SessionType: server.SessionType, + ConnectTimeout: server.ConnectTimeout, + ConnectionAttempts: server.ConnectionAttempts, + BindAddress: server.BindAddress, + BindInterface: server.BindInterface, + AddressFamily: server.AddressFamily, + ExitOnForwardFailure: server.ExitOnForwardFailure, + IPQoS: server.IPQoS, + // Hostname canonicalization + CanonicalizeHostname: server.CanonicalizeHostname, + CanonicalDomains: server.CanonicalDomains, + CanonicalizeFallbackLocal: server.CanonicalizeFallbackLocal, + CanonicalizeMaxDots: server.CanonicalizeMaxDots, + CanonicalizePermittedCNAMEs: server.CanonicalizePermittedCNAMEs, + GatewayPorts: server.GatewayPorts, + LocalForward: strings.Join(server.LocalForward, ", "), + RemoteForward: strings.Join(server.RemoteForward, ", "), + DynamicForward: strings.Join(server.DynamicForward, ", "), + ClearAllForwardings: server.ClearAllForwardings, + // Public key + PubkeyAuthentication: server.PubkeyAuthentication, + IdentitiesOnly: server.IdentitiesOnly, + // SSH Agent + AddKeysToAgent: server.AddKeysToAgent, + IdentityAgent: server.IdentityAgent, + // Password & Interactive + PasswordAuthentication: server.PasswordAuthentication, + KbdInteractiveAuthentication: server.KbdInteractiveAuthentication, + NumberOfPasswordPrompts: server.NumberOfPasswordPrompts, + // Advanced + PreferredAuthentications: server.PreferredAuthentications, + ForwardAgent: server.ForwardAgent, + ForwardX11: server.ForwardX11, + ForwardX11Trusted: server.ForwardX11Trusted, + ControlMaster: server.ControlMaster, + ControlPath: server.ControlPath, + ControlPersist: server.ControlPersist, + ServerAliveInterval: server.ServerAliveInterval, + ServerAliveCountMax: server.ServerAliveCountMax, + Compression: server.Compression, + TCPKeepAlive: server.TCPKeepAlive, + BatchMode: server.BatchMode, + StrictHostKeyChecking: server.StrictHostKeyChecking, + UserKnownHostsFile: server.UserKnownHostsFile, + HostKeyAlgorithms: server.HostKeyAlgorithms, + PubkeyAcceptedAlgorithms: server.PubkeyAcceptedAlgorithms, + HostbasedAcceptedAlgorithms: server.HostbasedAcceptedAlgorithms, + MACs: server.MACs, + Ciphers: server.Ciphers, + KexAlgorithms: server.KexAlgorithms, + VerifyHostKeyDNS: server.VerifyHostKeyDNS, + UpdateHostKeys: server.UpdateHostKeys, + HashKnownHosts: server.HashKnownHosts, + VisualHostKey: server.VisualHostKey, + LocalCommand: server.LocalCommand, + PermitLocalCommand: server.PermitLocalCommand, + EscapeChar: server.EscapeChar, + SendEnv: strings.Join(server.SendEnv, ", "), + SetEnv: strings.Join(server.SetEnv, ", "), + LogLevel: server.LogLevel, + } +} + // createBasicForm creates the Basic configuration tab func (sf *ServerForm) createBasicForm() { form := tview.NewForm() diff --git a/internal/adapters/ui/tui.go b/internal/adapters/ui/tui.go index d938e6f..10b4a5e 100644 --- a/internal/adapters/ui/tui.go +++ b/internal/adapters/ui/tui.go @@ -38,6 +38,7 @@ type tui struct { header *AppHeader searchBar *SearchBar serverList *ServerList + activeList *ServerList details *ServerDetails statusBar *tview.TextView @@ -98,6 +99,10 @@ func (t *tui) buildComponents() *tui { t.serverList = NewServerList(). OnSelectionChange(t.handleServerSelectionChange). OnReturnToSearch(t.handleReturnToSearch) + t.activeList = NewServerList(). + SetTitle(" Active Sessions (K: Kill) "). + OnSelectionChange(t.handleServerSelectionChange). + OnReturnToSearch(t.handleReturnToSearch) t.details = NewServerDetails() t.statusBar = NewStatusBar() @@ -110,7 +115,8 @@ func (t *tui) buildComponents() *tui { func (t *tui) buildLayout() *tui { t.left = tview.NewFlex().SetDirection(tview.FlexRow). AddItem(t.searchBar, 3, 0, false). - AddItem(t.serverList, 0, 1, true) + AddItem(t.serverList, 0, 1, true). + AddItem(t.activeList, 10, 0, false) right := tview.NewFlex().SetDirection(tview.FlexRow). AddItem(t.details, 0, 1, false) @@ -136,6 +142,8 @@ func (t *tui) loadInitialData() *tui { sortServersForUI(servers, t.sortMode) t.updateListTitle() t.serverList.UpdateServers(servers) + active, _ := t.serverService.ListActiveSessions("") + t.activeList.UpdateServers(active) return t } diff --git a/internal/core/domain/server.go b/internal/core/domain/server.go index c23b301..7872743 100644 --- a/internal/core/domain/server.go +++ b/internal/core/domain/server.go @@ -27,6 +27,7 @@ type Server struct { LastSeen time.Time PinnedAt time.Time SSHCount int + ActivePID int // Additional SSH config fields // Connection and proxy settings diff --git a/internal/core/ports/services.go b/internal/core/ports/services.go index 2407269..807949d 100644 --- a/internal/core/ports/services.go +++ b/internal/core/ports/services.go @@ -22,6 +22,7 @@ import ( type ServerService interface { ListServers(query string) ([]domain.Server, error) + ListActiveSessions(query string) ([]domain.Server, error) UpdateServer(server domain.Server, newServer domain.Server) error AddServer(server domain.Server) error DeleteServer(server domain.Server) error @@ -31,5 +32,7 @@ type ServerService interface { StartForward(alias string, extraArgs []string) (int, error) StopForwarding(alias string) error IsForwarding(alias string) bool + KillActiveSessions(server domain.Server) (int, error) + ResolveConfigServer(server domain.Server) (domain.Server, bool, error) Ping(server domain.Server) (bool, time.Duration, error) } diff --git a/internal/core/services/server_service.go b/internal/core/services/server_service.go index 7926bc6..249cbe3 100644 --- a/internal/core/services/server_service.go +++ b/internal/core/services/server_service.go @@ -41,6 +41,18 @@ type serverService struct { forwards map[string][]*os.Process } +type activeSSHSession struct { + alias string + host string + user string + port int + identityFiles []string + localForward []string + remoteForward []string + dynamicForward []string + pid int +} + // NewServerService creates a new instance of serverService. func NewServerService(logger *zap.SugaredLogger, sr ports.ServerRepository) ports.ServerService { return &serverService{ @@ -73,6 +85,85 @@ func (s *serverService) ListServers(query string) ([]domain.Server, error) { return servers, nil } +// ListActiveSessions returns currently running SSH sessions as list entries. +func (s *serverService) ListActiveSessions(query string) ([]domain.Server, error) { + activeSessions, err := s.listActiveSSHSessions() + if err != nil { + return nil, err + } + + configured, err := s.serverRepository.ListServers("") + if err != nil { + return nil, err + } + + aliasIndex := make(map[string]domain.Server, len(configured)) + hostIndex := make(map[string]domain.Server, len(configured)) + for _, server := range configured { + aliasIndex[strings.ToLower(server.Alias)] = server + for _, alias := range server.Aliases { + aliasIndex[strings.ToLower(alias)] = server + } + if server.Host != "" { + hostIndex[strings.ToLower(server.Host)] = server + } + } + + query = strings.ToLower(strings.TrimSpace(query)) + entries := make([]domain.Server, 0, len(activeSessions)) + for _, session := range activeSessions { + entry := domain.Server{} + if session.alias != "" { + if server, ok := aliasIndex[strings.ToLower(session.alias)]; ok { + entry = server + } + } + if entry.Alias == "" && session.host != "" { + if server, ok := hostIndex[strings.ToLower(session.host)]; ok { + entry = server + } + } + + if entry.Alias == "" { + if session.alias != "" { + entry.Alias = session.alias + } else { + entry.Alias = "unknown" + } + entry.Aliases = []string{entry.Alias} + } + + if session.host != "" { + entry.Host = session.host + } else if entry.Host == "" { + entry.Host = "unknown" + } + if session.user != "" { + entry.User = session.user + } + if session.port > 0 { + entry.Port = session.port + } else if entry.Port == 0 { + entry.Port = 22 + } + + entry.IdentityFiles = mergeIdentityFiles(entry.IdentityFiles, session.identityFiles) + entry.LocalForward = mergeForwardSpecs(entry.LocalForward, session.localForward) + entry.RemoteForward = mergeForwardSpecs(entry.RemoteForward, session.remoteForward) + entry.DynamicForward = mergeForwardSpecs(entry.DynamicForward, session.dynamicForward) + entry.ActivePID = session.pid + entry.LastSeen = time.Now() + entry.Tags = append([]string{"active"}, entry.Tags...) + + if query != "" && !matchesServerQuery(entry, query) { + continue + } + entries = append(entries, entry) + } + + return entries, nil +} + // validateServer performs core validation of server fields. func validateServer(srv domain.Server) error { if strings.TrimSpace(srv.Alias) == "" { @@ -309,6 +400,75 @@ func (s *serverService) IsForwarding(alias string) bool { return len(s.forwards[alias]) > 0 } +// KillActiveSessions terminates active SSH sessions matching the server. +func (s *serverService) KillActiveSessions(server domain.Server) (int, error) { + sessions, err := s.listActiveSSHSessions() + if err != nil { + return 0, err + } + + if server.ActivePID > 0 { + for _, session := range sessions { + if session.pid == server.ActivePID { + if err := killPID(session.pid); err != nil { + return 0, err + } + return 1, nil + } + } + return 0, fmt.Errorf("active ssh session not found") + } + + var pids []int + for _, session := range sessions { + if matchSessionForServer(server, session) && session.pid > 0 { + pids = append(pids, session.pid) + } + } + if len(pids) == 0 { + return 0, fmt.Errorf("no active ssh sessions found") + } + + var errs []error + killed := 0 + for _, pid := range pids { + if killErr := killPID(pid); killErr != nil { + errs = append(errs, fmt.Errorf("pid %d: %w", pid, killErr)) + continue + } + killed++ + } + + if len(errs) > 0 { + return killed, fmt.Errorf("failed to terminate sessions: %v", errs) + } + return killed, nil +} + +// ResolveConfigServer attempts to map a server entry to a configured server. +func (s *serverService) ResolveConfigServer(server domain.Server) (domain.Server, bool, error) { + servers, err := s.serverRepository.ListServers("") + if err != nil { + return domain.Server{}, false, err + } + for _, candidate := range servers { + if strings.EqualFold(candidate.Alias, server.Alias) { + return candidate, true, nil + } + for _, alias := range candidate.Aliases { + if strings.EqualFold(alias, server.Alias) { + return candidate, true, nil + } + } + if server.Host != "" && candidate.Host != "" { + if strings.EqualFold(candidate.Host, server.Host) { + return candidate, true, nil + } + } + } + return domain.Server{}, false, nil +} + // Ping checks if the server is reachable on its SSH port. func (s *serverService) Ping(server domain.Server) (bool, time.Duration, error) { start := time.Now() @@ -377,3 +537,291 @@ func resolveSSHDestination(alias string) (string, int, bool) { } return host, port, true } + +func mergeIdentityFiles(existing []string, incoming []string) []string { + if len(incoming) == 0 { + return existing + } + seen := make(map[string]struct{}, len(existing)) + for _, v := range existing { + seen[v] = struct{}{} + } + for _, v := range incoming { + if v == "" { + continue + } + if _, ok := seen[v]; ok { + continue + } + existing = append(existing, v) + seen[v] = struct{}{} + } + return existing +} + +func matchSessionForServer(server domain.Server, session activeSSHSession) bool { + if session.alias != "" { + if strings.EqualFold(session.alias, server.Alias) { + return true + } + for _, alias := range server.Aliases { + if strings.EqualFold(session.alias, alias) { + return true + } + } + } + if session.host != "" && server.Host != "" { + return strings.EqualFold(session.host, server.Host) + } + return false +} + +func killPID(pid int) error { + proc, findErr := os.FindProcess(pid) + if findErr != nil { + return findErr + } + return proc.Signal(syscall.SIGTERM) +} + +func mergeForwardSpecs(existing []string, incoming []string) []string { + if len(incoming) == 0 { + return existing + } + seen := make(map[string]struct{}, len(existing)) + for _, v := range existing { + seen[v] = struct{}{} + } + for _, v := range incoming { + if v == "" { + continue + } + if _, ok := seen[v]; ok { + continue + } + existing = append(existing, v) + seen[v] = struct{}{} + } + return existing +} + +func matchesServerQuery(server domain.Server, query string) bool { + fields := []string{ + strings.ToLower(server.Host), + strings.ToLower(server.User), + strings.ToLower(server.Alias), + } + for _, tag := range server.Tags { + fields = append(fields, strings.ToLower(tag)) + } + if len(server.Aliases) > 0 { + for _, alias := range server.Aliases { + fields = append(fields, strings.ToLower(alias)) + } + } + + for _, field := range fields { + if strings.Contains(field, query) { + return true + } + } + return false +} + +func (s *serverService) listActiveSSHSessions() ([]activeSSHSession, error) { + cmd := exec.Command("ps", "-ax", "-o", "pid=", "-o", "comm=", "-o", "args=") + out, err := cmd.Output() + if err != nil { + return nil, err + } + + sessions := make([]activeSSHSession, 0) + scanner := bufio.NewScanner(strings.NewReader(string(out))) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + pid, comm, args := splitPSLine(line) + if comm != "ssh" { + continue + } + parts := strings.Fields(args) + if len(parts) == 0 { + continue + } + session := parseSSHArgs(parts) + if session.alias == "" { + continue + } + session.pid = pid + sessions = append(sessions, session) + } + if err := scanner.Err(); err != nil { + return nil, err + } + return sessions, nil +} + +func splitPSLine(line string) (pid int, comm string, args string) { + fields := strings.Fields(line) + if len(fields) < 2 { + return 0, "", "" + } + if n, err := strconv.Atoi(fields[0]); err == nil { + pid = n + } + comm = fields[1] + start := strings.Index(line, comm) + if start == -1 { + if len(fields) > 2 { + args = strings.Join(fields[2:], " ") + } + return pid, comm, strings.TrimSpace(args) + } + args = strings.TrimSpace(line[start+len(comm):]) + return pid, comm, args +} + +func parseSSHArgs(args []string) activeSSHSession { + if len(args) == 0 { + return activeSSHSession{} + } + user := "" + port := 0 + dest := "" + identityFiles := make([]string, 0) + localForward := make([]string, 0) + remoteForward := make([]string, 0) + dynamicForward := make([]string, 0) + + start := 1 + if args[0] != "ssh" { + start = 0 + } + for i := start; i < len(args); i++ { + arg := args[i] + if arg == "--" { + if i+1 < len(args) { + dest = args[i+1] + } + break + } + if strings.HasPrefix(arg, "-") { + if sshOptionConsumesValue(arg) { + val := "" + if len(arg) > 2 { + val = arg[2:] + } else if i+1 < len(args) { + val = args[i+1] + i++ + } + switch { + case strings.HasPrefix(arg, "-p"): + if n, err := strconv.Atoi(val); err == nil { + port = n + } + case strings.HasPrefix(arg, "-l"): + if val != "" { + user = val + } + case strings.HasPrefix(arg, "-i"): + if val != "" { + identityFiles = append(identityFiles, val) + } + case strings.HasPrefix(arg, "-L"): + if val != "" { + localForward = append(localForward, val) + } + case strings.HasPrefix(arg, "-R"): + if val != "" { + remoteForward = append(remoteForward, val) + } + case strings.HasPrefix(arg, "-D"): + if val != "" { + dynamicForward = append(dynamicForward, val) + } + case strings.HasPrefix(arg, "-o"): + lowerVal := strings.ToLower(val) + if strings.HasPrefix(lowerVal, "user=") { + user = val[len("user="):] + } + if strings.HasPrefix(lowerVal, "port=") { + if n, err := strconv.Atoi(val[len("port="):]); err == nil { + port = n + } + } + if strings.HasPrefix(lowerVal, "identityfile=") { + identity := val[len("identityfile="):] + if identity != "" { + identityFiles = append(identityFiles, identity) + } + } + if strings.HasPrefix(lowerVal, "localforward=") { + spec := val[len("localforward="):] + if spec != "" { + localForward = append(localForward, spec) + } + } + if strings.HasPrefix(lowerVal, "remoteforward=") { + spec := val[len("remoteforward="):] + if spec != "" { + remoteForward = append(remoteForward, spec) + } + } + if strings.HasPrefix(lowerVal, "dynamicforward=") { + spec := val[len("dynamicforward="):] + if spec != "" { + dynamicForward = append(dynamicForward, spec) + } + } + } + } + continue + } + dest = arg + break + } + + if dest == "" { + return activeSSHSession{} + } + + host := dest + if at := strings.LastIndex(dest, "@"); at > -1 { + if user == "" { + user = dest[:at] + } + host = dest[at+1:] + } + if host == "" { + host = "unknown" + } + if port == 0 { + port = 22 + } + + return activeSSHSession{ + alias: dest, + host: host, + user: user, + port: port, + identityFiles: identityFiles, + localForward: localForward, + remoteForward: remoteForward, + dynamicForward: dynamicForward, + } +} + +func sshOptionConsumesValue(opt string) bool { + base := opt + if len(opt) > 2 && strings.HasPrefix(opt, "-") && !strings.HasPrefix(opt, "--") { + base = opt[:2] + } + switch base { + case "-p", "-l", "-i", "-o", "-F", "-b", "-c", "-D", "-E", "-e", "-I", "-J", "-L", "-m", "-O", "-Q", "-R", "-S", "-W", "-w": + return true + default: + return false + } +} From c08abe591ab2322441fa92445aefd233d2ec774a Mon Sep 17 00:00:00 2001 From: Kiran Yadav Date: Thu, 15 Jan 2026 16:43:00 +0530 Subject: [PATCH 2/2] refactor: restructure ssh arg parsing and unknown labels --- internal/adapters/ui/tui.go | 2 +- internal/core/services/server_service.go | 182 +++++++++++++---------- 2 files changed, 102 insertions(+), 82 deletions(-) diff --git a/internal/adapters/ui/tui.go b/internal/adapters/ui/tui.go index 10b4a5e..addf548 100644 --- a/internal/adapters/ui/tui.go +++ b/internal/adapters/ui/tui.go @@ -100,9 +100,9 @@ func (t *tui) buildComponents() *tui { OnSelectionChange(t.handleServerSelectionChange). OnReturnToSearch(t.handleReturnToSearch) t.activeList = NewServerList(). - SetTitle(" Active Sessions (K: Kill) "). OnSelectionChange(t.handleServerSelectionChange). OnReturnToSearch(t.handleReturnToSearch) + t.activeList.List.SetTitle(" Active Sessions (K: Kill) ") t.details = NewServerDetails() t.statusBar = NewStatusBar() diff --git a/internal/core/services/server_service.go b/internal/core/services/server_service.go index 249cbe3..4ed4c6d 100644 --- a/internal/core/services/server_service.go +++ b/internal/core/services/server_service.go @@ -53,6 +53,8 @@ type activeSSHSession struct { pid int } +const unknownLabel = "unknown" + // NewServerService creates a new instance of serverService. func NewServerService(logger *zap.SugaredLogger, sr ports.ServerRepository) ports.ServerService { return &serverService{ @@ -128,7 +130,7 @@ func (s *serverService) ListActiveSessions(query string) ([]domain.Server, error if session.alias != "" { entry.Alias = session.alias } else { - entry.Alias = "unknown" + entry.Alias = unknownLabel } entry.Aliases = []string{entry.Alias} } @@ -136,7 +138,7 @@ func (s *serverService) ListActiveSessions(query string) ([]domain.Server, error if session.host != "" { entry.Host = session.host } else if entry.Host == "" { - entry.Host = "unknown" + entry.Host = unknownLabel } if session.user != "" { entry.User = session.user @@ -687,13 +689,54 @@ func parseSSHArgs(args []string) activeSSHSession { if len(args) == 0 { return activeSSHSession{} } - user := "" - port := 0 - dest := "" - identityFiles := make([]string, 0) - localForward := make([]string, 0) - remoteForward := make([]string, 0) - dynamicForward := make([]string, 0) + state := parseSSHOptions(args) + if state.dest == "" { + return activeSSHSession{} + } + + host := state.dest + if at := strings.LastIndex(state.dest, "@"); at > -1 { + if state.user == "" { + state.user = state.dest[:at] + } + host = state.dest[at+1:] + } + if host == "" { + host = unknownLabel + } + if state.port == 0 { + state.port = 22 + } + + return activeSSHSession{ + alias: state.dest, + host: host, + user: state.user, + port: state.port, + identityFiles: state.identityFiles, + localForward: state.localForward, + remoteForward: state.remoteForward, + dynamicForward: state.dynamicForward, + } +} + +type sshParseState struct { + user string + port int + dest string + identityFiles []string + localForward []string + remoteForward []string + dynamicForward []string +} + +func parseSSHOptions(args []string) sshParseState { + state := sshParseState{ + identityFiles: make([]string, 0), + localForward: make([]string, 0), + remoteForward: make([]string, 0), + dynamicForward: make([]string, 0), + } start := 1 if args[0] != "ssh" { @@ -703,113 +746,90 @@ func parseSSHArgs(args []string) activeSSHSession { arg := args[i] if arg == "--" { if i+1 < len(args) { - dest = args[i+1] + state.dest = args[i+1] } break } if strings.HasPrefix(arg, "-") { if sshOptionConsumesValue(arg) { - val := "" - if len(arg) > 2 { - val = arg[2:] - } else if i+1 < len(args) { - val = args[i+1] - i++ - } + val, nextIdx := sshOptionValue(arg, args, i) + i = nextIdx switch { case strings.HasPrefix(arg, "-p"): if n, err := strconv.Atoi(val); err == nil { - port = n + state.port = n } case strings.HasPrefix(arg, "-l"): if val != "" { - user = val + state.user = val } case strings.HasPrefix(arg, "-i"): if val != "" { - identityFiles = append(identityFiles, val) + state.identityFiles = append(state.identityFiles, val) } case strings.HasPrefix(arg, "-L"): if val != "" { - localForward = append(localForward, val) + state.localForward = append(state.localForward, val) } case strings.HasPrefix(arg, "-R"): if val != "" { - remoteForward = append(remoteForward, val) + state.remoteForward = append(state.remoteForward, val) } case strings.HasPrefix(arg, "-D"): if val != "" { - dynamicForward = append(dynamicForward, val) + state.dynamicForward = append(state.dynamicForward, val) } case strings.HasPrefix(arg, "-o"): - lowerVal := strings.ToLower(val) - if strings.HasPrefix(lowerVal, "user=") { - user = val[len("user="):] - } - if strings.HasPrefix(lowerVal, "port=") { - if n, err := strconv.Atoi(val[len("port="):]); err == nil { - port = n - } - } - if strings.HasPrefix(lowerVal, "identityfile=") { - identity := val[len("identityfile="):] - if identity != "" { - identityFiles = append(identityFiles, identity) - } - } - if strings.HasPrefix(lowerVal, "localforward=") { - spec := val[len("localforward="):] - if spec != "" { - localForward = append(localForward, spec) - } - } - if strings.HasPrefix(lowerVal, "remoteforward=") { - spec := val[len("remoteforward="):] - if spec != "" { - remoteForward = append(remoteForward, spec) - } - } - if strings.HasPrefix(lowerVal, "dynamicforward=") { - spec := val[len("dynamicforward="):] - if spec != "" { - dynamicForward = append(dynamicForward, spec) - } - } + applySSHOptionValue(val, &state) } } continue } - dest = arg + state.dest = arg break } + return state +} - if dest == "" { - return activeSSHSession{} - } - - host := dest - if at := strings.LastIndex(dest, "@"); at > -1 { - if user == "" { - user = dest[:at] - } - host = dest[at+1:] +func sshOptionValue(arg string, args []string, idx int) (string, int) { + if len(arg) > 2 { + return arg[2:], idx } - if host == "" { - host = "unknown" - } - if port == 0 { - port = 22 + if idx+1 < len(args) { + return args[idx+1], idx + 1 } + return "", idx +} - return activeSSHSession{ - alias: dest, - host: host, - user: user, - port: port, - identityFiles: identityFiles, - localForward: localForward, - remoteForward: remoteForward, - dynamicForward: dynamicForward, +func applySSHOptionValue(val string, state *sshParseState) { + lowerVal := strings.ToLower(val) + switch { + case strings.HasPrefix(lowerVal, "user="): + state.user = val[len("user="):] + case strings.HasPrefix(lowerVal, "port="): + if n, err := strconv.Atoi(val[len("port="):]); err == nil { + state.port = n + } + case strings.HasPrefix(lowerVal, "identityfile="): + identity := val[len("identityfile="):] + if identity != "" { + state.identityFiles = append(state.identityFiles, identity) + } + case strings.HasPrefix(lowerVal, "localforward="): + spec := val[len("localforward="):] + if spec != "" { + state.localForward = append(state.localForward, spec) + } + case strings.HasPrefix(lowerVal, "remoteforward="): + spec := val[len("remoteforward="):] + if spec != "" { + state.remoteForward = append(state.remoteForward, spec) + } + case strings.HasPrefix(lowerVal, "dynamicforward="): + spec := val[len("dynamicforward="):] + if spec != "" { + state.dynamicForward = append(state.dynamicForward, spec) + } } }