diff --git a/cmd/main.go b/cmd/main.go index ee9cf15..680f1c9 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -64,6 +64,35 @@ func main() { } rootCmd.SilenceUsage = true + // Add --server flag to connect directly to a server + rootCmd.Flags().StringP("server", "s", "", "Connect directly to the specified server alias") + + // Handle direct SSH connection when server flag is provided + rootCmd.RunE = func(cmd *cobra.Command, args []string) error { + serverAlias, _ := cmd.Flags().GetString("server") + + // If server alias provided as flag or argument, connect directly + if serverAlias == "" && len(args) > 0 { + serverAlias = args[0] + } + + if serverAlias != "" { + server, err := serverService.GetServerByAlias(serverAlias) + if err != nil { + log.Errorw("failed to get server", "error", err, "alias", serverAlias) + return fmt.Errorf("failed to connect to server '%s': %w", serverAlias, err) + } + if server == nil { + log.Errorw("server not found", "alias", serverAlias) + return fmt.Errorf("server '%s' not found in SSH config", serverAlias) + } + return serverService.SSH(server.Alias) + } + + // Otherwise, run the TUI + return tui.Run() + } + if err := rootCmd.Execute(); err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) diff --git a/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go b/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go index 37e8004..aab1221 100644 --- a/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go +++ b/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go @@ -73,6 +73,30 @@ func (r *Repository) ListServers(query string) ([]domain.Server, error) { return r.filterServers(servers, query), nil } +// GetServerByAlias returns a server by its alias, or nil if not found. +func (r *Repository) GetServerByAlias(alias string) (*domain.Server, error) { + cfg, err := r.loadConfig() + if err != nil { + return nil, fmt.Errorf("failed to load config: %w", err) + } + + servers := r.toDomainServer(cfg) + metadata, err := r.metadataManager.loadAll() + if err != nil { + r.logger.Warnf("Failed to load metadata: %v", err) + metadata = make(map[string]ServerMetadata) + } + servers = r.mergeMetadata(servers, metadata) + + for i := range servers { + if servers[i].Alias == alias { + return &servers[i], nil + } + } + + return nil, nil +} + // AddServer adds a new server to the SSH config. func (r *Repository) AddServer(server domain.Server) error { cfg, err := r.loadConfig() diff --git a/internal/core/ports/repositories.go b/internal/core/ports/repositories.go index 133e4f8..4b1ba3a 100644 --- a/internal/core/ports/repositories.go +++ b/internal/core/ports/repositories.go @@ -18,6 +18,7 @@ import "github.com/Adembc/lazyssh/internal/core/domain" type ServerRepository interface { ListServers(query string) ([]domain.Server, error) + GetServerByAlias(alias string) (*domain.Server, error) UpdateServer(server domain.Server, newServer domain.Server) error AddServer(server domain.Server) error DeleteServer(server domain.Server) error diff --git a/internal/core/ports/services.go b/internal/core/ports/services.go index 2407269..31c286b 100644 --- a/internal/core/ports/services.go +++ b/internal/core/ports/services.go @@ -22,6 +22,7 @@ import ( type ServerService interface { ListServers(query string) ([]domain.Server, error) + GetServerByAlias(alias string) (*domain.Server, error) UpdateServer(server domain.Server, newServer domain.Server) error AddServer(server domain.Server) error DeleteServer(server domain.Server) error diff --git a/internal/core/services/server_service.go b/internal/core/services/server_service.go index 7926bc6..bb5c60f 100644 --- a/internal/core/services/server_service.go +++ b/internal/core/services/server_service.go @@ -73,6 +73,11 @@ func (s *serverService) ListServers(query string) ([]domain.Server, error) { return servers, nil } +// GetServerByAlias returns a server by its alias, or nil if not found. +func (s *serverService) GetServerByAlias(alias string) (*domain.Server, error) { + return s.serverRepository.GetServerByAlias(alias) +} + // validateServer performs core validation of server fields. func validateServer(srv domain.Server) error { if strings.TrimSpace(srv.Alias) == "" {