From 08efc157a12717e1ce3022ab204942b2cf1216c3 Mon Sep 17 00:00:00 2001 From: Clivern Date: Mon, 19 Jul 2021 13:37:31 +0200 Subject: [PATCH] init --- core/component/authentication.go | 36 ++++++++++++++++++++++++++++++-- core/module/oauth_access_data.go | 2 +- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/core/component/authentication.go b/core/component/authentication.go index 380a1ac..f6c3e4c 100644 --- a/core/component/authentication.go +++ b/core/component/authentication.go @@ -8,6 +8,7 @@ import ( "encoding/base64" "fmt" "strings" + "time" "github.com/spacewalkio/helmet/core/model" "github.com/spacewalkio/helmet/core/module" @@ -93,6 +94,37 @@ func (b *BasicAuthMethod) Authenticate(endpoint model.Endpoint, authKey string) } // Authenticate validates auth headers -func (o *OAuthAuthMethod) Authenticate(endpoint model.Endpoint, accessToken string) error { - return nil +func (o *OAuthAuthMethod) Authenticate(endpoint model.Endpoint, accessToken string) (model.OAuthAccessData, error) { + var data model.OAuthAccessData + + if accessToken == "" { + return data, fmt.Errorf("Access token is missing") + } + + accessToken = strings.Replace(accessToken, "Bearer ", "", -1) + + data = o.Database.GetOAuthAccessDataByKey(accessToken) + + if data.ID < 1 { + return data, fmt.Errorf("Access token is invalid") + } + + // Validate if access token is expired + if time.Now().Unix() >= (data.ExpireAt.UnixNano() / int64(time.Millisecond)) { + return data, fmt.Errorf("Access token is expired") + } + + oauthData := o.Database.GetOAuthDataByID(data.OAuthDataID) + + if oauthData.ID < 1 { + return data, fmt.Errorf("Access token credentials are missing") + } + + authMethod := o.Database.GetAuthMethodByID(oauthData.AuthMethodID) + + if authMethod.Endpoints == "" || !util.InArray(endpoint.Name, strings.Split(authMethod.Endpoints, ";")) { + return data, fmt.Errorf("Access token is invalid") + } + + return data, nil } diff --git a/core/module/oauth_access_data.go b/core/module/oauth_access_data.go index 8add1ea..3a03dd2 100644 --- a/core/module/oauth_access_data.go +++ b/core/module/oauth_access_data.go @@ -33,7 +33,7 @@ func (db *Database) GetOAuthAccessDataByID(id int) model.OAuthAccessData { } // GetOAuthAccessDataByKeys gets an entity by keys -func (db *Database) GetOAuthAccessDataByKeys(accessToken string) model.OAuthAccessData { +func (db *Database) GetOAuthAccessDataByKey(accessToken string) model.OAuthAccessData { oauthAccessData := model.OAuthAccessData{} db.Connection.Where(