Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix errors in /account #14

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/account/changepw.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
)

func ChangePW(uuid []byte, password string) error {
if len(password) < 6 {
return fmt.Errorf("invalid password")
if err := validatePassword(password); err != nil {
return err
}

salt := make([]byte, ArgonSaltSize)
Expand Down
22 changes: 22 additions & 0 deletions api/account/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package account

import (
"cmp"
"net/http"
"regexp"
"runtime"

"github.com/pagefaultgames/rogueserver/errors"
"golang.org/x/crypto/argon2"
)

Expand Down Expand Up @@ -52,3 +55,22 @@ func deriveArgon2IDKey(password, salt []byte) []byte {

return argon2.IDKey(password, salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize)
}

func validateUsernamePassword(username string, password string) error {
return cmp.Or(validateUsername(username), validatePassword(password))
}

func validateUsername(username string) error {
if !isValidUsername(username) {
return errors.NewHttpError(http.StatusBadRequest, "invalid username")
}
return nil
}

func validatePassword(password string) error {
if len(password) < 6 {
return errors.NewHttpError(http.StatusBadRequest, "invalid password")
}

return nil
}
61 changes: 61 additions & 0 deletions api/account/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package account

import (
"net/http"
"testing"

"github.com/pagefaultgames/rogueserver/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestValidateUsernamePassword(t *testing.T) {
t.Run("valid username and password", func(t *testing.T) {
err := validateUsernamePassword("validUser", "validPass")
assert.NoError(t, err)
})

t.Run("invalid username", func(t *testing.T) {
err := validateUsernamePassword("", "validPass")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
})

t.Run("invalid password", func(t *testing.T) {
err := validateUsernamePassword("validUser", "123")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid password"))
})

t.Run("invalid username and password", func(t *testing.T) {
err := validateUsernamePassword("", "123")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
})
}

func TestValidateUsername(t *testing.T) {
t.Run("valid username", func(t *testing.T) {
err := validateUsername("validUser")
assert.NoError(t, err)
})

t.Run("invalid username", func(t *testing.T) {
err := validateUsername("")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
})
}

func TestValidatePassword(t *testing.T) {
t.Run("valid password", func(t *testing.T) {
err := validatePassword("validPass")
assert.NoError(t, err)
})

t.Run("invalid password", func(t *testing.T) {
err := validatePassword("123")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid password"))
})
}
14 changes: 6 additions & 8 deletions api/account/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import (
"database/sql"
"encoding/base64"
"fmt"
"net/http"

"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/errors"
)

type LoginResponse GenericAuthResponse
Expand All @@ -33,25 +35,21 @@ type LoginResponse GenericAuthResponse
func Login(username, password string) (LoginResponse, error) {
var response LoginResponse

if !isValidUsername(username) {
return response, fmt.Errorf("invalid username")
}

if len(password) < 6 {
return response, fmt.Errorf("invalid password")
if err := validateUsernamePassword(username, password); err != nil {
return response, err
}

key, salt, err := db.FetchAccountKeySaltFromUsername(username)
if err != nil {
if err == sql.ErrNoRows {
return response, fmt.Errorf("account doesn't exist")
return response, errors.NewHttpError(http.StatusNotFound, "account doesn't exist")
}

return response, err
}

if !bytes.Equal(key, deriveArgon2IDKey([]byte(password), salt)) {
return response, fmt.Errorf("password doesn't match")
return response, errors.NewHttpError(http.StatusUnauthorized, "password doesn't match")
}

token := make([]byte, TokenSize)
Expand Down
15 changes: 9 additions & 6 deletions api/account/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ package account

import (
"crypto/rand"
stderrors "errors"
"fmt"
"net/http"

"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/errors"
)

// /account/register - register account
func Register(username, password string) error {
if !isValidUsername(username) {
return fmt.Errorf("invalid username")
}

if len(password) < 6 {
return fmt.Errorf("invalid password")
if err := validateUsernamePassword(username, password); err != nil {
return err
}

uuid := make([]byte, UUIDSize)
Expand All @@ -47,6 +47,9 @@ func Register(username, password string) error {

err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
if err != nil {
if stderrors.Is(err, db.ErrAccountAlreadyExists) {
return errors.NewHttpError(http.StatusConflict, fmt.Sprintf(`username "%s" already taken`, username))
}
return fmt.Errorf("failed to add account record: %s", err)
}

Expand Down
30 changes: 23 additions & 7 deletions api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ package api
import (
"encoding/base64"
"encoding/json"
stderrors "errors"
"fmt"
"log"
"net/http"

"github.com/pagefaultgames/rogueserver/api/account"
"github.com/pagefaultgames/rogueserver/api/daily"
"github.com/pagefaultgames/rogueserver/db"
"log"
"net/http"
"github.com/pagefaultgames/rogueserver/errors"
)

func Init(mux *http.ServeMux) error {
Expand Down Expand Up @@ -69,16 +72,16 @@ func Init(mux *http.ServeMux) error {

func tokenFromRequest(r *http.Request) ([]byte, error) {
if r.Header.Get("Authorization") == "" {
return nil, fmt.Errorf("missing token")
return nil, errors.NewHttpError(http.StatusBadRequest, "missing token")
}

token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil {
return nil, fmt.Errorf("failed to decode token: %s", err)
return nil, errors.NewHttpError(http.StatusBadRequest, "failed to decode token")
}

if len(token) != account.TokenSize {
return nil, fmt.Errorf("invalid token length: got %d, expected %d", len(token), account.TokenSize)
return nil, errors.NewHttpError(http.StatusBadRequest, "invalid token length")
}

return token, nil
Expand All @@ -97,14 +100,17 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) {

uuid, err := db.FetchUUIDFromToken(token)
if err != nil {
return nil, nil, fmt.Errorf("failed to validate token: %s", err)
if stderrors.Is(err, db.ErrTokenNotFound) {
return nil, nil, errors.NewHttpError(http.StatusUnauthorized, "bad token")
}
return nil, nil, fmt.Errorf("failed to fetch uuid from db: %w", err)
}

return token, uuid, nil
}

func httpError(w http.ResponseWriter, r *http.Request, err error, code int) {
log.Printf("%s: %s\n", r.URL.Path, err)
log.Printf("%s: %s\n", r.URL.Path, err.Error())
http.Error(w, err.Error(), code)
}

Expand All @@ -116,3 +122,13 @@ func jsonResponse(w http.ResponseWriter, r *http.Request, data any) {
return
}
}

func statusCodeFromError(err error) int {
var httpErr *errors.HttpError

if stderrors.As(err, &httpErr) {
return httpErr.Code
}

return http.StatusInternalServerError
}
28 changes: 28 additions & 0 deletions api/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package api

import (
stderrors "errors"
"net/http"
"testing"

"github.com/pagefaultgames/rogueserver/errors"
"github.com/stretchr/testify/assert"
)

func TestStatusCodeFromError(t *testing.T) {
t.Run("nil", func(t *testing.T) {
code := statusCodeFromError(nil)
assert.Equal(t, http.StatusInternalServerError, code)
})
t.Run("http error", func(t *testing.T) {
err := errors.NewHttpError(http.StatusTeapot, "teapot")
code := statusCodeFromError(err)
assert.Equal(t, http.StatusTeapot, code)
})

t.Run("standard error", func(t *testing.T) {
err := stderrors.New("standard error")
code := statusCodeFromError(err)
assert.Equal(t, http.StatusInternalServerError, code)
})
}
Loading
Loading