diff --git a/go/vt/topo/keyspace.go b/go/vt/topo/keyspace.go index fa41d423fd3..14a9563db87 100755 --- a/go/vt/topo/keyspace.go +++ b/go/vt/topo/keyspace.go @@ -18,6 +18,7 @@ package topo import ( "path" + "strings" "context" @@ -53,6 +54,16 @@ func (ki *KeyspaceInfo) SetKeyspaceName(name string) { ki.keyspace = name } +var ksNameReplacer = strings.NewReplacer("/", "") + +func ValidateKeyspaceName(name string) (string, error) { + if validated := ksNameReplacer.Replace(name); name != validated { + return validated, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "keyspace name %s contains invalid characters; expected %s instead", name, validated) + } + + return name, nil +} + // GetServedFrom returns a Keyspace_ServedFrom record if it exists. func (ki *KeyspaceInfo) GetServedFrom(tabletType topodatapb.TabletType) *topodatapb.Keyspace_ServedFrom { for _, ksf := range ki.ServedFroms { @@ -160,6 +171,10 @@ func (ki *KeyspaceInfo) ComputeCellServedFrom(cell string) []*topodatapb.SrvKeys // CreateKeyspace wraps the underlying Conn.Create // and dispatches the event. func (ts *Server) CreateKeyspace(ctx context.Context, keyspace string, value *topodatapb.Keyspace) error { + if _, err := ValidateKeyspaceName(keyspace); err != nil { + return vterrors.Wrap(err, "CreateKeyspace got invalid keyspace name") + } + data, err := value.MarshalVT() if err != nil { return err @@ -180,6 +195,10 @@ func (ts *Server) CreateKeyspace(ctx context.Context, keyspace string, value *to // GetKeyspace reads the given keyspace and returns it func (ts *Server) GetKeyspace(ctx context.Context, keyspace string) (*KeyspaceInfo, error) { + if _, err := ValidateKeyspaceName(keyspace); err != nil { + return nil, vterrors.Wrap(err, "GetKeyspace got invalid keyspace name") + } + keyspacePath := path.Join(KeyspacesPath, keyspace, KeyspaceFile) data, version, err := ts.globalCell.Get(ctx, keyspacePath) if err != nil { diff --git a/go/vt/topo/topotests/keyspace_test.go b/go/vt/topo/topotests/keyspace_test.go new file mode 100644 index 00000000000..1b511ef5ff3 --- /dev/null +++ b/go/vt/topo/topotests/keyspace_test.go @@ -0,0 +1,75 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package topotests + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/topo/memorytopo" + "vitess.io/vitess/go/vt/vterrors" + + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/vt/proto/vtrpc" +) + +// TODO: add copyright +// TODO: test valid vs invalid ks names + +func TestCreateKeyspace(t *testing.T) { + ts := memorytopo.NewServer("zone1") + ctx := context.Background() + + t.Run("valid name", func(t *testing.T) { + err := ts.CreateKeyspace(ctx, "ks", &topodatapb.Keyspace{}) + require.NoError(t, err) + }) + t.Run("invalid name", func(t *testing.T) { + err := ts.CreateKeyspace(ctx, "no/slashes/allowed", &topodatapb.Keyspace{}) + assert.Error(t, err) + assert.Equal(t, vtrpc.Code_INVALID_ARGUMENT, vterrors.Code(err), "%+v", err) + }) +} + +func TestGetKeyspace(t *testing.T) { + ts := memorytopo.NewServer("zone1") + ctx := context.Background() + + t.Run("valid name", func(t *testing.T) { + // First, create the keyspace. + err := ts.CreateKeyspace(ctx, "ks", &topodatapb.Keyspace{}) + require.NoError(t, err) + + // Now, get it. + ks, err := ts.GetKeyspace(ctx, "ks") + require.NoError(t, err) + assert.NotNil(t, ks) + }) + + t.Run("invalid name", func(t *testing.T) { + // We can't create the keyspace (because we can't create a keyspace + // with an invalid name), so we'll validate the error we get is *not* + // NOT_FOUND. + ks, err := ts.GetKeyspace(ctx, "no/slashes/allowed") + assert.Error(t, err) + assert.Equal(t, vtrpc.Code_INVALID_ARGUMENT, vterrors.Code(err), "%+v", err) + assert.Nil(t, ks) + }) +} diff --git a/go/vt/vtorc/inst/keyspace_dao.go b/go/vt/vtorc/inst/keyspace_dao.go index f3624449001..db13a55538c 100644 --- a/go/vt/vtorc/inst/keyspace_dao.go +++ b/go/vt/vtorc/inst/keyspace_dao.go @@ -30,6 +30,10 @@ var ErrKeyspaceNotFound = errors.New("keyspace not found") // ReadKeyspace reads the vitess keyspace record. func ReadKeyspace(keyspaceName string) (*topo.KeyspaceInfo, error) { + if _, err := topo.ValidateKeyspaceName(keyspaceName); err != nil { + return nil, err + } + query := ` select keyspace_type,