From 6ca41ad59a31d2f5c5d09b2589621e6f5c837ebb Mon Sep 17 00:00:00 2001 From: Sourr_cream Date: Sat, 20 Apr 2024 18:49:45 +0300 Subject: [PATCH] auth --- backend/cmd/orchestrator/main.go | 7 +- backend/internal/agent/agent.go | 7 +- backend/internal/config/config.go | 1 + .../domain/messages/expression_message.go | 1 + .../http-server/handlers/expression.go | 24 ++++-- .../http-server/handlers/operation.go | 22 ++++- backend/internal/http-server/handlers/user.go | 6 ++ backend/internal/lib/jwt/jwt_from_header.go | 54 ++++++++++++ backend/internal/orchestrator/orchestrator.go | 2 + .../storage/postgres/expressions.sql.go | 5 +- .../storage/postgres/model_transformers.go | 1 + backend/internal/storage/postgres/models.go | 1 + .../storage/postgres/operations.sql.go | 82 +++++++++---------- backend/sql/queries/expressions.sql | 1 + backend/sql/queries/operations.sql | 29 +++---- backend/sql/schema/0005_operations.sql | 6 +- .../sql/schema/0007_add_simple_operations.sql | 12 --- 17 files changed, 174 insertions(+), 87 deletions(-) create mode 100644 backend/internal/lib/jwt/jwt_from_header.go delete mode 100644 backend/sql/schema/0007_add_simple_operations.sql diff --git a/backend/cmd/orchestrator/main.go b/backend/cmd/orchestrator/main.go index 5e15b13..b02f160 100644 --- a/backend/cmd/orchestrator/main.go +++ b/backend/cmd/orchestrator/main.go @@ -91,14 +91,15 @@ func main() { v1Router.Post("/expressions", handlers.HandlerCreateExpression( log, dbCfg, + cfg.JWTSecret, application.OrchestratorApp, application.Producer, )) - v1Router.Get("/expressions", handlers.HandlerGetExpressions(log, dbCfg)) + v1Router.Get("/expressions", handlers.HandlerGetExpressions(log, dbCfg, cfg.JWTSecret)) // Operation endpoints - v1Router.Get("/operations", handlers.HandlerGetOperations(log, dbCfg)) - v1Router.Patch("/operations", handlers.HandlerUpdateOperation(log, dbCfg)) + v1Router.Get("/operations", handlers.HandlerGetOperations(log, dbCfg, cfg.JWTSecret)) + v1Router.Patch("/operations", handlers.HandlerUpdateOperation(log, dbCfg, cfg.JWTSecret)) // Agent endpoints v1Router.Get("/agents", handlers.HandlerGetAgents(log, dbCfg)) diff --git a/backend/internal/agent/agent.go b/backend/internal/agent/agent.go index ff19757..9e54bf9 100644 --- a/backend/internal/agent/agent.go +++ b/backend/internal/agent/agent.go @@ -148,7 +148,10 @@ func (a *Agent) RunSimpleComputer(ctx context.Context, exprMsg *messages.Express return fmt.Errorf("can't convert int to str: %v, fn: %s", err, fn) } - time_for_oper, err := a.dbConfig.Queries.GetOperationTimeByType(ctx, oper) + time_for_oper, err := a.dbConfig.Queries.GetOperationTimeByType(ctx, postgres.GetOperationTimeByTypeParams{ + OperationType: oper, + UserID: exprMsg.UserID, + }) if err != nil { return fmt.Errorf("can't get execution time by operation type: %v, fn: %s", err, fn) } @@ -162,8 +165,6 @@ func (a *Agent) RunSimpleComputer(ctx context.Context, exprMsg *messages.Express return fmt.Errorf("can't increment number of active calculations: %v, fn: %s", err, fn) } - // atomic.AddInt32(&a.NumberOfActiveCalculations, 1) - return nil } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 2cf5b73..9b07f31 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -19,6 +19,7 @@ type Config struct { InactiveTimeForAgent int32 `yaml:"inactive_time_for_agent" env-default:"200"` TimeForPing int32 `yaml:"time_for_ping" end-default:"100"` TokenTTL time.Duration `yaml:"tokenTTL" env-default:"1h"` + JWTSecret string `env:"JWT_SECRET" env-required:"true"` GRPCServer `yaml:"grpc_server" env-required:"true"` DatabaseInstance `yaml:"database_instance" env-required:"true"` RabbitQueue `yaml:"rabbit_queue" env-required:"true"` diff --git a/backend/internal/domain/messages/expression_message.go b/backend/internal/domain/messages/expression_message.go index ece0212..4db3f40 100644 --- a/backend/internal/domain/messages/expression_message.go +++ b/backend/internal/domain/messages/expression_message.go @@ -7,6 +7,7 @@ type ExpressionMessage struct { Result int `json:"result"` IsPing bool `json:"is_ping"` AgentID int32 `json:"agent_id"` + UserID int32 `json:"user_id"` } type ResultAndTokenMessage struct { diff --git a/backend/internal/http-server/handlers/expression.go b/backend/internal/http-server/handlers/expression.go index 0d1bb96..c9a7782 100644 --- a/backend/internal/http-server/handlers/expression.go +++ b/backend/internal/http-server/handlers/expression.go @@ -9,6 +9,7 @@ import ( "github.com/Prrromanssss/DAEE-fullstack/internal/domain/brokers" "github.com/Prrromanssss/DAEE-fullstack/internal/domain/messages" + "github.com/Prrromanssss/DAEE-fullstack/internal/lib/jwt" "github.com/Prrromanssss/DAEE-fullstack/internal/orchestrator" "github.com/Prrromanssss/DAEE-fullstack/internal/orchestrator/parser" @@ -16,11 +17,11 @@ import ( "github.com/Prrromanssss/DAEE-fullstack/internal/storage/postgres" ) -// TODO: user // HandlerCreateExpression is a http.Handler to create new expression. func HandlerCreateExpression( log *slog.Logger, dbCfg *storage.Storage, + secret string, orc *orchestrator.Orchestrator, producer brokers.Producer, ) http.HandlerFunc { @@ -31,13 +32,19 @@ func HandlerCreateExpression( slog.String("fn", fn), ) + userID, err := jwt.GetUidFromJWT(r, secret) + if err != nil { + respondWithError(log, w, 403, "Status Forbidden") + return + } + type parametrs struct { Data string `json:"data"` } decoder := json.NewDecoder(r.Body) params := parametrs{} - err := decoder.Decode(¶ms) + err = decoder.Decode(¶ms) if err != nil { respondWithError(log, w, 400, fmt.Sprintf("error parsing JSON: %v", err)) return @@ -56,7 +63,7 @@ func HandlerCreateExpression( Data: params.Data, ParseData: parseData, Status: "ready_for_computation", - UserID: 1, // TODO: UserID !!! + UserID: userID, }) if err != nil { respondWithError(log, w, 400, fmt.Sprintf("can't create expression: %v", err)) @@ -66,6 +73,7 @@ func HandlerCreateExpression( msgToQueue := messages.ExpressionMessage{ ExpressionID: expression.ExpressionID, Expression: parseData, + UserID: userID, } orc.AddTask(msgToQueue, producer) @@ -77,7 +85,7 @@ func HandlerCreateExpression( } // HandlerGetExpressions is a http.Handler to get all expressions from storage. -func HandlerGetExpressions(log *slog.Logger, dbCfg *storage.Storage) http.HandlerFunc { +func HandlerGetExpressions(log *slog.Logger, dbCfg *storage.Storage, secret string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { const fn = "handlers.HandlerCreateExpression" @@ -85,7 +93,13 @@ func HandlerGetExpressions(log *slog.Logger, dbCfg *storage.Storage) http.Handle slog.String("fn", fn), ) - expressions, err := dbCfg.Queries.GetExpressions(r.Context()) + userID, err := jwt.GetUidFromJWT(r, secret) + if err != nil { + respondWithError(log, w, 403, "Status Forbidden") + return + } + + expressions, err := dbCfg.Queries.GetExpressions(r.Context(), userID) if err != nil { respondWithError(log, w, 400, fmt.Sprintf("Couldn't get expressions: %v", err)) return diff --git a/backend/internal/http-server/handlers/operation.go b/backend/internal/http-server/handlers/operation.go index 399007a..ccdcc59 100644 --- a/backend/internal/http-server/handlers/operation.go +++ b/backend/internal/http-server/handlers/operation.go @@ -6,12 +6,13 @@ import ( "log/slog" "net/http" + "github.com/Prrromanssss/DAEE-fullstack/internal/lib/jwt" "github.com/Prrromanssss/DAEE-fullstack/internal/storage" "github.com/Prrromanssss/DAEE-fullstack/internal/storage/postgres" ) // HandlerGetOperations is a http.Handler to get all operations from storage. -func HandlerGetOperations(log *slog.Logger, dbCfg *storage.Storage) http.HandlerFunc { +func HandlerGetOperations(log *slog.Logger, dbCfg *storage.Storage, secret string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { const fn = "hadlers.HandlerGetOperations" @@ -19,7 +20,13 @@ func HandlerGetOperations(log *slog.Logger, dbCfg *storage.Storage) http.Handler slog.String("fn", fn), ) - operations, err := dbCfg.Queries.GetOperations(r.Context()) + userID, err := jwt.GetUidFromJWT(r, secret) + if err != nil { + respondWithError(log, w, 403, "Status Forbidden") + return + } + + operations, err := dbCfg.Queries.GetOperations(r.Context(), userID) if err != nil { respondWithError(log, w, 400, fmt.Sprintf("can't get operations: %v", err)) return @@ -30,7 +37,7 @@ func HandlerGetOperations(log *slog.Logger, dbCfg *storage.Storage) http.Handler } // HandlerUpdateOperation is a http.Handler to update execution time of the certain operation type. -func HandlerUpdateOperation(log *slog.Logger, dbCfg *storage.Storage) http.HandlerFunc { +func HandlerUpdateOperation(log *slog.Logger, dbCfg *storage.Storage, secret string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { const fn = "handlers.HandlerUpdateOperation" @@ -38,6 +45,12 @@ func HandlerUpdateOperation(log *slog.Logger, dbCfg *storage.Storage) http.Handl slog.String("fn", fn), ) + userID, err := jwt.GetUidFromJWT(r, secret) + if err != nil { + respondWithError(log, w, 403, "Status Forbidden") + return + } + type parametrs struct { OperationType string `json:"operation_type"` ExecutionTime int32 `json:"execution_time"` @@ -45,7 +58,7 @@ func HandlerUpdateOperation(log *slog.Logger, dbCfg *storage.Storage) http.Handl decoder := json.NewDecoder(r.Body) params := parametrs{} - err := decoder.Decode(¶ms) + err = decoder.Decode(¶ms) if err != nil { respondWithError(log, w, 400, fmt.Sprintf("error parsing JSON: %v", err)) } @@ -53,6 +66,7 @@ func HandlerUpdateOperation(log *slog.Logger, dbCfg *storage.Storage) http.Handl operation, err := dbCfg.Queries.UpdateOperationTime(r.Context(), postgres.UpdateOperationTimeParams{ OperationType: params.OperationType, ExecutionTime: params.ExecutionTime, + UserID: userID, }) if err != nil { diff --git a/backend/internal/http-server/handlers/user.go b/backend/internal/http-server/handlers/user.go index f99ed3a..d23d239 100644 --- a/backend/internal/http-server/handlers/user.go +++ b/backend/internal/http-server/handlers/user.go @@ -84,6 +84,12 @@ func HandlerRegisterNewUser( return } + err = dbCfg.Queries.NewOperationsForUser(r.Context(), int32(registerResponse.UserId)) + if err != nil { + respondWithError(log, w, 400, fmt.Sprintf("can't create new operations for user: %v", err)) + return + } + respondWithJson(log, w, 200, registerResponse) } } diff --git a/backend/internal/lib/jwt/jwt_from_header.go b/backend/internal/lib/jwt/jwt_from_header.go new file mode 100644 index 0000000..d96b318 --- /dev/null +++ b/backend/internal/lib/jwt/jwt_from_header.go @@ -0,0 +1,54 @@ +package jwt + +import ( + "fmt" + "net/http" + "strings" + + "github.com/golang-jwt/jwt/v5" +) + +func getTokenFromHeader(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", fmt.Errorf("authorization header is missing") + } + + // Checks that header starts with "Bearer". + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return "", fmt.Errorf("invalid Authorization header format") + } + + return parts[1], nil // returns token without "Bearer". +} + +func GetUidFromJWT(r *http.Request, secret string) (int32, error) { + jwtToken, err := getTokenFromHeader(r) + if err != nil { + return 0, err + } + + // Parse JWT Token. + token, err := jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }) + if err != nil { + return 0, err + } + + if !token.Valid { + return 0, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return 0, err + } + userID, ok := claims["uid"].(int32) + if !ok { + return 0, err + } + + return userID, nil +} diff --git a/backend/internal/orchestrator/orchestrator.go b/backend/internal/orchestrator/orchestrator.go index 558c594..2381b4a 100644 --- a/backend/internal/orchestrator/orchestrator.go +++ b/backend/internal/orchestrator/orchestrator.go @@ -85,6 +85,7 @@ func (o *Orchestrator) ReloadComputingExpressions( msgToQueue := messages.ExpressionMessage{ ExpressionID: expr.ExpressionID, Expression: expr.ParseData, + UserID: expr.UserID, } o.AddTask(msgToQueue, producer) } @@ -170,6 +171,7 @@ func (o *Orchestrator) CheckPing(ctx context.Context, producer brokers.Producer) msgToQueue := messages.ExpressionMessage{ ExpressionID: expr.ExpressionID, Expression: expr.ParseData, + UserID: expr.UserID, } o.AddTask(msgToQueue, producer) } diff --git a/backend/internal/storage/postgres/expressions.sql.go b/backend/internal/storage/postgres/expressions.sql.go index 9bf7b2c..b6037b7 100644 --- a/backend/internal/storage/postgres/expressions.sql.go +++ b/backend/internal/storage/postgres/expressions.sql.go @@ -149,11 +149,12 @@ SELECT created_at, updated_at, data, parse_data, status, result, is_ready FROM expressions +WHERE user_id = $1 ORDER BY created_at DESC ` -func (q *Queries) GetExpressions(ctx context.Context) ([]Expression, error) { - rows, err := q.db.QueryContext(ctx, getExpressions) +func (q *Queries) GetExpressions(ctx context.Context, userID int32) ([]Expression, error) { + rows, err := q.db.QueryContext(ctx, getExpressions, userID) if err != nil { return nil, err } diff --git a/backend/internal/storage/postgres/model_transformers.go b/backend/internal/storage/postgres/model_transformers.go index 3b4be58..eb02d04 100644 --- a/backend/internal/storage/postgres/model_transformers.go +++ b/backend/internal/storage/postgres/model_transformers.go @@ -34,6 +34,7 @@ type OperationTransformed struct { OperationID int32 `json:"operation_id"` OperationType string `json:"operation_type"` ExecutionTime int32 `json:"execution_time"` + UserID int32 `json:"user_id"` } func DatabaseOperationToOperation(dbOper Operation) OperationTransformed { diff --git a/backend/internal/storage/postgres/models.go b/backend/internal/storage/postgres/models.go index 738089b..df04377 100644 --- a/backend/internal/storage/postgres/models.go +++ b/backend/internal/storage/postgres/models.go @@ -125,6 +125,7 @@ type Operation struct { OperationID int32 OperationType string ExecutionTime int32 + UserID int32 } type User struct { diff --git a/backend/internal/storage/postgres/operations.sql.go b/backend/internal/storage/postgres/operations.sql.go index 54e330c..690a5d1 100644 --- a/backend/internal/storage/postgres/operations.sql.go +++ b/backend/internal/storage/postgres/operations.sql.go @@ -9,44 +9,19 @@ import ( "context" ) -const createOperation = `-- name: CreateOperation :exec -INSERT INTO operations (operation_type, execution_time) -VALUES ($1, $2) -RETURNING - operation_id, operation_type, execution_time -` - -type CreateOperationParams struct { - OperationType string - ExecutionTime int32 -} - -func (q *Queries) CreateOperation(ctx context.Context, arg CreateOperationParams) error { - _, err := q.db.ExecContext(ctx, createOperation, arg.OperationType, arg.ExecutionTime) - return err -} - -const getOperationByType = `-- name: GetOperationByType :one -SELECT - operation_id, operation_type, execution_time +const getOperationTimeByType = `-- name: GetOperationTimeByType :one +SELECT execution_time FROM operations -WHERE operation_type = $1 +WHERE operation_type = $1 AND user_id = $2 ` -func (q *Queries) GetOperationByType(ctx context.Context, operationType string) (Operation, error) { - row := q.db.QueryRowContext(ctx, getOperationByType, operationType) - var i Operation - err := row.Scan(&i.OperationID, &i.OperationType, &i.ExecutionTime) - return i, err +type GetOperationTimeByTypeParams struct { + OperationType string + UserID int32 } -const getOperationTimeByType = `-- name: GetOperationTimeByType :one -SELECT execution_time FROM operations -WHERE operation_type = $1 -` - -func (q *Queries) GetOperationTimeByType(ctx context.Context, operationType string) (int32, error) { - row := q.db.QueryRowContext(ctx, getOperationTimeByType, operationType) +func (q *Queries) GetOperationTimeByType(ctx context.Context, arg GetOperationTimeByTypeParams) (int32, error) { + row := q.db.QueryRowContext(ctx, getOperationTimeByType, arg.OperationType, arg.UserID) var execution_time int32 err := row.Scan(&execution_time) return execution_time, err @@ -54,13 +29,14 @@ func (q *Queries) GetOperationTimeByType(ctx context.Context, operationType stri const getOperations = `-- name: GetOperations :many SELECT - operation_id, operation_type, execution_time + operation_id, operation_type, execution_time, user_id FROM operations +WHERE user_id = $1 ORDER BY operation_type DESC ` -func (q *Queries) GetOperations(ctx context.Context) ([]Operation, error) { - rows, err := q.db.QueryContext(ctx, getOperations) +func (q *Queries) GetOperations(ctx context.Context, userID int32) ([]Operation, error) { + rows, err := q.db.QueryContext(ctx, getOperations, userID) if err != nil { return nil, err } @@ -68,7 +44,12 @@ func (q *Queries) GetOperations(ctx context.Context) ([]Operation, error) { var items []Operation for rows.Next() { var i Operation - if err := rows.Scan(&i.OperationID, &i.OperationType, &i.ExecutionTime); err != nil { + if err := rows.Scan( + &i.OperationID, + &i.OperationType, + &i.ExecutionTime, + &i.UserID, + ); err != nil { return nil, err } items = append(items, i) @@ -82,21 +63,40 @@ func (q *Queries) GetOperations(ctx context.Context) ([]Operation, error) { return items, nil } +const newOperationsForUser = `-- name: NewOperationsForUser :exec +INSERT INTO operations (operation_type, user_id) VALUES +('+', $1), +('-', $1), +('*', $1), +('/', $1) +` + +func (q *Queries) NewOperationsForUser(ctx context.Context, userID int32) error { + _, err := q.db.ExecContext(ctx, newOperationsForUser, userID) + return err +} + const updateOperationTime = `-- name: UpdateOperationTime :one UPDATE operations SET execution_time = $1 -WHERE operation_type = $2 -RETURNING operation_id, operation_type, execution_time +WHERE operation_type = $2 AND user_id = $3 +RETURNING operation_id, operation_type, execution_time, user_id ` type UpdateOperationTimeParams struct { ExecutionTime int32 OperationType string + UserID int32 } func (q *Queries) UpdateOperationTime(ctx context.Context, arg UpdateOperationTimeParams) (Operation, error) { - row := q.db.QueryRowContext(ctx, updateOperationTime, arg.ExecutionTime, arg.OperationType) + row := q.db.QueryRowContext(ctx, updateOperationTime, arg.ExecutionTime, arg.OperationType, arg.UserID) var i Operation - err := row.Scan(&i.OperationID, &i.OperationType, &i.ExecutionTime) + err := row.Scan( + &i.OperationID, + &i.OperationType, + &i.ExecutionTime, + &i.UserID, + ) return i, err } diff --git a/backend/sql/queries/expressions.sql b/backend/sql/queries/expressions.sql index bb177a2..4172b2c 100644 --- a/backend/sql/queries/expressions.sql +++ b/backend/sql/queries/expressions.sql @@ -14,6 +14,7 @@ SELECT created_at, updated_at, data, parse_data, status, result, is_ready FROM expressions +WHERE user_id = $1 ORDER BY created_at DESC; -- name: GetExpressionByID :one diff --git a/backend/sql/queries/operations.sql b/backend/sql/queries/operations.sql index 7c63bf8..a250267 100644 --- a/backend/sql/queries/operations.sql +++ b/backend/sql/queries/operations.sql @@ -1,27 +1,24 @@ --- name: CreateOperation :exec -INSERT INTO operations (operation_type, execution_time) -VALUES ($1, $2) -RETURNING - operation_id, operation_type, execution_time; - -- name: UpdateOperationTime :one UPDATE operations SET execution_time = $1 -WHERE operation_type = $2 -RETURNING *; +WHERE operation_type = $2 AND user_id = $3 +RETURNING operation_id, operation_type, execution_time, user_id; -- name: GetOperations :many SELECT - operation_id, operation_type, execution_time + operation_id, operation_type, execution_time, user_id FROM operations +WHERE user_id = $1 ORDER BY operation_type DESC; -- name: GetOperationTimeByType :one -SELECT execution_time FROM operations -WHERE operation_type = $1; - --- name: GetOperationByType :one -SELECT - operation_id, operation_type, execution_time +SELECT execution_time FROM operations -WHERE operation_type = $1; +WHERE operation_type = $1 AND user_id = $2; + +-- name: NewOperationsForUser :exec +INSERT INTO operations (operation_type, user_id) VALUES +('+', $1), +('-', $1), +('*', $1), +('/', $1); \ No newline at end of file diff --git a/backend/sql/schema/0005_operations.sql b/backend/sql/schema/0005_operations.sql index 44d4717..6592476 100644 --- a/backend/sql/schema/0005_operations.sql +++ b/backend/sql/schema/0005_operations.sql @@ -3,8 +3,12 @@ CREATE TABLE IF NOT EXISTS operations ( operation_id int GENERATED ALWAYS AS IDENTITY, operation_type varchar(1) UNIQUE NOT NULL, execution_time int NOT NULL DEFAULT 100, + user_id int NOT NULL, - PRIMARY KEY(operation_id) + PRIMARY KEY(operation_id), + FOREIGN KEY(user_id) + REFERENCES users(user_id) + ON DELETE CASCADE ); -- +goose Down diff --git a/backend/sql/schema/0007_add_simple_operations.sql b/backend/sql/schema/0007_add_simple_operations.sql deleted file mode 100644 index 3f67e49..0000000 --- a/backend/sql/schema/0007_add_simple_operations.sql +++ /dev/null @@ -1,12 +0,0 @@ --- +goose Up -INSERT INTO operations (operation_type) VALUES -('+'), -('-'), -('*'), -('/'); - --- +goose Down -DELETE FROM operations WHERE operation_type = '+'; -DELETE FROM operations WHERE operation_type = '-'; -DELETE FROM operations WHERE operation_type = '/'; -DELETE FROM operations WHERE operation_type = '*'; \ No newline at end of file