diff --git a/config/profile.go b/config/profile.go index 1f9468fd4..d3b28b139 100644 --- a/config/profile.go +++ b/config/profile.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "os" "path" "path/filepath" "reflect" @@ -10,6 +11,7 @@ import ( "time" "github.com/Masterminds/semver/v3" + "github.com/creativeprojects/clog" "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/restic" "github.com/creativeprojects/resticprofile/shell" @@ -17,6 +19,7 @@ import ( "github.com/creativeprojects/resticprofile/util/bools" "github.com/mitchellh/mapstructure" "golang.org/x/exp/maps" + "golang.org/x/exp/slices" ) // resticVersion14 is the semver of restic 0.14 (the version where several flag names were changed) @@ -249,13 +252,19 @@ type ScheduleBaseSection struct { SchedulePriority string `mapstructure:"schedule-priority" show:"noshow" default:"background" enum:"background;standard" description:"Set the priority at which the schedule is run"` ScheduleLockMode string `mapstructure:"schedule-lock-mode" show:"noshow" default:"default" enum:"default;fail;ignore" description:"Specify how locks are used when running on schedule - see https://creativeprojects.github.io/resticprofile/schedules/configuration/"` ScheduleLockWait time.Duration `mapstructure:"schedule-lock-wait" show:"noshow" examples:"150s;15m;30m;45m;1h;2h30m" description:"Set the maximum time to wait for acquiring locks when running on schedule"` + ScheduleEnvCapture []string `mapstructure:"schedule-capture-environment" show:"noshow" default:"RESTIC_*" description:"Set names (or glob expressions) of environment variables to capture during schedule creation. The captured environment is applied prior to \"profile.env\" when running the schedule. Whether capturing is supported depends on the type of scheduler being used (supported in \"systemd\" and \"launchd\")"` } func (s *ScheduleBaseSection) setRootPath(_ *Profile, _ string) { s.ScheduleLog = fixPath(s.ScheduleLog, expandEnv, expandUserHome) } -func (s *ScheduleBaseSection) GetSchedule() *ScheduleBaseSection { return s } +func (s *ScheduleBaseSection) GetSchedule() *ScheduleBaseSection { + if s != nil && s.ScheduleEnvCapture == nil { + s.ScheduleEnvCapture = []string{"RESTIC_*"} + } + return s +} // CopySection contains the destination parameters for a copy command type CopySection struct { @@ -479,6 +488,16 @@ func (p *Profile) ResolveConfiguration() { r.resolve(p) } + // Resolve environment variable name case (p.Environment keys are all lower case due to config parser) + // Custom env variables (without a match in os.Environ) are changed to uppercase (like before in wrapper) + osEnv := util.NewFoldingEnvironment(os.Environ()...) + for name, value := range p.Environment { + if newName := osEnv.ResolveName(strings.ToUpper(name)); newName != name { + delete(p.Environment, name) + p.Environment[newName] = value + } + } + // Deal with "path" & "tag" flags if p.Backup != nil { // Copy tags from backup if tag is set to boolean true @@ -736,9 +755,29 @@ func (p *Profile) Schedules() []*ScheduleConfig { for name, section := range sections { if s := section.GetSchedule(); len(s.Schedule) > 0 { - env := map[string]string{} - for key, value := range p.Environment { - env[key] = value.Value() + env := util.NewDefaultEnvironment() + + if len(s.ScheduleEnvCapture) > 0 { + // Capture OS env + env.SetValues(os.Environ()...) + + // Capture profile env + for key, value := range p.Environment { + env.Put(key, value.Value()) + } + + for index, key := range env.Names() { + matched := slices.ContainsFunc(s.ScheduleEnvCapture, func(pattern string) bool { + matched, err := filepath.Match(pattern, key) + if err != nil && index == 0 { + clog.Tracef("env not matched with invalid glob expression '%s': %s", pattern, err.Error()) + } + return matched + }) + if !matched { + env.Remove(key) + } + } } config := &ScheduleConfig{ @@ -746,7 +785,7 @@ func (p *Profile) Schedules() []*ScheduleConfig { SubTitle: name, Schedules: s.Schedule, Permission: s.SchedulePermission, - Environment: env, + Environment: env.Values(), Log: s.ScheduleLog, LockMode: s.ScheduleLockMode, LockWait: s.ScheduleLockWait, diff --git a/config/profile_test.go b/config/profile_test.go index ab93b5cc7..2b03669e9 100644 --- a/config/profile_test.go +++ b/config/profile_test.go @@ -791,16 +791,22 @@ func TestSchedules(t *testing.T) { testConfig := func(command string, scheduled bool) string { schedule := "" if scheduled { - schedule = `schedule = "@hourly" -schedule-log = "` + logFile + `"` + schedule = ` + schedule = "@hourly" + schedule-log = "` + logFile + `"` } config := ` -[profile] -initialize = true + [profile] + initialize = true -[profile.%s] -%s + [profile.env] + TEST_VAR="non-captured-test-value" + RESTIC_VAR="profile-only-value" + RESTIC_ANY2="123" + + [profile.%s] + %s ` return fmt.Sprintf(config, command, schedule) } @@ -808,19 +814,28 @@ initialize = true sections := NewProfile(nil, "").SchedulableCommands() require.GreaterOrEqual(t, len(sections), 6) + require.NoError(t, os.Setenv("RESTIC_ANY1", "xyz")) + require.NoError(t, os.Setenv("RESTIC_ANY2", "xyz")) + for _, command := range sections { t.Run(command, func(t *testing.T) { // Check that schedule is supported - profile, err := getProfile("toml", testConfig(command, true), "profile", "") + profile, err := getResolvedProfile("toml", testConfig(command, true), "profile") require.NoError(t, err) assert.NotNil(t, profile) config := profile.Schedules() - assert.Len(t, config, 1) - assert.Equal(t, config[0].SubTitle, command) - assert.Len(t, config[0].Schedules, 1) - assert.Equal(t, config[0].Schedules[0], "@hourly") - assert.Equal(t, config[0].Log, path.Join(constants.TemporaryDirMarker, "rp.log")) + require.Len(t, config, 1) + + schedule := config[0] + assert.Equal(t, command, schedule.SubTitle) + assert.Equal(t, []string{"@hourly"}, schedule.Schedules) + assert.Equal(t, path.Join(constants.TemporaryDirMarker, "rp.log"), schedule.Log) + assert.Equal(t, map[string]string{ + "RESTIC_VAR": "profile-only-value", + "RESTIC_ANY1": "xyz", + "RESTIC_ANY2": "123", + }, util.NewDefaultEnvironment(schedule.Environment...).ValuesAsMap()) // Check that schedule is optional profile, err = getProfile("toml", testConfig(command, false), "profile", "") diff --git a/config/schedule_config.go b/config/schedule_config.go index c77759e36..23747f998 100644 --- a/config/schedule_config.go +++ b/config/schedule_config.go @@ -28,7 +28,7 @@ type ScheduleConfig struct { WorkingDirectory string Command string Arguments []string - Environment map[string]string + Environment []string JobDescription string TimerDescription string Priority string diff --git a/config/schedule_config_test.go b/config/schedule_config_test.go index fcac46f84..ea77525a2 100644 --- a/config/schedule_config_test.go +++ b/config/schedule_config_test.go @@ -18,7 +18,7 @@ func TestScheduleProperties(t *testing.T) { WorkingDirectory: "home", Command: "command", Arguments: []string{"1", "2"}, - Environment: map[string]string{"test": "dev"}, + Environment: []string{"test=dev"}, JobDescription: "job", TimerDescription: "timer", Log: "log.txt", @@ -36,7 +36,7 @@ func TestScheduleProperties(t *testing.T) { assert.Equal(t, "command", schedule.Command) assert.Equal(t, "home", schedule.WorkingDirectory) assert.ElementsMatch(t, []string{"1", "2"}, schedule.Arguments) - assert.Equal(t, "dev", schedule.Environment["test"]) + assert.Equal(t, []string{"test=dev"}, schedule.Environment) assert.Equal(t, "background", schedule.GetPriority()) // default value assert.Equal(t, "log.txt", schedule.Log) assert.Equal(t, ScheduleLockModeDefault, schedule.GetLockMode()) diff --git a/schedule/handler_darwin.go b/schedule/handler_darwin.go index 68bfcf999..610063f72 100644 --- a/schedule/handler_darwin.go +++ b/schedule/handler_darwin.go @@ -17,6 +17,7 @@ import ( "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/dial" "github.com/creativeprojects/resticprofile/term" + "github.com/creativeprojects/resticprofile/util" "github.com/spf13/afero" "golang.org/x/exp/slices" "howett.net/plist" @@ -167,10 +168,10 @@ func (h *HandlerLaunchd) getLaunchdJob(job *config.ScheduleConfig, schedules []* logfile = name + ".log" } - // Add path to env variables - env := make(map[string]string, 1) - if pathEnv := os.Getenv("PATH"); pathEnv != "" { - env["PATH"] = pathEnv + // Format schedule env, adding PATH if not yet provided by the schedule config + env := util.NewDefaultEnvironment(job.Environment...) + if !env.Has("PATH") { + env.Put("PATH", os.Getenv("PATH")) } lowPriorityIO := true @@ -188,7 +189,7 @@ func (h *HandlerLaunchd) getLaunchdJob(job *config.ScheduleConfig, schedules []* StandardErrorPath: logfile, WorkingDirectory: job.WorkingDirectory, StartCalendarInterval: getCalendarIntervalsFromSchedules(schedules), - EnvironmentVariables: env, + EnvironmentVariables: env.ValuesAsMap(), Nice: nice, ProcessType: priorityValues[job.GetPriority()], LowPriorityIO: lowPriorityIO, diff --git a/schedule/handler_darwin_test.go b/schedule/handler_darwin_test.go index 8ba784776..25ade997c 100644 --- a/schedule/handler_darwin_test.go +++ b/schedule/handler_darwin_test.go @@ -4,6 +4,8 @@ package schedule import ( "bytes" + "fmt" + "os" "path" "testing" @@ -162,6 +164,27 @@ func TestLaunchdJobLog(t *testing.T) { } } +func TestLaunchdJobPreservesEnv(t *testing.T) { + pathEnv := os.Getenv("PATH") + fixtures := []struct { + environment []string + expected map[string]string + }{ + {expected: map[string]string{"PATH": pathEnv}}, + {environment: []string{"path=extra-var"}, expected: map[string]string{"PATH": pathEnv, "path": "extra-var"}}, + {environment: []string{"PATH=custom-path"}, expected: map[string]string{"PATH": "custom-path"}}, + } + + for i, fixture := range fixtures { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + handler := NewHandler(SchedulerLaunchd{}) + cfg := &config.ScheduleConfig{Title: "t", SubTitle: "s", Environment: fixture.environment} + launchdJob := handler.getLaunchdJob(cfg, []*calendar.Event{}) + assert.Equal(t, fixture.expected, launchdJob.EnvironmentVariables) + }) + } +} + func TestCreateUserPlist(t *testing.T) { handler := NewHandler(SchedulerLaunchd{}) handler.fs = afero.NewMemMapFs() diff --git a/schedule/handler_systemd.go b/schedule/handler_systemd.go index 285d649d6..9a3ea409f 100644 --- a/schedule/handler_systemd.go +++ b/schedule/handler_systemd.go @@ -105,6 +105,7 @@ func (h *HandlerSystemd) CreateJob(job *config.ScheduleConfig, schedules []*cale } err := systemd.Generate(systemd.Config{ CommandLine: job.Command + " --no-prio " + strings.Join(job.Arguments, " "), + Environment: job.Environment, WorkingDirectory: job.WorkingDirectory, Title: job.Title, SubTitle: job.SubTitle, diff --git a/systemd/generate.go b/systemd/generate.go index c024dcf13..92fc8aea4 100644 --- a/systemd/generate.go +++ b/systemd/generate.go @@ -15,6 +15,7 @@ import ( "github.com/creativeprojects/resticprofile/constants" "github.com/creativeprojects/resticprofile/util/templates" "github.com/spf13/afero" + "golang.org/x/exp/slices" ) const ( @@ -78,6 +79,7 @@ type templateInfo struct { // Config for generating systemd unit and timer files type Config struct { CommandLine string + Environment []string WorkingDirectory string Title string SubTitle string @@ -108,7 +110,7 @@ func Generate(config Config) error { } } - environment := make([]string, 0, 2) + environment := slices.Clone(config.Environment) // add $HOME to the environment variables (as a fallback if not defined in profile) if home, err := os.UserHomeDir(); err == nil { environment = append(environment, fmt.Sprintf("HOME=%s", home)) diff --git a/systemd/generate_test.go b/systemd/generate_test.go index 76b35ab84..9e25421c0 100644 --- a/systemd/generate_test.go +++ b/systemd/generate_test.go @@ -1,4 +1,4 @@ -//+build !darwin,!windows +//go:build !darwin && !windows package systemd @@ -25,17 +25,15 @@ func TestGenerateSystemUnit(t *testing.T) { assertNoFileExists(t, timerFile) err := Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "", - "", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", }) require.NoError(t, err) requireFileExists(t, serviceFile) @@ -77,17 +75,15 @@ WantedBy=timers.target assertNoFileExists(t, timerFile) err = Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - UserUnit, - "low", - "", - "", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: UserUnit, + Priority: "low", }) require.NoError(t, err) requireFileExists(t, serviceFile) @@ -106,17 +102,16 @@ func TestGenerateUnitTemplateNotFound(t *testing.T) { fs = afero.NewMemMapFs() err := Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "unit-file", - "", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", + UnitFile: "unit-file", }) require.Error(t, err) } @@ -125,17 +120,16 @@ func TestGenerateTimerTemplateNotFound(t *testing.T) { fs = afero.NewMemMapFs() err := Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "", - "timer-file", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", + TimerFile: "timer-file", }) require.Error(t, err) } @@ -147,17 +141,16 @@ func TestGenerateUnitTemplateFailed(t *testing.T) { require.NoError(t, err) err = Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "unit", - "", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", + UnitFile: "unit", }) require.Error(t, err) } @@ -169,17 +162,16 @@ func TestGenerateTimerTemplateFailed(t *testing.T) { require.NoError(t, err) err = Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "", - "timer", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", + TimerFile: "timer", }) require.Error(t, err) } @@ -191,17 +183,16 @@ func TestGenerateUnitTemplateFailedToExecute(t *testing.T) { require.NoError(t, err) err = Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "unit", - "", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", + UnitFile: "unit", }) require.Error(t, err) } @@ -213,17 +204,16 @@ func TestGenerateTimerTemplateFailedToExecute(t *testing.T) { require.NoError(t, err) err = Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "", - "timer", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", + TimerFile: "timer", }) require.Error(t, err) } @@ -244,17 +234,17 @@ func TestGenerateFromUserDefinedTemplates(t *testing.T) { require.NoError(t, err) err = Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "unit", - "timer", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", + UnitFile: "unit", + TimerFile: "timer", }) require.NoError(t, err) requireFileExists(t, serviceFile) @@ -275,17 +265,15 @@ func TestGenerateOnReadOnlyFs(t *testing.T) { fs = afero.NewReadOnlyFs(fs) err = Generate(Config{ - "commandLine", - "workdir", - "name", - "backup", - "job description", - "timer description", - []string{"daily"}, - SystemUnit, - "low", - "", - "", + CommandLine: "commandLine", + WorkingDirectory: "workdir", + Title: "name", + SubTitle: "backup", + JobDescription: "job description", + TimerDescription: "timer description", + Schedules: []string{"daily"}, + UnitType: SystemUnit, + Priority: "low", }) require.Error(t, err) } diff --git a/util/env.go b/util/env.go new file mode 100644 index 000000000..cac2e9782 --- /dev/null +++ b/util/env.go @@ -0,0 +1,130 @@ +package util + +import ( + "golang.org/x/exp/maps" + "strings" + + "github.com/creativeprojects/resticprofile/platform" +) + +// Environment manages a set of environment variables +type Environment struct { + env map[string]string + preserveCase bool +} + +// NewDefaultEnvironment creates Environment with OS defaults for preserveCase and the specified initial values +func NewDefaultEnvironment(values ...string) *Environment { + return NewEnvironment(EnvironmentPreservesCase(), values...) +} + +// EnvironmentPreservesCase returns true if environment variables are case sensitive (all OS except Windows) +func EnvironmentPreservesCase() bool { return !platform.IsWindows() } + +// NewFoldingEnvironment creates an Environment that folds the case of variable names +func NewFoldingEnvironment(values ...string) *Environment { + return NewEnvironment(false, values...) +} + +// NewEnvironment creates Environment with optional preserveCase and the specified initial values +func NewEnvironment(preserveCase bool, values ...string) *Environment { + env := &Environment{ + env: make(map[string]string), + preserveCase: preserveCase, + } + env.SetValues(values...) + return env +} + +func splitEnvironmentValue(keyValue string) (key, value string) { + if index := strings.Index(keyValue, "="); index > 0 { + key = strings.TrimSpace(keyValue[:index]) + value = keyValue[index+1:] + } + return +} + +// SetValues sets one or more values of the format NAME=VALUE +func (e *Environment) SetValues(values ...string) { + for _, kv := range values { + if key, _ := splitEnvironmentValue(kv); key != "" { + if e.preserveCase { + e.env[key] = kv + } else { + e.env[strings.ToUpper(key)] = kv + } + } + } +} + +// Values returns all environment variables as NAME=VALUE lines (case is preserved as inserted) +func (e *Environment) Values() (values []string) { return maps.Values(e.env) } + +// Names returns all environment variables names (case is preserved as inserted) +func (e *Environment) Names() (names []string) { return maps.Keys(e.ValuesAsMap()) } + +// FoldedNames returns all environment variables names (case depends on preserveCase) +func (e *Environment) FoldedNames() (names []string) { return maps.Keys(e.env) } + +// ValuesAsMap returns all environment variables as name & value map +func (e *Environment) ValuesAsMap() (m map[string]string) { + m = make(map[string]string) + for _, kv := range e.env { + k, v := splitEnvironmentValue(kv) + m[k] = v + } + return +} + +// Put sets a single name and value pair +func (e *Environment) Put(name, value string) { + if strings.Contains(name, "=") { + return + } + if value == "" { + e.Remove(name) + } else { + e.SetValues(name + "=" + value) + } +} + +// Get returns the variable value, possibly an empty string when the variable is unset or empty. +func (e *Environment) Get(name string) (value string) { + _, value, _ = e.Find(name) + return +} + +// Has returns true as the variable with name is set. +func (e *Environment) Has(name string) (found bool) { + _, _, found = e.Find(name) + return +} + +// Find returns the variable's original name, its value and ok when the variable exists +func (e *Environment) Find(name string) (originalName, value string, ok bool) { + if !e.preserveCase { + name = strings.ToUpper(name) + } + if value, ok = e.env[name]; ok { + originalName, value = splitEnvironmentValue(value) + } + return +} + +// Remove deletes the named env variable +func (e *Environment) Remove(name string) { + if !e.preserveCase { + name = strings.ToUpper(name) + } + delete(e.env, name) +} + +// ResolveName resolves the specified name to the actual variable name if case folding applies (preserveCase is false) +func (e *Environment) ResolveName(name string) string { + if !e.preserveCase { + if actualName, _, found := e.Find(name); found { + return actualName + } + } + return name +} diff --git a/util/env_test.go b/util/env_test.go new file mode 100644 index 000000000..66ce0db19 --- /dev/null +++ b/util/env_test.go @@ -0,0 +1,100 @@ +package util + +import ( + "os" + "strings" + "testing" + + "github.com/creativeprojects/resticprofile/platform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCanReadOsEnv(t *testing.T) { + env := NewDefaultEnvironment(os.Environ()...) + + // All values are included + for _, value := range os.Environ() { + kv := strings.SplitN(value, "=", 2) + assert.Equal(t, kv[1], env.Get(strings.TrimSpace(kv[0]))) + } + + // Elements are retained like specified + assert.ElementsMatch(t, env.Values(), os.Environ()) +} + +func TestEnvironmentPreservesCase(t *testing.T) { + assert.Equal(t, !platform.IsWindows(), EnvironmentPreservesCase()) + assert.Equal(t, !platform.IsWindows(), NewDefaultEnvironment().preserveCase) + assert.False(t, NewFoldingEnvironment().preserveCase) +} + +func TestCanSetAndRemove(t *testing.T) { + env := NewDefaultEnvironment() + env.Put("Name", "value") + assert.Equal(t, "value", env.Get("Name")) + + env.Put("N=V", "value") + assert.Equal(t, "", env.Get("N=V")) + assert.False(t, env.Has("N=V")) + + env.Put("Name", "") + assert.Equal(t, "", env.Get("Name")) + assert.False(t, env.Has("Name")) +} + +func TestCaseFolding(t *testing.T) { + env := NewEnvironment(true) + foldingEnv := NewEnvironment(false) + + env.Put("Name", "abc") + foldingEnv.Put("Name", "abc") + + t.Run("Values", func(t *testing.T) { + values := []string{"Name=abc"} + assert.Equal(t, values, env.Values()) + assert.Equal(t, values, foldingEnv.Values()) + }) + + t.Run("ValuesAsMap", func(t *testing.T) { + values := map[string]string{"Name": "abc"} + assert.Equal(t, values, env.ValuesAsMap()) + assert.Equal(t, values, foldingEnv.ValuesAsMap()) + }) + + t.Run("Names", func(t *testing.T) { + names := []string{"Name"} + foldedNames := []string{"NAME"} + assert.Equal(t, names, env.Names()) + assert.Equal(t, names, foldingEnv.Names()) + assert.Equal(t, names, env.FoldedNames()) + assert.Equal(t, foldedNames, foldingEnv.FoldedNames()) + }) + + t.Run("Get", func(t *testing.T) { + assert.Equal(t, env.Get("Name"), foldingEnv.Get("Name")) + + assert.Equal(t, "", env.Get("NAME")) + assert.False(t, env.Has("NAME")) + assert.Equal(t, "abc", foldingEnv.Get("NAME")) + assert.True(t, foldingEnv.Has("NAME")) + }) + + t.Run("ResolveName", func(t *testing.T) { + assert.Equal(t, "NAME", env.ResolveName("NAME")) + assert.Equal(t, "Name", foldingEnv.ResolveName("NAME")) + }) + + t.Run("Remove", func(t *testing.T) { + env.Put("ToRemove", "x") + env.Remove("TOREMOVE") + foldingEnv.Put("ToRemove", "x") + foldingEnv.Remove("TOREMOVE") + + assert.True(t, env.Has("ToRemove")) + assert.False(t, foldingEnv.Has("ToRemove")) + + env.Remove("ToRemove") + require.False(t, env.Has("ToRemove")) + }) +} diff --git a/util/templates/data.go b/util/templates/data.go index 96fd88b4f..2d48d421e 100644 --- a/util/templates/data.go +++ b/util/templates/data.go @@ -8,6 +8,7 @@ import ( "time" "github.com/creativeprojects/clog" + "github.com/creativeprojects/resticprofile/util" ) // DefaultData provides default variables for templates @@ -38,7 +39,6 @@ func NewDefaultData(env map[string]string) (data DefaultData) { OS: runtime.GOOS, Arch: runtime.GOARCH, Hostname: "localhost", - Env: formatEnv(env), StartupDir: startupDir, CurrentDir: startupDir, } @@ -57,29 +57,22 @@ func NewDefaultData(env map[string]string) (data DefaultData) { data.Hostname = hostname } - for _, envValue := range os.Environ() { - kv := strings.SplitN(envValue, "=", 2) - key, value := strings.ToUpper(strings.TrimSpace(kv[0])), kv[1] - if _, contains := data.Env[key]; !contains && key != "" { - data.Env[key] = value - } + osEnv := util.NewDefaultEnvironment(os.Environ()...) + for name, value := range env { + osEnv.Put(osEnv.ResolveName(name), value) } + data.Env = osEnv.ValuesAsMap() - return data -} - -func formatEnv(env map[string]string) map[string]string { - if env == nil { - env = make(map[string]string) - } else { - for name, v := range env { - if un := strings.ToUpper(name); un != name { - delete(env, name) - env[un] = v + // add uppercase env variants to simplify usage in templates + for name, value := range data.Env { + if un := strings.ToUpper(name); un != name { + if _, exists := data.Env[un]; !exists { + data.Env[un] = value } } } - return env + + return data } var startupDir = (func() string { diff --git a/util/templates/data_test.go b/util/templates/data_test.go index 05786db47..bd9ec4210 100644 --- a/util/templates/data_test.go +++ b/util/templates/data_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/creativeprojects/resticprofile/util/collect" + "github.com/creativeprojects/resticprofile/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -62,9 +62,7 @@ func TestOsAndArch(t *testing.T) { } func TestEnv(t *testing.T) { - osEnvKeys := collect.From(os.Environ(), func(s string) string { - return strings.SplitN(s, "=", 2)[0] - }) + osEnv := util.NewDefaultEnvironment(os.Environ()...) customEnv := map[string]string{ "path": "my-test-path", @@ -74,12 +72,17 @@ func TestEnv(t *testing.T) { env := NewDefaultData(customEnv).Env - for _, key := range osEnvKeys { - if key != "" && key != "PATH" { - assert.Equal(t, os.Getenv(key), env[strings.ToUpper(key)], "key = %s", key) + for _, key := range osEnv.Names() { + if key != "" && strings.ToUpper(key) != "PATH" { + assert.Equal(t, os.Getenv(key), env[key], "key = %s", key) } } for key := range customEnv { - assert.Equal(t, customEnv[key], env[strings.ToUpper(key)], "key = %s", key) + rKey := osEnv.ResolveName(key) + assert.Equal(t, customEnv[key], env[rKey], "key = %s, rKey = %s", key, rKey) } + + // templates offer uppercase variant + assert.Equal(t, customEnv["__test_k1"], env["__test_k1"]) + assert.Equal(t, customEnv["__test_k1"], env["__TEST_K1"]) } diff --git a/wrapper.go b/wrapper.go index 10e3f0568..6a19b9edf 100644 --- a/wrapper.go +++ b/wrapper.go @@ -630,20 +630,13 @@ func (r *resticWrapper) sendMonitoring(sections []config.SendMonitoringSection, } // getEnvironment returns the environment variables defined in the profile configuration -func (r *resticWrapper) getEnvironment() []string { - if r.profile.Environment == nil || len(r.profile.Environment) == 0 { - return nil - } - env := make([]string, len(r.profile.Environment)) - i := 0 +func (r *resticWrapper) getEnvironment() (env []string) { + // Note: variable names match the original case for OS variables. Custom vars are all uppercase. for key, value := range r.profile.Environment { - // env variables are always uppercase - key = strings.ToUpper(key) - clog.Debugf("setting up environment variable '%s'", key) - env[i] = fmt.Sprintf("%s=%s", key, value.Value()) - i++ + clog.Debugf("setting up environment variable: %s=%s", key, value) + env = append(env, fmt.Sprintf("%s=%s", key, value.Value())) } - return env + return } // getProfileEnvironment returns some environment variables about the current profile diff --git a/wrapper_test.go b/wrapper_test.go index 3319aac67..dfd24d11d 100644 --- a/wrapper_test.go +++ b/wrapper_test.go @@ -209,6 +209,7 @@ func TestGetSingleEnvironment(t *testing.T) { profile.Environment = map[string]config.ConfidentialValue{ "User": config.NewConfidentialValue("me"), } + profile.ResolveConfiguration() wrapper := newResticWrapper(nil, "restic", false, profile, "test", nil, nil) env := wrapper.getEnvironment() assert.Equal(t, []string{"USER=me"}, env) @@ -220,6 +221,7 @@ func TestGetMultipleEnvironment(t *testing.T) { "User": config.NewConfidentialValue("me"), "Password": config.NewConfidentialValue("secret"), } + profile.ResolveConfiguration() wrapper := newResticWrapper(nil, "restic", false, profile, "test", nil, nil) env := wrapper.getEnvironment() assert.Len(t, env, 2)