diff --git a/pkg/entity/flag_test.go b/pkg/entity/flag_test.go index a1d3e7c9..9bfb9d3b 100644 --- a/pkg/entity/flag_test.go +++ b/pkg/entity/flag_test.go @@ -33,7 +33,7 @@ func TestCreateFlagKey(t *testing.T) { }) t.Run("invalid key", func(t *testing.T) { - key, err := CreateFlagKey("1-2-3") + key, err := CreateFlagKey(" spaces in key are not allowed 1-2-3") assert.Error(t, err) assert.Zero(t, key) }) @@ -52,7 +52,7 @@ func TestCreateFlagEntityType(t *testing.T) { f := GenFixtureFlag() db := PopulateTestDB(f) - err := CreateFlagEntityType(db, "123-invalid-key") + err := CreateFlagEntityType(db, " spaces in key are not allowed 123-invalid-key") assert.Error(t, err) }) } diff --git a/pkg/handler/crud_test.go b/pkg/handler/crud_test.go index 56dec80f..e6ede4f7 100644 --- a/pkg/handler/crud_test.go +++ b/pkg/handler/crud_test.go @@ -191,8 +191,8 @@ func TestCrudFlagsWithFailures(t *testing.T) { t.Run("CreateFlag - invalid key error", func(t *testing.T) { res = c.CreateFlag(flag.CreateFlagParams{ Body: &models.CreateFlagRequest{ - Description: util.StringPtr("funny flag"), - Key: "1-2-3", // invalid key + Description: util.StringPtr(" flag with a space"), + Key: " 1-2-3", // invalid key }, }) assert.NotZero(t, res.(*flag.CreateFlagDefault).Payload) @@ -811,7 +811,7 @@ func TestCrudVariantsWithFailures(t *testing.T) { res = c.CreateVariant(variant.CreateVariantParams{ FlagID: int64(1), Body: &models.CreateVariantRequest{ - Key: util.StringPtr("123_invalid_key"), + Key: util.StringPtr(" 123_invalid_key"), }, }) assert.NotZero(t, res.(*variant.CreateVariantDefault).Payload) @@ -868,7 +868,7 @@ func TestCrudVariantsWithFailures(t *testing.T) { FlagID: int64(1), VariantID: int64(1), Body: &models.PutVariantRequest{ - Key: util.StringPtr("123_invalid_key"), + Key: util.StringPtr(" spaces in key 123_invalid_key"), }, }) assert.NotZero(t, *res.(*variant.PutVariantDefault).Payload) diff --git a/pkg/util/util.go b/pkg/util/util.go index cfd08918..b3d0baaf 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -12,7 +12,7 @@ import ( var ( keyLengthLimit = 63 - keyRegex = regexp.MustCompile("^[a-z]+[a-z0-9_]*$") + keyRegex = regexp.MustCompile(`^[\w\d-]+$`) randomKeyCharset = []byte("123456789abcdefghijkmnopqrstuvwxyz") randomKeyPrefix = "k" diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 3e2379a8..87adda69 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -82,12 +82,16 @@ func TestIsSafeKey(t *testing.T) { assert.Empty(t, msg) b, msg = IsSafeKey("1a") + assert.True(t, b) + assert.Empty(t, msg) + + b, msg = IsSafeKey(" spaces in key are not allowed ") assert.False(t, b) - assert.NotEmpty(t, msg) + assert.NotEmpty(t, msg) b, msg = IsSafeKey("_a") - assert.False(t, b) - assert.NotEmpty(t, msg) + assert.True(t, b) + assert.Empty(t, msg) b, msg = IsSafeKey(strings.Repeat("a", 64)) assert.False(t, b)