diff --git a/gnmi_server/clientCertAuth.go b/gnmi_server/clientCertAuth.go index 1c44d9c5..db6ebe12 100644 --- a/gnmi_server/clientCertAuth.go +++ b/gnmi_server/clientCertAuth.go @@ -2,6 +2,7 @@ package gnmi import ( "github.com/sonic-net/sonic-gnmi/common_utils" + "github.com/sonic-net/sonic-gnmi/swsscommon" "github.com/golang/glog" "golang.org/x/net/context" "google.golang.org/grpc/codes" @@ -10,7 +11,7 @@ import ( "google.golang.org/grpc/status" ) -func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) { +func ClientCertAuthenAndAuthor(ctx context.Context, serviceConfigTableName string) (context.Context, error) { rc, ctx := common_utils.GetContext(ctx) p, ok := peer.FromContext(ctx) if !ok { @@ -32,10 +33,44 @@ func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) { return ctx, status.Error(codes.Unauthenticated, "invalid username in certificate common name.") } - if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil { - glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err) - return ctx, status.Errorf(codes.Unauthenticated, "") + if serviceConfigTableName != "" { + if err := PopulateAuthStructByCommonName(username, &rc.Auth, serviceConfigTableName); err != nil { + return ctx, err + } + } else { + if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil { + glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err) + return ctx, status.Errorf(codes.Unauthenticated, "") + } } return ctx, nil } + +func PopulateAuthStructByCommonName(certCommonName string, auth *common_utils.AuthInfo, serviceConfigTableName string) error { + if serviceConfigTableName == "" { + return status.Errorf(codes.Unauthenticated, "Service config table name should not be empty") + } + + var configDbConnector = swsscommon.NewConfigDBConnector_Native() + defer swsscommon.DeleteConfigDBConnector_Native(configDbConnector) + configDbConnector.Connect(false) + + var fieldValuePairs = configDbConnector.Get_entry(serviceConfigTableName, certCommonName) + if fieldValuePairs.Size() > 0 { + if fieldValuePairs.Has_key("role") { + var role = fieldValuePairs.Get("role") + auth.Roles = []string{role} + } + } else { + glog.Warningf("Failed to retrieve cert common name mapping; %s", certCommonName) + } + + swsscommon.DeleteFieldValueMap(fieldValuePairs) + + if len(auth.Roles) == 0 { + return status.Errorf(codes.Unauthenticated, "Invalid cert cname:'%s', not a trusted cert common name.", certCommonName) + } else { + return nil + } +} diff --git a/gnmi_server/debug.go b/gnmi_server/debug.go index 5239b72e..6099630e 100644 --- a/gnmi_server/debug.go +++ b/gnmi_server/debug.go @@ -35,7 +35,7 @@ import ( func (srv *Server) GetSubscribePreferences(req *spb_gnoi.SubscribePreferencesReq, stream spb_gnoi.Debug_GetSubscribePreferencesServer) error { ctx := stream.Context() - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return err } diff --git a/gnmi_server/gnoi.go b/gnmi_server/gnoi.go index 8bd96536..241089cf 100644 --- a/gnmi_server/gnoi.go +++ b/gnmi_server/gnoi.go @@ -33,7 +33,7 @@ func RebootSystem(fileName string) error { func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest) (*gnoi_system_pb.RebootResponse, error) { fileName := common_utils.GNMI_WORK_PATH + "/config_db.json.tmp" - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -57,7 +57,7 @@ func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest // TODO: Support GNOI RebootStatus func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootStatusRequest) (*gnoi_system_pb.RebootStatusResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootS // TODO: Support GNOI CancelReboot func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelRebootRequest) (*gnoi_system_pb.CancelRebootResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelR } func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.System_PingServer) error { ctx := rs.Context() - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return err } @@ -85,7 +85,7 @@ func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.Syste } func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_system_pb.System_TracerouteServer) error { ctx := rs.Context() - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return err } @@ -94,7 +94,7 @@ func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_sys } func (srv *Server) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error { ctx := rs.Context() - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return err } @@ -102,7 +102,7 @@ func (srv *Server) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error { return status.Errorf(codes.Unimplemented, "") } func (srv *Server) SwitchControlProcessor(ctx context.Context, req *gnoi_system_pb.SwitchControlProcessorRequest) (*gnoi_system_pb.SwitchControlProcessorResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -110,7 +110,7 @@ func (srv *Server) SwitchControlProcessor(ctx context.Context, req *gnoi_system_ return nil, status.Errorf(codes.Unimplemented, "") } func (srv *Server) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*gnoi_system_pb.TimeResponse, error) { - _, err := authenticate(srv.config.UserAuth, ctx) + _, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -147,7 +147,7 @@ func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRe } func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*spb_jwt.RefreshResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -175,7 +175,7 @@ func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*s } func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRequest) (*spb.ClearNeighborsResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -207,7 +207,7 @@ func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRe } func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (*spb.CopyConfigResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -238,7 +238,7 @@ func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) ( } func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequest) (*spb.TechsupportResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -270,7 +270,7 @@ func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequ } func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallRequest) (*spb.ImageInstallResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -302,7 +302,7 @@ func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallReques } func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) (*spb.ImageRemoveResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } @@ -334,7 +334,7 @@ func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) } func (srv *Server) ImageDefault(ctx context.Context, req *spb.ImageDefaultRequest) (*spb.ImageDefaultResponse, error) { - ctx, err := authenticate(srv.config.UserAuth, ctx) + ctx, err := authenticate(srv.config, ctx) if err != nil { return nil, err } diff --git a/gnmi_server/server.go b/gnmi_server/server.go index 4ae32c18..d45ea29d 100644 --- a/gnmi_server/server.go +++ b/gnmi_server/server.go @@ -58,6 +58,7 @@ type Config struct { EnableNativeWrite bool ZmqAddress string IdleConnDuration int + ConfigTableName string } var AuthLock sync.Mutex @@ -188,30 +189,30 @@ func (srv *Server) Port() int64 { return srv.config.Port } -func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, error) { +func authenticate(config *Config, ctx context.Context) (context.Context, error) { var err error success := false rc, ctx := common_utils.GetContext(ctx) - if !UserAuth.Any() { + if !config.UserAuth.Any() { //No Auth enabled rc.Auth.AuthEnabled = false return ctx, nil } rc.Auth.AuthEnabled = true - if UserAuth.Enabled("password") { + if config.UserAuth.Enabled("password") { ctx, err = BasicAuthenAndAuthor(ctx) if err == nil { success = true } } - if !success && UserAuth.Enabled("jwt") { + if !success && config.UserAuth.Enabled("jwt") { _, ctx, err = JwtAuthenAndAuthor(ctx) if err == nil { success = true } } - if !success && UserAuth.Enabled("cert") { - ctx, err = ClientCertAuthenAndAuthor(ctx) + if !success && config.UserAuth.Enabled("cert") { + ctx, err = ClientCertAuthenAndAuthor(ctx, config.ConfigTableName) if err == nil { success = true } @@ -230,7 +231,7 @@ func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, err // Subscribe implements the gNMI Subscribe RPC. func (s *Server) Subscribe(stream gnmipb.GNMI_SubscribeServer) error { ctx := stream.Context() - ctx, err := authenticate(s.config.UserAuth, ctx) + ctx, err := authenticate(s.config, ctx) if err != nil { return err } @@ -315,7 +316,7 @@ func IsNativeOrigin(origin string) bool { // Get implements the Get RPC in gNMI spec. func (s *Server) Get(ctx context.Context, req *gnmipb.GetRequest) (*gnmipb.GetResponse, error) { common_utils.IncCounter(common_utils.GNMI_GET) - ctx, err := authenticate(s.config.UserAuth, ctx) + ctx, err := authenticate(s.config, ctx) if err != nil { common_utils.IncCounter(common_utils.GNMI_GET_FAIL) return nil, err @@ -402,7 +403,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe common_utils.IncCounter(common_utils.GNMI_SET_FAIL) return nil, grpc.Errorf(codes.Unimplemented, "GNMI is in read-only mode") } - ctx, err := authenticate(s.config.UserAuth, ctx) + ctx, err := authenticate(s.config, ctx) if err != nil { common_utils.IncCounter(common_utils.GNMI_SET_FAIL) return nil, err @@ -502,7 +503,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe } func (s *Server) Capabilities(ctx context.Context, req *gnmipb.CapabilityRequest) (*gnmipb.CapabilityResponse, error) { - ctx, err := authenticate(s.config.UserAuth, ctx) + ctx, err := authenticate(s.config, ctx) if err != nil { return nil, err } diff --git a/gnmi_server/server_test.go b/gnmi_server/server_test.go index 55a02d63..0eee36fd 100644 --- a/gnmi_server/server_test.go +++ b/gnmi_server/server_test.go @@ -25,6 +25,10 @@ import ( "time" "runtime" + "crypto/x509" + "crypto/x509/pkix" + spb_jwt "github.com/sonic-net/sonic-gnmi/proto/gnoi/jwt" + "github.com/kylelemons/godebug/pretty" "github.com/openconfig/gnmi/client" pb "github.com/openconfig/gnmi/proto/gnmi" @@ -36,6 +40,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "google.golang.org/grpc/keepalive" @@ -55,6 +60,7 @@ import ( "github.com/agiledragon/gomonkey/v2" "github.com/godbus/dbus/v5" cacheclient "github.com/openconfig/gnmi/client" + "github.com/sonic-net/sonic-gnmi/swsscommon" ) @@ -4101,6 +4107,188 @@ func TestMasterArbitration(t *testing.T) { }) }*/ + +func TestPopulateAuthStructByCommonName(t *testing.T) { + // check auth with nil cert name + err := PopulateAuthStructByCommonName("certname1", nil, "") + if err == nil { + t.Errorf("PopulateAuthStructByCommonName with empty config table should failed: %v", err) + } +} + +func CreateAuthorizationCtx() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + cert := x509.Certificate{ + Subject: pkix.Name{ + CommonName: "certname1", + }, + } + verifiedCerts := make([][]*x509.Certificate, 1) + verifiedCerts[0] = make([]*x509.Certificate, 1) + verifiedCerts[0][0] = &cert + p := peer.Peer{ + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + VerifiedChains: verifiedCerts, + }, + }, + } + ctx = peer.NewContext(ctx, &p) + return ctx, cancel +} + + func TestClientCertAuthenAndAuthor(t *testing.T) { + if !swsscommon.SonicDBConfigIsInit() { + swsscommon.SonicDBConfigInitialize() + } + + var configDb = swsscommon.NewDBConnector("CONFIG_DB", uint(0), true) + var gnmiTable = swsscommon.NewTable(configDb, "GNMI_CLIENT_CERT") + configDb.Flushdb() + + // initialize err variable + err := status.Error(codes.Unauthenticated, "") + + // when config table is empty, will authorize with PopulateAuthStruct + mockpopulate := gomonkey.ApplyFunc(PopulateAuthStruct, func(username string, auth *common_utils.AuthInfo, r []string) error { + return nil + }) + defer mockpopulate.Reset() + + // check auth with nil cert name + ctx, cancel := CreateAuthorizationCtx() + ctx, err = ClientCertAuthenAndAuthor(ctx, "") + if err != nil { + t.Errorf("CommonNameMatch with empty config table should success: %v", err) + } + + cancel() + + // check get 1 cert name + ctx, cancel = CreateAuthorizationCtx() + configDb.Flushdb() + gnmiTable.Hset("certname1", "role", "role1") + ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT") + if err != nil { + t.Errorf("CommonNameMatch with correct cert name should success: %v", err) + } + + cancel() + + // check get multiple cert names + ctx, cancel = CreateAuthorizationCtx() + configDb.Flushdb() + gnmiTable.Hset("certname1", "role", "role1") + gnmiTable.Hset("certname2", "role", "role2") + ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT") + if err != nil { + t.Errorf("CommonNameMatch with correct cert name should success: %v", err) + } + + cancel() + + // check a invalid cert cname + ctx, cancel = CreateAuthorizationCtx() + configDb.Flushdb() + gnmiTable.Hset("certname2", "role", "role2") + ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT") + if err == nil { + t.Errorf("CommonNameMatch with invalid cert name should fail: %v", err) + } + + cancel() + + swsscommon.DeleteTable(gnmiTable) + swsscommon.DeleteDBConnector(configDb) +} + +type MockServerStream struct { + grpc.ServerStream +} + +func (x *MockServerStream) Context() context.Context { + return context.Background() +} + +type MockPingServer struct { + MockServerStream +} + +func (x *MockPingServer) Send(m *gnoi_system_pb.PingResponse) error { + return nil +} + +type MockTracerouteServer struct { + MockServerStream +} + +func (x *MockTracerouteServer) Send(m *gnoi_system_pb.TracerouteResponse) error { + return nil +} + +type MockSetPackageServer struct { + MockServerStream +} + +func (x *MockSetPackageServer) Send(m *gnoi_system_pb.SetPackageResponse) error { + return nil +} + +func (x *MockSetPackageServer) SendAndClose(m *gnoi_system_pb.SetPackageResponse) error { + return nil +} + +func (x *MockSetPackageServer) Recv() (*gnoi_system_pb.SetPackageRequest, error) { + return nil, nil +} + +func TestGnoiAuthorization(t *testing.T) { + s := createServer(t, 8081) + go runServer(t, s) + mockAuthenticate := gomonkey.ApplyFunc(s.Authenticate, func(ctx context.Context, req *spb_jwt.AuthenticateRequest) (*spb_jwt.AuthenticateResponse, error) { + return nil, nil + }) + defer mockAuthenticate.Reset() + + err := s.Ping(new(gnoi_system_pb.PingRequest), new(MockPingServer)) + if err == nil { + t.Errorf("Ping should failed, because not implement.") + } + + s.Traceroute(new(gnoi_system_pb.TracerouteRequest), new(MockTracerouteServer)) + if err == nil { + t.Errorf("Traceroute should failed, because not implement.") + } + + s.SetPackage(new(MockSetPackageServer)) + if err == nil { + t.Errorf("SetPackage should failed, because not implement.") + } + + ctx := context.Background() + s.SwitchControlProcessor(ctx, new(gnoi_system_pb.SwitchControlProcessorRequest)) + if err == nil { + t.Errorf("SwitchControlProcessor should failed, because not implement.") + } + + s.Refresh(ctx, new(spb_jwt.RefreshRequest)) + if err == nil { + t.Errorf("Refresh should failed, because not implement.") + } + + s.ClearNeighbors(ctx, new(sgpb.ClearNeighborsRequest)) + if err == nil { + t.Errorf("ClearNeighbors should failed, because not implement.") + } + + s.CopyConfig(ctx, new(sgpb.CopyConfigRequest)) + if err == nil { + t.Errorf("CopyConfig should failed, because not implement.") + } + + s.s.Stop() +} + func init() { // Enable logs at UT setup flag.Lookup("v").Value.Set("10") diff --git a/telemetry/telemetry.go b/telemetry/telemetry.go index 6cc128fe..6bcb835f 100644 --- a/telemetry/telemetry.go +++ b/telemetry/telemetry.go @@ -25,6 +25,7 @@ var ( caCert = flag.String("ca_crt", "", "CA certificate for client certificate validation. Optional.") serverCert = flag.String("server_crt", "", "TLS server certificate") serverKey = flag.String("server_key", "", "TLS server private key") + configTableName = flag.String("config_table_name", "", "Config table name") zmqAddress = flag.String("zmq_address", "", "Orchagent ZMQ address, when not set or empty string telemetry server will switch to Redis based communication channel.") insecure = flag.Bool("insecure", false, "Skip providing TLS cert and key, for testing only!") noTLS = flag.Bool("noTLS", false, "disable TLS, for testing only!") @@ -86,6 +87,7 @@ func main() { cfg.ZmqAddress = *zmqAddress cfg.Threshold = int(*threshold) cfg.IdleConnDuration = int(*idle_conn_duration) + cfg.ConfigTableName = *configTableName var opts []grpc.ServerOption if val, err := strconv.Atoi(getflag("v")); err == nil {