Skip to content

Commit

Permalink
Refactor database engines registration (#10074)
Browse files Browse the repository at this point in the history
  • Loading branch information
r0mant authored Feb 1, 2022
1 parent 0ab7a7a commit 21b6b17
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 78 deletions.
108 changes: 108 additions & 0 deletions lib/srv/db/common/engines.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package common

import (
"context"
"sync"

"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/trace"

"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
)

var (
// engines is a global database engines registry.
engines map[string]EngineFn
// enginesMu protects access to the global engines registry map.
enginesMu sync.RWMutex
)

// EngineFn defines a database engine constructor function.
type EngineFn func(EngineConfig) Engine

// RegisterEngine registers a new engine constructor.
func RegisterEngine(fn EngineFn, names ...string) {
enginesMu.Lock()
defer enginesMu.Unlock()
if engines == nil {
engines = make(map[string]EngineFn)
}
for _, name := range names {
engines[name] = fn
}
}

// GetEngine returns a new engine for the provided configuration.
func GetEngine(name string, conf EngineConfig) (Engine, error) {
if err := conf.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
enginesMu.RLock()
engineFn := engines[name]
enginesMu.RUnlock()
if engineFn == nil {
return nil, trace.NotFound("database engine %q is not registered", name)
}
return engineFn(conf), nil
}

// EngineConfig is the common configuration every database engine uses.
type EngineConfig struct {
// Auth handles database access authentication.
Auth Auth
// Audit emits database access audit events.
Audit Audit
// AuthClient is the cluster auth server client.
AuthClient *auth.Client
// CloudClients provides access to cloud API clients.
CloudClients CloudClients
// Context is the database server close context.
Context context.Context
// Clock is the clock interface.
Clock clockwork.Clock
// Log is used for logging.
Log logrus.FieldLogger
}

// CheckAndSetDefaults validates the config and sets default values.
func (c *EngineConfig) CheckAndSetDefaults() error {
if c.Auth == nil {
return trace.BadParameter("engine config Auth is missing")
}
if c.Audit == nil {
return trace.BadParameter("engine config Audit is missing")
}
if c.AuthClient == nil {
return trace.BadParameter("engine config AuthClient is missing")
}
if c.CloudClients == nil {
return trace.BadParameter("engine config CloudClients are missing")
}
if c.Context == nil {
c.Context = context.Background()
}
if c.Clock == nil {
c.Clock = clockwork.NewRealClock()
}
if c.Log == nil {
c.Log = logrus.StandardLogger()
}
return nil
}
75 changes: 75 additions & 0 deletions lib/srv/db/common/engines_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package common

import (
"context"
"testing"

"github.com/gravitational/teleport/lib/auth"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)

// TestRegisterEngine verifies database engine registration.
func TestRegisterEngine(t *testing.T) {
ec := EngineConfig{
Context: context.Background(),
Clock: clockwork.NewFakeClock(),
Log: logrus.StandardLogger(),
Auth: &testAuth{},
Audit: &testAudit{},
AuthClient: &auth.Client{},
CloudClients: NewCloudClients(),
}

// No engine is registered initially.
engine, err := GetEngine("test", ec)
require.Nil(t, engine)
require.IsType(t, trace.NotFound(""), err)

// Register a "test" engine.
RegisterEngine(func(ec EngineConfig) Engine {
return &testEngine{ec: ec}
}, "test")

// Create the registered engine instance.
engine, err = GetEngine("test", ec)
require.NoError(t, err)
require.NotNil(t, engine)

// Verify it's the one we registered.
engineInst, ok := engine.(*testEngine)
require.True(t, ok)
require.Equal(t, ec, engineInst.ec)
}

type testEngine struct {
Engine
ec EngineConfig
}

type testAudit struct {
Audit
}

type testAuth struct {
Auth
}
24 changes: 12 additions & 12 deletions lib/srv/db/mongodb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,26 @@ import (
"go.mongodb.org/mongo-driver/x/mongo/driver"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
)

func init() {
common.RegisterEngine(newEngine, defaults.ProtocolMongoDB)
}

func newEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
}
}

// Engine implements the MongoDB database service that accepts client
// connections coming over reverse tunnel from the proxy and proxies
// them between the proxy and the MongoDB database instance.
//
// Implements common.Engine.
type Engine struct {
// Auth handles database access authentication.
Auth common.Auth
// Audit emits database access audit events.
Audit common.Audit
// Context is the database server close context.
Context context.Context
// Clock is the clock interface.
Clock clockwork.Clock
// Log is used for logging.
Log logrus.FieldLogger
// EngineConfig is the common database engine configuration.
common.EngineConfig
// clientConn is an incoming client connection.
clientConn net.Conn
}
Expand Down
28 changes: 12 additions & 16 deletions lib/srv/db/mysql/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"time"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db/cloud"
Expand All @@ -38,30 +37,27 @@ import (
"github.com/siddontang/go-mysql/server"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
)

func init() {
common.RegisterEngine(newEngine, defaults.ProtocolMySQL)
}

func newEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
}
}

// Engine implements the MySQL database service that accepts client
// connections coming over reverse tunnel from the proxy and proxies
// them between the proxy and the MySQL database instance.
//
// Implements common.Engine.
type Engine struct {
// Auth handles database access authentication.
Auth common.Auth
// Audit emits database access audit events.
Audit common.Audit
// AuthClient is the cluster auth server client.
AuthClient *auth.Client
// Context is the database server close context.
Context context.Context
// Clock is the clock interface.
Clock clockwork.Clock
// CloudClients provides access to cloud API clients.
CloudClients common.CloudClients
// Log is used for logging.
Log logrus.FieldLogger
// EngineConfig is the common database engine configuration.
common.EngineConfig
// proxyConn is a client connection.
proxyConn server.Conn
}
Expand Down
28 changes: 15 additions & 13 deletions lib/srv/db/postgres/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db/cloud"
"github.com/gravitational/teleport/lib/srv/db/common"
Expand All @@ -33,28 +34,29 @@ import (
"github.com/jackc/pgproto3/v2"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
)

func init() {
common.RegisterEngine(newEngine,
defaults.ProtocolPostgres,
defaults.ProtocolCockroachDB)
}

func newEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
}
}

// Engine implements the Postgres database service that accepts client
// connections coming over reverse tunnel from the proxy and proxies
// them between the proxy and the Postgres database instance.
//
// Implements common.Engine.
type Engine struct {
// Auth handles database access authentication.
Auth common.Auth
// Audit emits database access audit events.
Audit common.Audit
// Context is the database server close context.
Context context.Context
// Clock is the clock interface.
Clock clockwork.Clock
// CloudClients provides access to cloud API clients.
CloudClients common.CloudClients
// Log is used for logging.
Log logrus.FieldLogger
// EngineConfig is the common database engine configuration.
common.EngineConfig
// client is a client connection.
client *pgproto3.Backend
}
Expand Down
55 changes: 18 additions & 37 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,15 @@ import (
"github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/srv/db/cloud"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
"github.com/gravitational/teleport/lib/srv/db/mysql"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/utils"

// Import to register MongoDB engine.
_ "github.com/gravitational/teleport/lib/srv/db/mongodb"
// Import to register MySQL engine.
_ "github.com/gravitational/teleport/lib/srv/db/mysql"
// Import to register Postgres engine.
_ "github.com/gravitational/teleport/lib/srv/db/postgres"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
Expand Down Expand Up @@ -752,41 +756,18 @@ func (s *Server) dispatch(sessionCtx *common.Session, streamWriter events.Stream
return engine, nil
}

// createEngine creates a new database engine base on the database protocol. An error is returned when
// a protocol is not supported.
// createEngine creates a new database engine based on the database protocol.
// An error is returned when a protocol is not supported.
func (s *Server) createEngine(sessionCtx *common.Session, audit common.Audit) (common.Engine, error) {
switch sessionCtx.Database.GetProtocol() {
case defaults.ProtocolPostgres, defaults.ProtocolCockroachDB:
return &postgres.Engine{
Auth: s.cfg.Auth,
Audit: audit,
Context: s.closeContext,
Clock: s.cfg.Clock,
CloudClients: s.cfg.CloudClients,
Log: sessionCtx.Log,
}, nil
case defaults.ProtocolMySQL:
return &mysql.Engine{
Auth: s.cfg.Auth,
Audit: audit,
AuthClient: s.cfg.AuthClient,
Context: s.closeContext,
Clock: s.cfg.Clock,
CloudClients: s.cfg.CloudClients,
Log: sessionCtx.Log,
}, nil
case defaults.ProtocolMongoDB:
return &mongodb.Engine{
Auth: s.cfg.Auth,
Audit: audit,
Context: s.closeContext,
Clock: s.cfg.Clock,
Log: sessionCtx.Log,
}, nil
}

return nil, trace.BadParameter("unsupported database protocol %q",
sessionCtx.Database.GetProtocol())
return common.GetEngine(sessionCtx.Database.GetProtocol(), common.EngineConfig{
Auth: s.cfg.Auth,
Audit: audit,
AuthClient: s.cfg.AuthClient,
CloudClients: s.cfg.CloudClients,
Context: s.closeContext,
Clock: s.cfg.Clock,
Log: sessionCtx.Log,
})
}

func (s *Server) authorize(ctx context.Context) (*common.Session, error) {
Expand Down

0 comments on commit 21b6b17

Please sign in to comment.