From 3561edb13d8c5d3e8b7a8e937aea2c9cd7250fb1 Mon Sep 17 00:00:00 2001 From: suntt2019 Date: Thu, 25 Feb 2021 22:58:12 +0800 Subject: [PATCH] Fix class; Update letterBytes for random string --- app/controller/class.go | 23 ++-------------------- app/controller/class_test.go | 16 +++++----------- app/request/class.go | 2 +- base/utils/class.go | 22 +++++++++++++++++++++ base/utils/class_test.go | 37 ++++++++++++++++++++++++++++++++++++ base/utils/utils.go | 2 +- 6 files changed, 68 insertions(+), 34 deletions(-) create mode 100644 base/utils/class.go create mode 100644 base/utils/class_test.go diff --git a/app/controller/class.go b/app/controller/class.go index fc629e3c..dcd3199d 100644 --- a/app/controller/class.go +++ b/app/controller/class.go @@ -11,27 +11,8 @@ import ( "github.com/pkg/errors" "gorm.io/gorm" "net/http" - "sync" ) -var inviteCodeLock sync.Mutex - -func GenerateInviteCode() (code string) { - inviteCodeLock.Lock() - defer inviteCodeLock.Unlock() - crashed := true - for crashed { - // 5: Fixed invite code length - code = utils.RandStr(5) - crashed = false - var count int64 - utils.PanicIfDBError(base.DB.Model(models.Class{}).Where("invite_code = ?", code).Count(&count), - "could not check if invite code crashed for generating invite code") - crashed = count >= 1 - } - return -} - func CreateClass(c echo.Context) error { user := c.Get("user").(models.User) req := request.CreateClassRequest{} @@ -43,7 +24,7 @@ func CreateClass(c echo.Context) error { Name: req.Name, CourseName: req.CourseName, Description: req.Description, - InviteCode: GenerateInviteCode(), + InviteCode: utils.GenerateInviteCode(), Managers: []models.User{ user, }, @@ -171,7 +152,7 @@ func RefreshInviteCode(c echo.Context) error { panic(errors.Wrap(err, "could not find class for refreshing invite code")) } } - class.InviteCode = GenerateInviteCode() + class.InviteCode = utils.GenerateInviteCode() utils.PanicIfDBError(base.DB.Save(&class), "could not update class for refreshing invite code") return c.JSON(http.StatusOK, response.RefreshInviteCodeResponse{ Message: "SUCCESS", diff --git a/app/controller/class_test.go b/app/controller/class_test.go index 3ea0f92a..3f7a2a00 100644 --- a/app/controller/class_test.go +++ b/app/controller/class_test.go @@ -2,11 +2,11 @@ package controller_test import ( "fmt" - "github.com/leoleoasd/EduOJBackend/app/controller" "github.com/leoleoasd/EduOJBackend/app/request" "github.com/leoleoasd/EduOJBackend/app/response" "github.com/leoleoasd/EduOJBackend/app/response/resource" "github.com/leoleoasd/EduOJBackend/base" + "github.com/leoleoasd/EduOJBackend/base/utils" "github.com/leoleoasd/EduOJBackend/database/models" "github.com/stretchr/testify/assert" "gorm.io/gorm" @@ -16,14 +16,14 @@ import ( ) func checkInviteCode(t *testing.T, code string) { - assert.Regexp(t, regexp.MustCompile("^[a-zA-Z]{5}$"), code) + assert.Regexp(t, regexp.MustCompile("^[a-zA-Z2-9]{5}$"), code) var count int64 assert.NoError(t, base.DB.Model(models.Class{}).Where("invite_code = ?", code).Count(&count).Error) assert.Equal(t, int64(1), count) } func createClassForTest(t *testing.T, name string, id int, managers, students []models.User) models.Class { - inviteCode := controller.GenerateInviteCode() + inviteCode := utils.GenerateInviteCode() class := models.Class{ Name: fmt.Sprintf("test_%s_%d_name", name, id), CourseName: fmt.Sprintf("test_%s_%d_course_name", name, id), @@ -36,12 +36,6 @@ func createClassForTest(t *testing.T, name string, id int, managers, students [] return class } -func TestGenerateInviteCode(t *testing.T) { - t.Parallel() - class := createClassForTest(t, "test_generate_invite_code_success", 0, nil, nil) - checkInviteCode(t, class.InviteCode) -} - func TestCreateClass(t *testing.T) { t.Parallel() @@ -733,7 +727,7 @@ func TestJoinClass(t *testing.T) { method: "POST", path: base.Echo.Reverse("class.joinClass", -1), req: request.JoinClassRequest{ - InviteCode: controller.GenerateInviteCode(), + InviteCode: utils.GenerateInviteCode(), }, reqOptions: []reqOption{applyAdminUser}, statusCode: http.StatusNotFound, @@ -745,7 +739,7 @@ func TestJoinClass(t *testing.T) { method: "POST", path: base.Echo.Reverse("class.joinClass", class1.ID), req: request.JoinClassRequest{ - InviteCode: controller.GenerateInviteCode(), + InviteCode: utils.GenerateInviteCode(), }, reqOptions: []reqOption{applyNormalUser}, statusCode: http.StatusForbidden, diff --git a/app/request/class.go b/app/request/class.go index 19555784..74eea608 100644 --- a/app/request/class.go +++ b/app/request/class.go @@ -33,7 +33,7 @@ type RefreshInviteCodeRequest struct { } type JoinClassRequest struct { - InviteCode string `json:"invite_code" form:"invite_code" query:"invite_code" validate:"required,alpha,max=255"` + InviteCode string `json:"invite_code" form:"invite_code" query:"invite_code" validate:"required,max=255"` } type DeleteClassRequest struct { diff --git a/base/utils/class.go b/base/utils/class.go new file mode 100644 index 00000000..6fbb32da --- /dev/null +++ b/base/utils/class.go @@ -0,0 +1,22 @@ +package utils + +import ( + "github.com/leoleoasd/EduOJBackend/base" + "github.com/leoleoasd/EduOJBackend/database/models" + "sync" +) + +var inviteCodeLock sync.Mutex + +func GenerateInviteCode() (code string) { + inviteCodeLock.Lock() + defer inviteCodeLock.Unlock() + var count int64 = 1 + for count > 0 { + // 5: Fixed invite code length + code = RandStr(5) + PanicIfDBError(base.DB.Model(models.Class{}).Where("invite_code = ?", code).Count(&count), + "could not check if invite code crashed for generating invite code") + } + return +} diff --git a/base/utils/class_test.go b/base/utils/class_test.go new file mode 100644 index 00000000..cca509e6 --- /dev/null +++ b/base/utils/class_test.go @@ -0,0 +1,37 @@ +package utils + +import ( + "fmt" + "github.com/leoleoasd/EduOJBackend/base" + "github.com/leoleoasd/EduOJBackend/database/models" + "github.com/stretchr/testify/assert" + "regexp" + "testing" +) + +func checkInviteCode(t *testing.T, code string) { + assert.Regexp(t, regexp.MustCompile("^[a-zA-Z2-9]{5}$"), code) + var count int64 + assert.NoError(t, base.DB.Model(models.Class{}).Where("invite_code = ?", code).Count(&count).Error) + assert.Equal(t, int64(1), count) +} + +func createClassForTest(t *testing.T, name string, id int, managers, students []models.User) models.Class { + inviteCode := GenerateInviteCode() + class := models.Class{ + Name: fmt.Sprintf("test_%s_%d_name", name, id), + CourseName: fmt.Sprintf("test_%s_%d_course_name", name, id), + Description: fmt.Sprintf("test_%s_%d_description", name, id), + InviteCode: inviteCode, + Managers: managers, + Students: students, + } + assert.NoError(t, base.DB.Create(&class).Error) + return class +} + +func TestGenerateInviteCode(t *testing.T) { + t.Parallel() + class := createClassForTest(t, "test_generate_invite_code_success", 0, nil, nil) + checkInviteCode(t, class.InviteCode) +} diff --git a/base/utils/utils.go b/base/utils/utils.go index 4d0ef926..10a5f135 100644 --- a/base/utils/utils.go +++ b/base/utils/utils.go @@ -8,7 +8,7 @@ import ( // Random string generator by https://stackoverflow.com/a/22892986/8031146 -const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const letterBytes = "abcdefghjkmnpqrstuvwxyzABCDEFGHJKMNPQRSTUVWXYZ23456789" const ( letterIdxBits = 6 // 6 bits to represent a letter index letterIdxMask = 1<