Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 33 additions & 207 deletions pkg/ssh/config/parse_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ import (
"fmt"
"io"
"os"
"os/user"
"path/filepath"
"sort"
"strconv"
"strings"

"github.com/deckhouse/lib-dhctl/pkg/log"
Expand All @@ -32,6 +30,8 @@ import (
"github.com/deckhouse/lib-connection/pkg/settings"
"github.com/deckhouse/lib-connection/pkg/ssh/utils"
"github.com/deckhouse/lib-connection/pkg/ssh/utils/terminal"
"github.com/deckhouse/lib-connection/pkg/utils/defaults"
"github.com/deckhouse/lib-connection/pkg/utils/env"
)

const (
Expand Down Expand Up @@ -84,7 +84,7 @@ type Flags struct {
forceNoPrivateKeys bool

flagSet *flag.FlagSet
envExtractor *envExtractor
envExtractor *env.Extractor
}

func (f *Flags) IsConflictBetweenFlags() error {
Expand All @@ -106,7 +106,7 @@ func (f *Flags) IsConflictBetweenFlags() error {

func (f *Flags) FillDefaults() error {
if len(f.PrivateKeysPaths) == 0 && !f.forceNoPrivateKeys {
home, err := getHomeDir(f.envExtractor)
home, err := defaults.HomeDir(f.envExtractor)
if err != nil {
return err
}
Expand Down Expand Up @@ -152,38 +152,35 @@ func (f *Flags) RewriteFromEnvs() error {
return notInitializedError("envExtractor")
}

f.envExtractor.Bool(ForceNoPrivateKeysEnv, &f.forceNoPrivateKeys)

if !f.forceNoPrivateKeys {
isSet := f.envExtractor.Strings(AgentPrivateKeysEnv, &f.PrivateKeysPaths)

if isSet && len(f.PrivateKeysPaths) == 0 {
f.forceNoPrivateKeys = true
}
}
privateKeysVal := env.NewVar(AgentPrivateKeysEnv, &f.PrivateKeysPaths)

err := f.envExtractor.ExtractAllVars(
env.NewVar(BastionPortEnv, &f.BastionPort),
env.NewVar(PortEnv, &f.Port),
env.NewVar(ForceNoPrivateKeysEnv, &f.forceNoPrivateKeys),
privateKeysVal,
env.NewVar(BastionHostEnv, &f.BastionHost),
env.NewVar(BastionUserEnv, &f.BastionUser),
env.NewVar(UserEnv, &f.User),
env.NewVar(HostsEnv, &f.Hosts),
env.NewVar(ExtraArgsEnv, &f.ExtraArgs),
env.NewVar(ConnectionConfigEnv, &f.ConnectionConfigPath),
env.NewVar(LegacyModeEnv, &f.ForceLegacy),
env.NewVar(ModernModeEnv, &f.ForceModern),
env.NewVar(AskBastionPasswordEnv, &f.AskBastionPass),
env.NewVar(AskSudoPasswordEnv, &f.AskSudoPass),
)

f.envExtractor.String(BastionHostEnv, &f.BastionHost)
f.envExtractor.String(BastionUserEnv, &f.BastionUser)
if _, err := f.envExtractor.Int(BastionPortEnv, &f.BastionPort); err != nil {
if err != nil {
return err
}

f.envExtractor.String(UserEnv, &f.User)
f.envExtractor.Strings(HostsEnv, &f.Hosts)
if _, err := f.envExtractor.Int(PortEnv, &f.Port); err != nil {
return err
if !f.forceNoPrivateKeys {
if privateKeysVal.Present && len(f.PrivateKeysPaths) == 0 {
f.forceNoPrivateKeys = true
}
}

f.envExtractor.String(ExtraArgsEnv, &f.ExtraArgs)

f.envExtractor.String(ConnectionConfigEnv, &f.ConnectionConfigPath)

f.envExtractor.Bool(LegacyModeEnv, &f.ForceLegacy)
f.envExtractor.Bool(ModernModeEnv, &f.ForceModern)

f.envExtractor.Bool(AskBastionPasswordEnv, &f.AskBastionPass)
f.envExtractor.Bool(AskSudoPasswordEnv, &f.AskSudoPass)

return nil
}

Expand Down Expand Up @@ -218,7 +215,7 @@ func (f *Flags) userExtractor() func() (string, error) {
return *currentUser, nil
}

userName, err := getCurrentUser(f.envExtractor)
userName, err := defaults.CurrentUserName(f.envExtractor)
if err != nil {
return "", err
}
Expand All @@ -231,15 +228,14 @@ func (f *Flags) userExtractor() func() (string, error) {

type (
AskPasswordFunc func(promt string) ([]byte, error)
EnvsLookupFunc func(name string) (string, bool)
PrivateKeyExtractorFunc func(path string, logger log.Logger) (password string, err error)
)

type FlagsParser struct {
envsPrefix string
ask AskPasswordFunc
sett settings.Settings
envsLookup EnvsLookupFunc
envsLookup env.EnvsLookupFunc

// extractPrivateKey
// custom extract content and password for private key file
Expand Down Expand Up @@ -274,9 +270,7 @@ func NewFlagsParser(sett settings.Settings) *FlagsParser {
// This method trim right all _ ang - symbols and spaces left and right
// By default parser add _ after prefix for all env vars
func (p *FlagsParser) WithEnvsPrefix(envsPrefix string) *FlagsParser {
envsPrefix = strings.TrimSpace(envsPrefix)
envsPrefix = strings.TrimRight(envsPrefix, "_-")
p.envsPrefix = envsPrefix
p.envsPrefix = env.SimplifyPrefix(envsPrefix)
return p
}

Expand All @@ -290,7 +284,7 @@ func (p *FlagsParser) WithAsk(ask AskPasswordFunc) *FlagsParser {
return p
}

func (p *FlagsParser) WithEnvsLookup(lookup EnvsLookupFunc) *FlagsParser {
func (p *FlagsParser) WithEnvsLookup(lookup env.EnvsLookupFunc) *FlagsParser {
if govalue.Nil(lookup) {
p.sett.Logger().WarnF("Envs lookup function is nil. Skip set ask function.")
return p
Expand Down Expand Up @@ -607,8 +601,8 @@ func (p *FlagsParser) ParseFlagsAndExtractConfig(arguments []string, set *flag.F
return p.ExtractConfigAfterParse(flags, opts...)
}

func (p *FlagsParser) envsExtractor() *envExtractor {
return newEnvExtractor(p.envsPrefix, p.envsLookup)
func (p *FlagsParser) envsExtractor() *env.Extractor {
return env.NewExtractor(p.envsPrefix, p.envsLookup)
}

func (p *FlagsParser) readPrivateKeysFromFlags(flags *Flags, logger log.Logger) ([]AgentPrivateKey, error) {
Expand Down Expand Up @@ -671,67 +665,6 @@ func (p *FlagsParser) getPasswordsFromUser(flags *Flags) (*passwordsFromUser, er
return res, nil
}

func getHomeDir(extractor *envExtractor) (string, error) {
home := ""

extractor.StringWithoutPrefix("HOME", &home)

if home == "" {
var err error
home, err = os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("Cannot get user home dir: %w", err)
}

if home == "" {
return "", fmt.Errorf("Cannot get user home dir: empty after call os.UserHomeDir")
}
}

var err error
home, err = filepath.Abs(home)
if err != nil {
return "", fmt.Errorf("Cannot get absolute path of home directory: %w", err)
}

stat, err := os.Stat(home)
if err != nil {
return "", fmt.Errorf("Cannot get user home dir stat: %w", err)
}

if !stat.IsDir() {
return "", fmt.Errorf("Cannot get user home dir: '%s' not a directory", home)
}

return home, nil
}

// getCurrentUser
// returns current user name
// first attempt get user from env
// can be call multiple times because user.Current() cache user info
func getCurrentUser(extractor *envExtractor) (string, error) {
userName := ""

extractor.StringWithoutPrefix("USER", &userName)

if userName != "" {
return userName, nil
}

currentUser, err := user.Current()
if err != nil {
return "", fmt.Errorf("cannot get current user: %w", err)
}

userName = currentUser.Username
if userName == "" {
return "", fmt.Errorf("Cannot get current user: empty after call user.Current")
}

return userName, nil
}

func fileReader(path string, fileType string) (io.ReadCloser, error) {
fullPath, err := filepath.Abs(path)
if err != nil {
Expand All @@ -750,113 +683,6 @@ func fileReader(path string, fileType string) (io.ReadCloser, error) {
return os.Open(fullPath)
}

type envExtractor struct {
prefix string
lookupFunc func(string) (string, bool)
}

func newEnvExtractor(prefix string, lookupFunc EnvsLookupFunc) *envExtractor {
return &envExtractor{
prefix: prefix,
lookupFunc: lookupFunc,
}
}

func (e *envExtractor) NameWithPrefix(name string) string {
if e.prefix != "" {
name = fmt.Sprintf("%s_%s", e.prefix, name)
}

return name
}

func (e *envExtractor) AddEnvToUsage(usage string, envName string) string {
if envName == "" {
return usage
}

return fmt.Sprintf("%s (Can rewrite with %s env)", usage, e.NameWithPrefix(envName))
}

func (e *envExtractor) Var(name string) (string, bool) {
return e.lookupFunc(e.NameWithPrefix(name))
}

func (e *envExtractor) VarWithoutPrefix(name string) (string, bool) {
return e.lookupFunc(name)
}

func (e *envExtractor) Int(name string, destination *int) (bool, error) {
strVar, ok := e.Var(name)
if !ok {
return false, nil
}

value, err := strconv.Atoi(strVar)
if err != nil {
return false, fmt.Errorf("Cannot convert '%s' to int for %s: %w", strVar, e.NameWithPrefix(name), err)
}

*destination = value

return true, nil
}

func (e *envExtractor) StringWithoutPrefix(name string, destination *string) bool {
strVar, ok := e.VarWithoutPrefix(name)
if !ok {
return false
}

*destination = strVar

return true
}

func (e *envExtractor) String(name string, destination *string) bool {
strVar, ok := e.Var(name)
if !ok {
return false
}

*destination = strVar

return true
}

func (e *envExtractor) Strings(name string, destination *[]string) bool {
valsStr, ok := e.Var(name)
if !ok {
return false
}

valsSplit := strings.Split(valsStr, ",")
vals := make([]string, 0, len(valsSplit))
for _, v := range valsSplit {
if strings.TrimSpace(v) != "" {
vals = append(vals, v)
}
}

*destination = vals

return true
}

// Bool
// returns that env is set
func (e *envExtractor) Bool(name string, destination *bool) bool {
strVar, ok := e.Var(name)
if !ok {
return false
}
value := strVar != ""

*destination = value

return true
}

func terminalPrivateKeyPasswordExtractor(path string, defaultPassword []byte, logger log.Logger) (string, error) {
_, password, err := utils.ParseSSHPrivateKeyFile(path, string(defaultPassword), logger)

Expand Down
1 change: 1 addition & 0 deletions pkg/ssh/config/parse_flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ func TestParseFlags(t *testing.T) {
envs: map[string]string{
"DHCTL_SSH_HOSTS": "192.168.0.2,192.168.0.3",
"DHCTL_SSH_MODERN_MODE": "true",
"DHCTL_SSH_LEGACY_MODE": "false",
"DHCTL_SSH_BASTION_PORT": "2200",
},

Expand Down
Loading