Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cert authorization with common name support. #241

Merged
merged 18 commits into from
Jun 13, 2024
Merged
1 change: 1 addition & 0 deletions dialout/dialout_server_cli/dialout_server_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var (
serverKey = flag.String("server_key", "", "TLS server private key")
insecure = flag.Bool("insecure", false, "Skip providing TLS cert and key, for testing only!")
allowNoClientCert = flag.Bool("allow_no_client_auth", false, "When set, telemetry server will request but not require a client certificate.")
clientCrtCname = flag.String("client_crt_cname", "", "Client cert common name")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need cname for dialout server?

Copy link
Contributor Author

@liuh-80 liuh-80 Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted, confirmed with Zain, dialout not use in prod.

)

func main() {
Expand Down
43 changes: 39 additions & 4 deletions gnmi_server/clientCertAuth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gnmi

import (
"github.com/sonic-net/sonic-gnmi/common_utils"
"github.com/sonic-net/sonic-gnmi/swsscommon"
"github.com/golang/glog"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
Expand All @@ -10,7 +11,7 @@ import (
"google.golang.org/grpc/status"
)

func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) {
func ClientCertAuthenAndAuthor(ctx context.Context, serviceConfigTableName string) (context.Context, error) {
rc, ctx := common_utils.GetContext(ctx)
p, ok := peer.FromContext(ctx)
if !ok {
Expand All @@ -32,10 +33,44 @@ func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) {
return ctx, status.Error(codes.Unauthenticated, "invalid username in certificate common name.")
}

if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil {
glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err)
return ctx, status.Errorf(codes.Unauthenticated, "")
if serviceConfigTableName != "" {
if err := PopulateAuthStructByCommonName(username, &rc.Auth, serviceConfigTableName); err != nil {
return ctx, err
}
} else {
if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil {
glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err)
return ctx, status.Errorf(codes.Unauthenticated, "")
}
}

return ctx, nil
}

func PopulateAuthStructByCommonName(certCommonName string, auth *common_utils.AuthInfo, serviceConfigTableName string) error {
if serviceConfigTableName == "" {
return status.Errorf(codes.Unauthenticated, "Service config table name should not be empty")
}

var configDbConnector = swsscommon.NewConfigDBConnector()
defer swsscommon.DeleteConfigDBConnector_Native(configDbConnector.ConfigDBConnector_Native)
configDbConnector.Connect(false)

var fieldValuePairs = configDbConnector.Get_entry(serviceConfigTableName, certCommonName)
if fieldValuePairs.Size() > 0 {
if fieldValuePairs.Has_key("role") {
var role = fieldValuePairs.Get("role")
auth.Roles = []string{role}
}
} else {
glog.Warningf("Failed to retrieve cert common name mapping; %s", certCommonName)
}

swsscommon.DeleteFieldValueMap(fieldValuePairs)

if len(auth.Roles) == 0 {
return status.Errorf(codes.Unauthenticated, "Invalid cert cname:'%s', not a trusted cert common name.", certCommonName)
} else {
return nil
}
}
2 changes: 1 addition & 1 deletion gnmi_server/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (

func (srv *Server) GetSubscribePreferences(req *spb_gnoi.SubscribePreferencesReq, stream spb_gnoi.Debug_GetSubscribePreferencesServer) error {
ctx := stream.Context()
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand Down
34 changes: 17 additions & 17 deletions gnmi_server/gnoi.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func KillOrRestartProcess(restart bool, serviceName string) error {
}

func (srv *Server) KillProcess(ctx context.Context, req *gnoi_system_pb.KillProcessRequest) (*gnoi_system_pb.KillProcessResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func RebootSystem(fileName string) error {
func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest) (*gnoi_system_pb.RebootResponse, error) {
fileName := common_utils.GNMI_WORK_PATH + "/config_db.json.tmp"

_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -102,7 +102,7 @@ func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest

// TODO: Support GNOI RebootStatus
func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootStatusRequest) (*gnoi_system_pb.RebootStatusResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -112,7 +112,7 @@ func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootS

// TODO: Support GNOI CancelReboot
func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelRebootRequest) (*gnoi_system_pb.CancelRebootResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -121,7 +121,7 @@ func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelR
}
func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.System_PingServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand All @@ -130,7 +130,7 @@ func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.Syste
}
func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_system_pb.System_TracerouteServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand All @@ -139,23 +139,23 @@ func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_sys
}
func (srv *Server) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
log.V(1).Info("gNOI: SetPackage")
return status.Errorf(codes.Unimplemented, "")
}
func (srv *Server) SwitchControlProcessor(ctx context.Context, req *gnoi_system_pb.SwitchControlProcessorRequest) (*gnoi_system_pb.SwitchControlProcessorResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
log.V(1).Info("gNOI: SwitchControlProcessor")
return nil, status.Errorf(codes.Unimplemented, "")
}
func (srv *Server) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*gnoi_system_pb.TimeResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -167,7 +167,7 @@ func (srv *Server) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*

func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRequest) (*spb_jwt.AuthenticateResponse, error) {
// Can't enforce normal authentication here.. maybe only enforce client cert auth if enabled?
// ctx,err := authenticate(srv.config.UserAuth, ctx)
// ctx,err := authenticate(srv.config, ctx)
// if err != nil {
// return nil, err
// }
Expand All @@ -192,7 +192,7 @@ func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRe

}
func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*spb_jwt.RefreshResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*s
}

func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRequest) (*spb.ClearNeighborsResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -252,7 +252,7 @@ func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRe
}

func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (*spb.CopyConfigResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -283,7 +283,7 @@ func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (
}

func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequest) (*spb.TechsupportResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -315,7 +315,7 @@ func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequ
}

func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallRequest) (*spb.ImageInstallResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -347,7 +347,7 @@ func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallReques
}

func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) (*spb.ImageRemoveResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -379,7 +379,7 @@ func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest)
}

func (srv *Server) ImageDefault(ctx context.Context, req *spb.ImageDefaultRequest) (*spb.ImageDefaultResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down
22 changes: 12 additions & 10 deletions gnmi_server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type Config struct {
EnableNativeWrite bool
ZmqPort string
IdleConnDuration int
ConfigTableName string
}

var AuthLock sync.Mutex
Expand Down Expand Up @@ -207,30 +208,31 @@ func (srv *Server) Port() int64 {
return srv.config.Port
}

func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, error) {
func authenticate(config *Config, ctx context.Context) (context.Context, error) {
var err error
success := false
rc, ctx := common_utils.GetContext(ctx)
if !UserAuth.Any() {
if !config.UserAuth.Any() {
//No Auth enabled
rc.Auth.AuthEnabled = false
return ctx, nil
}

rc.Auth.AuthEnabled = true
if UserAuth.Enabled("password") {
if config.UserAuth.Enabled("password") {
ctx, err = BasicAuthenAndAuthor(ctx)
if err == nil {
success = true
}
}
if !success && UserAuth.Enabled("jwt") {
if !success && config.UserAuth.Enabled("jwt") {
_, ctx, err = JwtAuthenAndAuthor(ctx)
if err == nil {
success = true
}
}
if !success && UserAuth.Enabled("cert") {
ctx, err = ClientCertAuthenAndAuthor(ctx)
if !success && config.UserAuth.Enabled("cert") {
ctx, err = ClientCertAuthenAndAuthor(ctx, config.ConfigTableName)
if err == nil {
success = true
}
Expand All @@ -249,7 +251,7 @@ func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, err
// Subscribe implements the gNMI Subscribe RPC.
func (s *Server) Subscribe(stream gnmipb.GNMI_SubscribeServer) error {
ctx := stream.Context()
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -334,7 +336,7 @@ func IsNativeOrigin(origin string) bool {
// Get implements the Get RPC in gNMI spec.
func (s *Server) Get(ctx context.Context, req *gnmipb.GetRequest) (*gnmipb.GetResponse, error) {
common_utils.IncCounter(common_utils.GNMI_GET)
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
common_utils.IncCounter(common_utils.GNMI_GET_FAIL)
return nil, err
Expand Down Expand Up @@ -440,7 +442,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe
common_utils.IncCounter(common_utils.GNMI_SET_FAIL)
return nil, grpc.Errorf(codes.Unimplemented, "GNMI is in read-only mode")
}
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
common_utils.IncCounter(common_utils.GNMI_SET_FAIL)
return nil, err
Expand Down Expand Up @@ -541,7 +543,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe
}

func (s *Server) Capabilities(ctx context.Context, req *gnmipb.CapabilityRequest) (*gnmipb.CapabilityResponse, error) {
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading