diff --git a/grpc_api/authorization.go b/grpc_api/authorization.go new file mode 100644 index 0000000..bb55a5c --- /dev/null +++ b/grpc_api/authorization.go @@ -0,0 +1,50 @@ +package grpc_api + +import ( + "Simple-Bank/token" + "context" + "fmt" + "google.golang.org/grpc/metadata" + "strings" +) + +const ( + // authorizationHeader is the name of the header containing authorization information, including access token. + authorizationHeader = "authorization" + // authorizationTypeBearer is the type of authorization bearer token. + authorizationTypeBearer = "bearer" +) + +// authorizeUser authorizes the user based on the access token provided in the context. +// It extracts the access token from the authorization header and verifies it using the token maker. +// It returns the token payload if the access token is valid, otherwise it returns an error. +func (server *GrpcServer) authorizeUser(ctx context.Context) (*token.Payload, error) { + mtdt, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, fmt.Errorf("missing metadata") + } + + values := mtdt.Get(authorizationHeader) + if len(values) == 0 { + return nil, fmt.Errorf("missing authorization header") + } + + authHeader := values[0] + fields := strings.Fields(authHeader) + if len(fields) < 2 { + return nil, fmt.Errorf("invalid aithorization header format") + } + + authType := strings.ToLower(fields[0]) + if authType != authorizationTypeBearer { + return nil, fmt.Errorf("unsupported authorization type: %s", authType) + } + + accessToken := fields[1] + payload, err := server.tokenMaker.VerifyToken(accessToken) + if err != nil { + return nil, fmt.Errorf("invalid access token: %s", accessToken) + } + + return payload, nil +} diff --git a/grpc_api/error.go b/grpc_api/error.go index 1688ea3..ec967e8 100644 --- a/grpc_api/error.go +++ b/grpc_api/error.go @@ -24,3 +24,8 @@ func invalidArgumentError(violations []*errdetails.BadRequest_FieldViolation) er return statusDetails.Err() } + +// unAuthenticatedError creates and returns an unauthenticated gRPC error with the input error message. +func unAuthenticatedError(err error) error { + return status.Errorf(codes.Unauthenticated, "unauthorized: %s", err) +} diff --git a/grpc_api/rpc_update_user.go b/grpc_api/rpc_update_user.go index 135071a..bf599c4 100644 --- a/grpc_api/rpc_update_user.go +++ b/grpc_api/rpc_update_user.go @@ -11,12 +11,21 @@ import ( ) func (server *GrpcServer) UpdateUser(context context.Context, req *pb.UpdateUserRequest) (*pb.UpdateUserResponse, error) { + payload, err := server.authorizeUser(context) + if err != nil { + return nil, unAuthenticatedError(err) + } + violations := validateUpdateUserRequest(req) if violations != nil { err := invalidArgumentError(violations) return nil, err } + if payload.Username != req.Username { + return nil, status.Errorf(codes.PermissionDenied, "cannot update other users info") + } + updatedUser, err := server.dbServices.UpdateUser(services.UpdateUserRequest{ Username: req.Username, Password: req.Password,