Skip to content

Commit

Permalink
More DB unit tests (#234)
Browse files Browse the repository at this point in the history
* Fix EarlyStopParam and SuggestionParam DB methods

GetEarlyStopParamList and GetSuggestionParamList mixed up the column
order and they returned nothing. Also, SetEarlyStopParam didn't
return an ID properly.

Signed-off-by: IWAMOTO Toshihiro <[email protected]>

* Add more DB UTs

Signed-off-by: IWAMOTO Toshihiro <[email protected]>
  • Loading branch information
toshiiw authored and k8s-ci-robot committed Nov 5, 2018
1 parent 8e90513 commit 04837a4
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 32 deletions.
18 changes: 5 additions & 13 deletions pkg/db/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ func (d *dbConn) GetSuggestionParam(paramID string) ([]*api.SuggestionParameter,
func (d *dbConn) GetSuggestionParamList(studyID string) ([]*api.SuggestionParameterSet, error) {
var rows *sql.Rows
var err error
rows, err = d.db.Query("SELECT * FROM suggestion_param WHERE study_id = ?", studyID)
rows, err = d.db.Query("SELECT id, suggestion_algo, parameters FROM suggestion_param WHERE study_id = ?", studyID)
if err != nil {
return nil, err
}
Expand All @@ -976,14 +976,10 @@ func (d *dbConn) GetSuggestionParamList(studyID string) ([]*api.SuggestionParame
var id string
var algorithm string
var params string
var sID string
err := rows.Scan(&id, &sID, &algorithm, &params)
err := rows.Scan(&id, &algorithm, &params)
if err != nil {
return nil, err
}
if studyID != sID {
continue
}
var pArray []string
if len(params) > 0 {
pArray = strings.Split(params, ",\n")
Expand Down Expand Up @@ -1021,7 +1017,7 @@ func (d *dbConn) SetEarlyStopParam(algorithm string, studyID string, params []*a
}
var paramID string
for true {
paramID := generateRandid()
paramID = generateRandid()
_, err = d.db.Exec("INSERT INTO earlystopping_param VALUES (?,?, ?, ?)",
paramID, algorithm, studyID, strings.Join(ps, ",\n"))
if err == nil {
Expand Down Expand Up @@ -1077,7 +1073,7 @@ func (d *dbConn) GetEarlyStopParam(paramID string) ([]*api.EarlyStoppingParamete
func (d *dbConn) GetEarlyStopParamList(studyID string) ([]*api.EarlyStoppingParameterSet, error) {
var rows *sql.Rows
var err error
rows, err = d.db.Query("SELECT * FROM earlystopping_param WHERE study_id = ?", studyID)
rows, err = d.db.Query("SELECT id, earlystop_algo, parameters FROM earlystopping_param WHERE study_id = ?", studyID)
if err != nil {
return nil, err
}
Expand All @@ -1086,14 +1082,10 @@ func (d *dbConn) GetEarlyStopParamList(studyID string) ([]*api.EarlyStoppingPara
var id string
var algorithm string
var params string
var sID string
err := rows.Scan(&id, &sID, &algorithm, &params)
err := rows.Scan(&id, &algorithm, &params)
if err != nil {
return nil, err
}
if studyID != sID {
continue
}
var pArray []string
if len(params) > 0 {
pArray = strings.Split(params, ",\n")
Expand Down
247 changes: 228 additions & 19 deletions pkg/db/interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ import (
var dbInterface VizierDBInterface
var mock sqlmock.Sqlmock

var studyColumns = []string{
"id", "name", "owner", "optimization_type", "optimization_goal",
"parameter_configs", "tags", "objective_value_name",
"metrics", "job_id"}
var trialColumns = []string{
"id", "study_id", "parameters", "objective_value", "tags"}
var workerColumns = []string{"id",
"study_id", "trial_id", "type",
"status", "template_path", "tags"}

func TestMain(m *testing.M) {
db, sm, err := sqlmock.New()
mock = sm
Expand Down Expand Up @@ -55,25 +65,43 @@ func TestGetStudyConfig(t *testing.T) {
}
// mock.ExpectExec("SELECT * FROM studies WHERE id").WithArgs(id).WillReturnRows(sqlmock.NewRows())
mock.ExpectQuery("SELECT").WillReturnRows(
sqlmock.NewRows([]string{
"id",
"name",
"owner",
"optimization_type",
"optimization_goal",
"parameter_configs",
"tags",
"objective_value_name",
"metrics",
"job_id",
}).
AddRow("abc", "test", "admin", 1, 0.99, "{}", "", "", "", "test"))
sqlmock.NewRows(studyColumns).AddRow(
"abc", "test", "admin", 1, 0.99, "{}", "", "", "", "test"))
study, err := dbInterface.GetStudyConfig(id)
if err != nil {
t.Errorf("GetStudyConfig failed: %v", err)
} else if study.Name != "test" || study.Owner != "admin" {
t.Errorf("GetStudyConfig incorrect return %v", study)
}
}

func TestGetStudyList(t *testing.T) {
ids := []string{"abcde1234567890f", "bcde1234567890fa"}
mock.ExpectQuery("SELECT id FROM studies").WillReturnRows(
sqlmock.NewRows([]string{"id"}).AddRow(ids[0]).AddRow(ids[1]))
r, err := dbInterface.GetStudyList()
if err != nil {
t.Errorf("GetStudyList error %v", err)
}
if len(r) != len(ids) {
t.Errorf("GetStudyList returned incorrect number of ids %d != %d",
len(r), len(ids))
}
for i, id := range r {
if ids[i] != id {
t.Errorf("GetStudyList returned incorrect ID %s != %s",
id, ids[i])
}
}
}

func TestDeleteStudy(t *testing.T) {
studyID := generateRandid()
mock.ExpectExec(`DELETE FROM studies WHERE id = \?`).WithArgs(studyID).WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.DeleteStudy(studyID)
if err != nil {
t.Errorf("DeleteStudy error %v", err)
}
fmt.Printf("%v", study)
// TODO: check study data
}

func TestCreateStudyIdGeneration(t *testing.T) {
Expand Down Expand Up @@ -108,6 +136,58 @@ func TestCreateStudyIdGeneration(t *testing.T) {
}
}

func TestGetTrial(t *testing.T) {
id := generateRandid()
mock.ExpectQuery(`SELECT \* FROM trials WHERE id = \?`).WillReturnRows(
sqlmock.NewRows(trialColumns).AddRow(
id, "s1234567890abcde",
"{\"name\": \"1\"},\n{}", "obj_val",
"{\"name\": \"foo\"},\n{}"))
trial, err := dbInterface.GetTrial(id)
if err != nil {
t.Errorf("GetTrial error %v", err)
} else if len((*trial).Tags) != 2 {
t.Errorf("GetTrial returned incorrect Tag %v", (*trial).Tags)
}
}

func TestGetTrialList(t *testing.T) {
studyID := generateRandid()
var ids = []string{"abcdef1234567890", "bcdef1234567890a"}
rows := sqlmock.NewRows(trialColumns)
for _, id := range ids {
rows.AddRow(id, studyID, "", "obj_val", "")
}
mock.ExpectQuery(`SELECT \* FROM trials WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(rows)
trials, err := dbInterface.GetTrialList(studyID)
if err != nil {
t.Errorf("GetTrialList error %v", err)
} else if len(trials) != len(ids) {
t.Errorf("GetTrialList returned incorrect number of trials %d != %d",
len(trials), len(ids))
}
}

func TestCreateTrial(t *testing.T) {
var trial api.Trial
trial.StudyId = generateRandid()
mock.ExpectExec(`INSERT INTO trials VALUES \(`).WithArgs(sqlmock.AnyArg(),
trial.StudyId, "", "", "").WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.CreateTrial(&trial)
if err != nil {
t.Errorf("CreateTrial error %v", err)
}
}

func TestDeleteTrial(t *testing.T) {
id := generateRandid()
mock.ExpectExec(`DELETE FROM trials WHERE id = \?`).WithArgs(id).WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.DeleteTrial(id)
if err != nil {
t.Errorf("DeleteTrial error %v", err)
}
}

func TestCreateWorker(t *testing.T) {
var w api.Worker
w.StudyId = generateRandid()
Expand All @@ -124,10 +204,6 @@ func TestCreateWorker(t *testing.T) {
}
}

var workerColumns = []string{"id",
"study_id", "trial_id", "type",
"status", "template_path", "tags"}

const defaultWorkerID = "w123456789abcdef"
const objValueName = "obj_value"

Expand Down Expand Up @@ -189,6 +265,31 @@ func TestDeleteWorker(t *testing.T) {

}

func TestGetWorkerFullInfo(t *testing.T) {
studyID := generateRandid()
wRows := sqlmock.NewRows(workerColumns)
wRows.AddRow("w1134567890abcde", studyID, "", "", "1", "", "")
wRows.AddRow("w2234567890abcde", studyID, "", "", "2", "", "")
mock.ExpectQuery(`SELECT \* FROM workers WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(wRows)
mock.ExpectQuery(`SELECT \* FROM trials WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(
sqlmock.NewRows(trialColumns))
mock.ExpectQuery(`SELECT \* FROM studies WHERE id = \?`).WithArgs(studyID).WillReturnRows(
sqlmock.NewRows(studyColumns).AddRow(
studyID, "test", "admin", 1, 0.99, "{}", "", "", "foo,\nbar", "test"))
WMRows := sqlmock.NewRows([]string{"WM.worker_id", "WM.time", "WM.name", "WM.value"})
WMRows.AddRow("w1134567890abcde", "2012-01-01 09:54:32", "foo", "1")
WMRows.AddRow("w1134567890abcde", "2012-01-01 09:54:32", "bar", "1")
mock.ExpectQuery(`SELECT WM.worker_id, WM.time, WM.name, WM.value FROM .* MaxID .* ON WM.worker_id`).WithArgs(studyID).WillReturnRows(WMRows)

fi, err := dbInterface.GetWorkerFullInfo(studyID, "", "", true)
if err != nil {
t.Errorf("GetWorkerFullInfo error %v", err)
} else if len(fi.WorkerFullInfos) != 2 ||
len(fi.WorkerFullInfos[0].MetricsLogs) != 2 {
t.Errorf("GetWorkerFullInfo incorrect return %v", fi)
}
}

type MetricsLogData struct {
stored bool
name string
Expand Down Expand Up @@ -337,3 +438,111 @@ func TestGetWorkerLogs(t *testing.T) {
}
}
}

func TestSetSuggestionParam(t *testing.T) {
sp := make([]*api.SuggestionParameter, 1)
sp[0] = &api.SuggestionParameter{Name: "DefaultGrid", Value: "1"}
studyID := generateRandid()
mock.ExpectExec("INSERT INTO suggestion_param VALUES").WithArgs(
sqlmock.AnyArg(), "grid", studyID,
`{"name":"DefaultGrid","value":"1"}`).WillReturnResult(sqlmock.NewResult(1, 1))
id, err := dbInterface.SetSuggestionParam("grid", studyID, sp)
if err != nil {
t.Errorf("SetSuggestionParam error %v", err)
} else if len(id) != 16 {
t.Errorf("SetSuggestionParam returned incorrect ID %s", id)
}
}

func TestUpdateSuggestionParam(t *testing.T) {
sp := make([]*api.SuggestionParameter, 1)
sp[0] = &api.SuggestionParameter{Name: "DefaultGrid", Value: "12"}
id := generateRandid()
mock.ExpectExec(`UPDATE suggestion_param SET parameters = \? WHERE id = \?`).WithArgs(
`{"name":"DefaultGrid","value":"12"}`, id).WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.UpdateSuggestionParam(id, sp)
if err != nil {
t.Errorf("UpdateSuggestionParam error %v", err)
}
}

func TestGetSuggestionParam(t *testing.T) {
id := generateRandid()
mock.ExpectQuery(`SELECT parameters FROM suggestion_param WHERE id = \?`).WithArgs(id).WillReturnRows(
sqlmock.NewRows([]string{"parameters"}).AddRow(
`{"name":"DefaultGrid","value":"12"}`))
sp, err := dbInterface.GetSuggestionParam(id)
if err != nil {
t.Errorf("GetSuggestionParam error %v", err)
} else if len(sp) != 1 {
t.Errorf("GetSuggestionParam returned incorrect number of data %v", sp)
}
}

func TestGetSuggestionParamList(t *testing.T) {
studyID := generateRandid()
mock.ExpectQuery(`SELECT id, suggestion_algo, parameters FROM suggestion_param WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(
sqlmock.NewRows([]string{"id", "suggestion_algo", "parameters"}).AddRow(
generateRandid(), "random", "{}"))

sp, err := dbInterface.GetSuggestionParamList(studyID)
if err != nil {
t.Errorf("GetSuggestionParamList error %v", err)
} else if len(sp) != 1 {
t.Errorf("GetSuggestionParamList returned incorrect number of data %v", sp)
}
}

func TestSetEarlyStopParam(t *testing.T) {
ep := make([]*api.EarlyStoppingParameter, 1)
ep[0] = &api.EarlyStoppingParameter{Name: "LeastStep", Value: "1"}
studyID := generateRandid()
mock.ExpectExec("INSERT INTO earlystopping_param VALUES").WithArgs(
sqlmock.AnyArg(), "medianstopping", studyID,
`{"name":"LeastStep","value":"1"}`).WillReturnResult(sqlmock.NewResult(1, 1))
id, err := dbInterface.SetEarlyStopParam("medianstopping", studyID, ep)
if err != nil {
t.Errorf("SetEarlyStopParam error %v", err)
} else if len(id) != 16 {
t.Errorf("SetEarlyStopParam returned incorrect ID %s", id)
}
}

func TestUpdateEarlyStopParam(t *testing.T) {
ep := make([]*api.EarlyStoppingParameter, 1)
ep[0] = &api.EarlyStoppingParameter{Name: "LeastStep", Value: "12"}
id := generateRandid()
mock.ExpectExec(`UPDATE earlystopping_param SET parameters = \? WHERE id = \?`).WithArgs(
`{"name":"LeastStep","value":"12"}`, id).WillReturnResult(sqlmock.NewResult(1, 1))
err := dbInterface.UpdateEarlyStopParam(id, ep)
if err != nil {
t.Errorf("UpdateEarlyStopParamerror %v", err)
}
}

func TestGetEarlyStopParam(t *testing.T) {
id := generateRandid()
mock.ExpectQuery(`SELECT parameters FROM earlystopping_param WHERE id = \?`).WithArgs(id).WillReturnRows(
sqlmock.NewRows([]string{"parameters"}).AddRow(
`{"name":"LeastStep","value":"12"}`))
ep, err := dbInterface.GetEarlyStopParam(id)
if err != nil {
t.Errorf("GetEarlyStopParam error %v", err)
} else if len(ep) != 1 {
t.Errorf("GetEarlyStopParam returned incorrect number of data %v", ep)
}
}

func TestGetEarlyStopParamList(t *testing.T) {
studyID := generateRandid()
mock.ExpectQuery(`SELECT id, earlystop_algo, parameters FROM earlystopping_param WHERE study_id = \?`).WithArgs(studyID).WillReturnRows(
sqlmock.NewRows([]string{"id", "earlystop_algo", "parameters"}).AddRow(
generateRandid(), "medianstopping", "{}"))

ep, err := dbInterface.GetEarlyStopParamList(studyID)
if err != nil {
t.Errorf("GetEarlyStopParamList error %v", err)
} else if len(ep) != 1 {
t.Errorf("GetEarlyStopParamList returned incorrect number of data %v", ep)
}
}

0 comments on commit 04837a4

Please sign in to comment.