Skip to content

Commit

Permalink
Merge pull request #241 from liuh-80/dev/liuh/add_gnmi_cert_name
Browse files Browse the repository at this point in the history
Add cert authorization with common name support.
  • Loading branch information
liuh-80 authored Jun 13, 2024
2 parents 424212b + 9b65bde commit f0d0959
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 33 deletions.
43 changes: 39 additions & 4 deletions gnmi_server/clientCertAuth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gnmi

import (
"github.com/sonic-net/sonic-gnmi/common_utils"
"github.com/sonic-net/sonic-gnmi/swsscommon"
"github.com/golang/glog"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
Expand All @@ -10,7 +11,7 @@ import (
"google.golang.org/grpc/status"
)

func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) {
func ClientCertAuthenAndAuthor(ctx context.Context, serviceConfigTableName string) (context.Context, error) {
rc, ctx := common_utils.GetContext(ctx)
p, ok := peer.FromContext(ctx)
if !ok {
Expand All @@ -32,10 +33,44 @@ func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) {
return ctx, status.Error(codes.Unauthenticated, "invalid username in certificate common name.")
}

if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil {
glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err)
return ctx, status.Errorf(codes.Unauthenticated, "")
if serviceConfigTableName != "" {
if err := PopulateAuthStructByCommonName(username, &rc.Auth, serviceConfigTableName); err != nil {
return ctx, err
}
} else {
if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil {
glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err)
return ctx, status.Errorf(codes.Unauthenticated, "")
}
}

return ctx, nil
}

func PopulateAuthStructByCommonName(certCommonName string, auth *common_utils.AuthInfo, serviceConfigTableName string) error {
if serviceConfigTableName == "" {
return status.Errorf(codes.Unauthenticated, "Service config table name should not be empty")
}

var configDbConnector = swsscommon.NewConfigDBConnector()
defer swsscommon.DeleteConfigDBConnector_Native(configDbConnector.ConfigDBConnector_Native)
configDbConnector.Connect(false)

var fieldValuePairs = configDbConnector.Get_entry(serviceConfigTableName, certCommonName)
if fieldValuePairs.Size() > 0 {
if fieldValuePairs.Has_key("role") {
var role = fieldValuePairs.Get("role")
auth.Roles = []string{role}
}
} else {
glog.Warningf("Failed to retrieve cert common name mapping; %s", certCommonName)
}

swsscommon.DeleteFieldValueMap(fieldValuePairs)

if len(auth.Roles) == 0 {
return status.Errorf(codes.Unauthenticated, "Invalid cert cname:'%s', not a trusted cert common name.", certCommonName)
} else {
return nil
}
}
2 changes: 1 addition & 1 deletion gnmi_server/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (

func (srv *Server) GetSubscribePreferences(req *spb_gnoi.SubscribePreferencesReq, stream spb_gnoi.Debug_GetSubscribePreferencesServer) error {
ctx := stream.Context()
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand Down
34 changes: 17 additions & 17 deletions gnmi_server/gnoi.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func KillOrRestartProcess(restart bool, serviceName string) error {
}

func (srv *Server) KillProcess(ctx context.Context, req *gnoi_system_pb.KillProcessRequest) (*gnoi_system_pb.KillProcessResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func RebootSystem(fileName string) error {
func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest) (*gnoi_system_pb.RebootResponse, error) {
fileName := common_utils.GNMI_WORK_PATH + "/config_db.json.tmp"

_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -102,7 +102,7 @@ func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest

// TODO: Support GNOI RebootStatus
func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootStatusRequest) (*gnoi_system_pb.RebootStatusResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -112,7 +112,7 @@ func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootS

// TODO: Support GNOI CancelReboot
func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelRebootRequest) (*gnoi_system_pb.CancelRebootResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -121,7 +121,7 @@ func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelR
}
func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.System_PingServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand All @@ -130,7 +130,7 @@ func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.Syste
}
func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_system_pb.System_TracerouteServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand All @@ -139,23 +139,23 @@ func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_sys
}
func (srv *Server) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
log.V(1).Info("gNOI: SetPackage")
return status.Errorf(codes.Unimplemented, "")
}
func (srv *Server) SwitchControlProcessor(ctx context.Context, req *gnoi_system_pb.SwitchControlProcessorRequest) (*gnoi_system_pb.SwitchControlProcessorResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
log.V(1).Info("gNOI: SwitchControlProcessor")
return nil, status.Errorf(codes.Unimplemented, "")
}
func (srv *Server) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*gnoi_system_pb.TimeResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -167,7 +167,7 @@ func (srv *Server) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*

func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRequest) (*spb_jwt.AuthenticateResponse, error) {
// Can't enforce normal authentication here.. maybe only enforce client cert auth if enabled?
// ctx,err := authenticate(srv.config.UserAuth, ctx)
// ctx,err := authenticate(srv.config, ctx)
// if err != nil {
// return nil, err
// }
Expand All @@ -192,7 +192,7 @@ func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRe

}
func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*spb_jwt.RefreshResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*s
}

func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRequest) (*spb.ClearNeighborsResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -252,7 +252,7 @@ func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRe
}

func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (*spb.CopyConfigResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -283,7 +283,7 @@ func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (
}

func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequest) (*spb.TechsupportResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -315,7 +315,7 @@ func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequ
}

func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallRequest) (*spb.ImageInstallResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -347,7 +347,7 @@ func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallReques
}

func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) (*spb.ImageRemoveResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -379,7 +379,7 @@ func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest)
}

func (srv *Server) ImageDefault(ctx context.Context, req *spb.ImageDefaultRequest) (*spb.ImageDefaultResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down
22 changes: 12 additions & 10 deletions gnmi_server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type Config struct {
EnableNativeWrite bool
ZmqPort string
IdleConnDuration int
ConfigTableName string
}

var AuthLock sync.Mutex
Expand Down Expand Up @@ -207,30 +208,31 @@ func (srv *Server) Port() int64 {
return srv.config.Port
}

func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, error) {
func authenticate(config *Config, ctx context.Context) (context.Context, error) {
var err error
success := false
rc, ctx := common_utils.GetContext(ctx)
if !UserAuth.Any() {
if !config.UserAuth.Any() {
//No Auth enabled
rc.Auth.AuthEnabled = false
return ctx, nil
}

rc.Auth.AuthEnabled = true
if UserAuth.Enabled("password") {
if config.UserAuth.Enabled("password") {
ctx, err = BasicAuthenAndAuthor(ctx)
if err == nil {
success = true
}
}
if !success && UserAuth.Enabled("jwt") {
if !success && config.UserAuth.Enabled("jwt") {
_, ctx, err = JwtAuthenAndAuthor(ctx)
if err == nil {
success = true
}
}
if !success && UserAuth.Enabled("cert") {
ctx, err = ClientCertAuthenAndAuthor(ctx)
if !success && config.UserAuth.Enabled("cert") {
ctx, err = ClientCertAuthenAndAuthor(ctx, config.ConfigTableName)
if err == nil {
success = true
}
Expand All @@ -249,7 +251,7 @@ func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, err
// Subscribe implements the gNMI Subscribe RPC.
func (s *Server) Subscribe(stream gnmipb.GNMI_SubscribeServer) error {
ctx := stream.Context()
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -334,7 +336,7 @@ func IsNativeOrigin(origin string) bool {
// Get implements the Get RPC in gNMI spec.
func (s *Server) Get(ctx context.Context, req *gnmipb.GetRequest) (*gnmipb.GetResponse, error) {
common_utils.IncCounter(common_utils.GNMI_GET)
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
common_utils.IncCounter(common_utils.GNMI_GET_FAIL)
return nil, err
Expand Down Expand Up @@ -440,7 +442,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe
common_utils.IncCounter(common_utils.GNMI_SET_FAIL)
return nil, grpc.Errorf(codes.Unimplemented, "GNMI is in read-only mode")
}
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
common_utils.IncCounter(common_utils.GNMI_SET_FAIL)
return nil, err
Expand Down Expand Up @@ -541,7 +543,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe
}

func (s *Server) Capabilities(ctx context.Context, req *gnmipb.CapabilityRequest) (*gnmipb.CapabilityResponse, error) {
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit f0d0959

Please sign in to comment.