From 765332445d7915472a18121da58c5a1d348a428a Mon Sep 17 00:00:00 2001 From: mitchell Date: Tue, 22 Oct 2024 11:50:50 -0400 Subject: [PATCH] Handle case-insensitive env var names on Windows. We read the case-insensitive version as needed, but replace it with our case-sensitive version (e.g. "Path" -> "PATH"). --- pkg/runtime/internal/envdef/environment.go | 38 ++++++++++++++--- .../internal/envdef/environment_test.go | 41 +++++++++---------- 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/pkg/runtime/internal/envdef/environment.go b/pkg/runtime/internal/envdef/environment.go index 9cda4e2cbc..d5742ca2c0 100644 --- a/pkg/runtime/internal/envdef/environment.go +++ b/pkg/runtime/internal/envdef/environment.go @@ -6,6 +6,7 @@ import ( "maps" "os" "path/filepath" + "runtime" "strings" "github.com/ActiveState/cli/internal/constants" @@ -344,19 +345,40 @@ func (ev *EnvironmentVariable) ValueString() string { ev.Separator) } -// GetEnvBasedOn returns the environment variable names and values defined by +// getEnvBasedOn returns the environment variable names and values defined by // the EnvironmentDefinition. // If an environment variable is configured to inherit from the base // environment (`Inherit==true`), the base environment defined by the // `envLookup` method is joined with these environment variables. // This function is mostly used for testing. Use GetEnv() in production. -func (ed *EnvironmentDefinition) GetEnvBasedOn(envLookup map[string]string) (map[string]string, error) { +func (ed *EnvironmentDefinition) getEnvBasedOn(envLookup map[string]string) (map[string]string, error) { res := maps.Clone(envLookup) + // On Windows, environment variable names are case-insensitive. + // For example, it uses "Path", but responds to "PATH" as well. + // This causes trouble with our environment merging, which will end up adding "PATH" (with the + // correct value) alongside "Path" (with the old value). + // In order to remedy this, track the OS-specific environment variable name and if it's + // modified/merged, replace it with our version (e.g. "Path" -> "PATH"). We do not use the OS name + // because we assume ours is the one that's used elsewhere in the codebase, and Windows will + // properly respond to a changed-case name anyway. + osEnvNames := map[string]string{} + if runtime.GOOS == "windows" { + for k := range envLookup { + osEnvNames[strings.ToLower(k)] = k + } + } + for _, ev := range ed.Env { pev := &ev + osName := pev.Name + if runtime.GOOS == "windows" { + if name, ok := osEnvNames[strings.ToLower(osName)]; ok { + osName = name + } + } + osValue, hasOsValue := envLookup[osName] if pev.Inherit { - osValue, hasOsValue := envLookup[pev.Name] if hasOsValue { osEv := ev osEv.Values = []string{osValue} @@ -364,15 +386,19 @@ func (ed *EnvironmentDefinition) GetEnvBasedOn(envLookup map[string]string) (map pev, err = osEv.Merge(ev) if err != nil { return nil, err - } } - } else if _, hasOsValue := envLookup[pev.Name]; hasOsValue { + } else if hasOsValue { res[pev.Name] = "" // unset } // only add environment variable if at least one value is set (This allows us to remove variables from the environment.) if len(ev.Values) > 0 { res[pev.Name] = pev.ValueString() + if pev.Name != osName { + // On Windows, delete the case-insensitive version. Our case-sensitive version has already + // processed the value of the case-insensitive version. + delete(res, osName) + } } } return res, nil @@ -388,7 +414,7 @@ func (ed *EnvironmentDefinition) GetEnv(inherit bool) map[string]string { if inherit { lookupEnv = osutils.EnvSliceToMap(os.Environ()) } - res, err := ed.GetEnvBasedOn(lookupEnv) + res, err := ed.getEnvBasedOn(lookupEnv) if err != nil { panic(fmt.Sprintf("Could not inherit OS environment variable: %v", err)) } diff --git a/pkg/runtime/internal/envdef/environment_test.go b/pkg/runtime/internal/envdef/environment_test.go index 7fd8754f7e..3741d5d393 100644 --- a/pkg/runtime/internal/envdef/environment_test.go +++ b/pkg/runtime/internal/envdef/environment_test.go @@ -1,4 +1,4 @@ -package envdef_test +package envdef import ( "encoding/json" @@ -9,7 +9,6 @@ import ( "github.com/ActiveState/cli/internal/osutils" "github.com/ActiveState/cli/internal/testhelpers/suite" - "github.com/ActiveState/cli/pkg/runtime/internal/envdef" "github.com/stretchr/testify/require" "github.com/ActiveState/cli/internal/fileutils" @@ -21,20 +20,20 @@ type EnvironmentTestSuite struct { func (suite *EnvironmentTestSuite) TestMergeVariables() { - ev1 := envdef.EnvironmentVariable{} + ev1 := EnvironmentVariable{} err := json.Unmarshal([]byte(`{ "env_name": "V", "values": ["a", "b"] }`), &ev1) require.NoError(suite.T(), err) - ev2 := envdef.EnvironmentVariable{} + ev2 := EnvironmentVariable{} err = json.Unmarshal([]byte(`{ "env_name": "V", "values": ["b", "c"] }`), &ev2) require.NoError(suite.T(), err) - expected := &envdef.EnvironmentVariable{} + expected := &EnvironmentVariable{} err = json.Unmarshal([]byte(`{ "env_name": "V", "values": ["b", "c", "a"], @@ -51,7 +50,7 @@ func (suite *EnvironmentTestSuite) TestMergeVariables() { } func (suite *EnvironmentTestSuite) TestMerge() { - ed1 := &envdef.EnvironmentDefinition{} + ed1 := &EnvironmentDefinition{} err := json.Unmarshal([]byte(`{ "env": [{"env_name": "V", "values": ["a", "b"]}], @@ -59,14 +58,14 @@ func (suite *EnvironmentTestSuite) TestMerge() { }`), ed1) require.NoError(suite.T(), err) - ed2 := envdef.EnvironmentDefinition{} + ed2 := EnvironmentDefinition{} err = json.Unmarshal([]byte(`{ "env": [{"env_name": "V", "values": ["c", "d"]}], "installdir": "abc" }`), &ed2) require.NoError(suite.T(), err) - expected := envdef.EnvironmentDefinition{} + expected := EnvironmentDefinition{} err = json.Unmarshal([]byte(`{ "env": [{"env_name": "V", "values": ["c", "d", "a", "b"]}], "installdir": "abc" @@ -80,7 +79,7 @@ func (suite *EnvironmentTestSuite) TestMerge() { } func (suite *EnvironmentTestSuite) TestInheritPath() { - ed1 := &envdef.EnvironmentDefinition{} + ed1 := &EnvironmentDefinition{} err := json.Unmarshal([]byte(`{ "env": [{"env_name": "PATH", "values": ["NEWVALUE"]}], @@ -90,7 +89,7 @@ func (suite *EnvironmentTestSuite) TestInheritPath() { }`), ed1) require.NoError(suite.T(), err) - env, err := ed1.GetEnvBasedOn(map[string]string{"PATH": "OLDVALUE"}) + env, err := ed1.getEnvBasedOn(map[string]string{"PATH": "OLDVALUE"}) require.NoError(suite.T(), err) suite.True(strings.HasPrefix(env["PATH"], "NEWVALUE"), "%s does not start with NEWVALUE", env["PATH"]) suite.True(strings.HasSuffix(env["PATH"], "OLDVALUE"), "%s does not end with OLDVALUE", env["PATH"]) @@ -99,11 +98,11 @@ func (suite *EnvironmentTestSuite) TestInheritPath() { func (suite *EnvironmentTestSuite) TestSharedTests() { type testCase struct { - Name string `json:"name"` - Definitions []envdef.EnvironmentDefinition `json:"definitions"` - BaseEnv map[string]string `json:"base_env"` - Expected map[string]string `json:"result"` - IsError bool `json:"error"` + Name string `json:"name"` + Definitions []EnvironmentDefinition `json:"definitions"` + BaseEnv map[string]string `json:"base_env"` + Expected map[string]string `json:"result"` + IsError bool `json:"error"` } td, err := os.ReadFile("runtime_test_cases.json") @@ -126,7 +125,7 @@ func (suite *EnvironmentTestSuite) TestSharedTests() { suite.Assert().NoError(err, "error merging %d-th definition", i) } - res, err := ed.GetEnvBasedOn(tc.BaseEnv) + res, err := ed.getEnvBasedOn(tc.BaseEnv) if tc.IsError { suite.Assert().Error(err) return @@ -139,7 +138,7 @@ func (suite *EnvironmentTestSuite) TestSharedTests() { } func (suite *EnvironmentTestSuite) TestValueString() { - ev1 := envdef.EnvironmentVariable{} + ev1 := EnvironmentVariable{} err := json.Unmarshal([]byte(`{ "env_name": "V", "values": ["a", "b"] @@ -151,7 +150,7 @@ func (suite *EnvironmentTestSuite) TestValueString() { } func (suite *EnvironmentTestSuite) TestGetEnv() { - ed1 := envdef.EnvironmentDefinition{} + ed1 := EnvironmentDefinition{} err := json.Unmarshal([]byte(`{ "env": [{"env_name": "V", "values": ["a", "b"]}], "installdir": "abc" @@ -177,7 +176,7 @@ func (suite *EnvironmentTestSuite) TestFindBinPathFor() { require.NoError(suite.T(), err, "creating temporary directory") defer os.RemoveAll(tmpDir) - ed1 := envdef.EnvironmentDefinition{} + ed1 := EnvironmentDefinition{} err = json.Unmarshal([]byte(`{ "env": [{"env_name": "PATH", "values": ["${INSTALLDIR}/bin", "${INSTALLDIR}/bin2"]}], "installdir": "abc" @@ -187,7 +186,7 @@ func (suite *EnvironmentTestSuite) TestFindBinPathFor() { tmpDir, err = fileutils.GetLongPathName(tmpDir) require.NoError(suite.T(), err) - constants := envdef.NewConstants(tmpDir) + constants := NewConstants(tmpDir) // expand variables ed1.ExpandVariables(constants) @@ -248,7 +247,7 @@ func TestFilterPATH(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - envdef.FilterPATH(tt.args.env, tt.args.excludes...) + FilterPATH(tt.args.env, tt.args.excludes...) require.Equal(t, tt.want, tt.args.env["PATH"]) }) }