From 1c1060bc4195d0755b020d05d3bc32fcb8873ba1 Mon Sep 17 00:00:00 2001 From: Joshua Richardson-Noyes Date: Wed, 20 Jan 2021 21:06:23 +0000 Subject: [PATCH] Switched origninal Where function to a WhereRaw function. Created new Where function(s) capable of taking structs in and storing the filters this way. Added support for multiple data types and for handling of arrays of filters as well as an array of values for one filter. (I.E. WHERE X IN (1, 2, 3)) --- munch.go | 208 +++++++++++++++++++++++++++++++++++-------- munch_test.go | 233 +++++++++++++++++++++++++++++++++++++++++++++++++ tests/basic.go | 45 ---------- 3 files changed, 403 insertions(+), 83 deletions(-) create mode 100644 munch_test.go delete mode 100644 tests/basic.go diff --git a/munch.go b/munch.go index e9305e9..ac4310b 100644 --- a/munch.go +++ b/munch.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "reflect" + "strconv" + "strings" ) // SQL Operations - ENUM @@ -34,30 +36,85 @@ type query struct { operation int table string where []filter + columns []string - data map[string]string + data map[string]interface{} } type filter struct { isOr bool columnName string comparator string - value string + value interface{} } -func (q *query) OrWhere(columnName, comparator, value string) { +func getColumns(in reflect.Type) []string { + cols := make([]string, 0) + for x := 0; x < in.NumField(); x++ { + var columnName string + sqlTag := in.Field(x).Tag.Get("sql") + if len(sqlTag) > 0 { + columnName = sqlTag + } else { + columnName = in.Field(x).Name + } + + cols = append(cols, columnName) + } + + return cols +} + +func (q *query) Select(cols []string) { + q.columns = cols +} + +func (q *query) Where(in interface{}) { + t := reflect.TypeOf(in) + + if t.Kind() == reflect.Slice { + mySlice := reflect.ValueOf(in) + for i := 0; i < mySlice.Len(); i++ { + obj := mySlice.Index(i).Interface() + + q.Where(obj) + } + } else if t.Kind() == reflect.Struct { + cols := getColumns(t) + + q.filterObj(in, cols) + } +} + +func (q *query) WhereIn(columnName string, values interface{}, isOr bool) { + q.addFilter(columnName, "IN", values, isOr) +} + +func (q *query) filterObj(obj interface{}, columns []string) { + v := reflect.ValueOf(obj) + + for i := 0; i < len(columns); i++ { + fieldVal := v.Field(i) + //fieldStr := fieldVal.String() + if fieldVal.IsValid() { + q.addFilter(columns[i], "=", fieldVal.Interface(), false) + } + } +} + +func (q *query) OrWhereRaw(columnName, comparator string, value interface{}) { q.addFilter(columnName, comparator, value, true) } -func (q *query) Where(columnName, comparator, value string) { +func (q *query) WhereRaw(columnName, comparator string, value interface{}) { q.addFilter(columnName, comparator, value, false) } -func (q *query) AndWhere(columnName, comparator, value string) { +func (q *query) AndWhereRaw(columnName, comparator string, value interface{}) { q.addFilter(columnName, comparator, value, false) } -func (q *query) addFilter(columnName, comparator, value string, isOr bool) { +func (q *query) addFilter(columnName, comparator string, value interface{}, isOr bool) { q.where = append(q.where, filter{ isOr: isOr, columnName: columnName, @@ -75,14 +132,13 @@ func (q *query) appendData(in interface{}) { sTag := field.Tag.Get("sql") v := reflect.ValueOf(in) - s := reflect.Indirect(v).FieldByName(field.Name).String() + //s := reflect.Indirect(v).FieldByName(field.Name) + s := v.Field(i).Interface() - if len(s) > 0 { - if len(sTag) > 0 { - q.data[sTag] = s - } else { - q.data[field.Name] = s - } + if len(sTag) > 0 { + q.data[sTag] = s + } else { + q.data[field.Name] = s } } } @@ -91,7 +147,7 @@ func (q *query) Insert(in interface{}) { q.operation = sql_INSERT if q.data == nil { - q.data = make(map[string]string) + q.data = make(map[string]interface{}) } q.appendData(in) @@ -101,7 +157,7 @@ func (q *query) Update(in interface{}) { q.operation = sql_UPDATE if q.data == nil { - q.data = make(map[string]string) + q.data = make(map[string]interface{}) } q.appendData(in) @@ -115,6 +171,43 @@ func (q *query) Delete() { q.operation = sql_DELETE } +func formatValue(t reflect.Type, v reflect.Value) string { + valStr := "" + + switch t.Kind() { + case reflect.Int: + valStr = strconv.FormatInt(v.Int(), 10) + break + case reflect.Bool: + valStr = strings.ToUpper(strconv.FormatBool(v.Bool())) + break + case reflect.Float64: + valStr = strconv.FormatFloat(v.Float(), 'f', -1, 64) + break + case reflect.Slice: + valList := "" + for i := 0; i < v.Len(); i++ { + vIdx := v.Index(i) + if i > 0 { + valList += ", " + } + valList += formatValue(vIdx.Type(), vIdx) + } + if len(valList) > 0 { + valStr = fmt.Sprintf("(%s)", valList) + } + break + default: + s := v.String() + if len(s) > 0 { + valStr = fmt.Sprintf("'%s'", v.String()) + } + break + } + + return valStr +} + func (q *query) ToSQL() string { var sqlStr string filterSql := "" @@ -131,7 +224,13 @@ func (q *query) ToSQL() string { sqlStr = fmt.Sprintf("DELETE FROM `%s`", q.table) break default: - sqlStr = fmt.Sprintf("SELECT * FROM `%s`", q.table) + colString := "*" + + if len(q.columns) > 0 { + colString = "`" + strings.Join(q.columns, "`, `") + "`" + } + + sqlStr = fmt.Sprintf("SELECT %s FROM `%s`", colString, q.table) break } @@ -139,15 +238,31 @@ func (q *query) ToSQL() string { first := true filterSql = "" for _, filter := range q.where { - predicate := "AND" - if first { - predicate = "WHERE" - first = false - } else if filter.isOr { - predicate = "OR" - } + var ( + valStr string + escape = "" + ) + + fType := reflect.TypeOf(filter.value) + fValue := reflect.ValueOf(filter.value) + + valStr = formatValue(fType, fValue) + + if len(valStr) > 2 { + predicate := "AND" + if first { + predicate = "WHERE" + first = false + } else if filter.isOr { + predicate = "OR" + } + + if fType.Kind() == reflect.Slice { + filter.comparator = "IN" + } - filterSql += fmt.Sprintf(" %s `%s` %s \"%s\"", predicate, filter.columnName, filter.comparator, filter.value) + filterSql += fmt.Sprintf(" %s `%s` %s %s%s%s", predicate, filter.columnName, filter.comparator, escape, valStr, escape) + } } } @@ -155,30 +270,47 @@ func (q *query) ToSQL() string { first := true if q.operation == sql_INSERT { cols := "" - vals := "" + values := "" for col, val := range q.data { - if first { - first = false - } else { - cols += ", " - vals += ", " + valT := reflect.TypeOf(val) + valV := reflect.ValueOf(val) + + valString := formatValue(valT, valV) + + if len(valString) > 0 { + if first { + first = false + } else { + cols += ", " + values += ", " + } + cols += "`" + col + "`" + + values += valString } - cols += "`" + col + "`" - vals += "\"" + val + "\"" } - dataSql += fmt.Sprintf(" (%s) VALUES (%s)", cols, vals) + dataSql += fmt.Sprintf(" (%s) VALUES (%s)", cols, values) } else if q.operation == sql_UPDATE { colUpdates := "" first := true for col, val := range q.data { - if first { - first = false - } else { - colUpdates += ", " + valT := reflect.TypeOf(val) + valV := reflect.ValueOf(val) + + valString := formatValue(valT, valV) + + if len(valString) > 0 { + + if first { + first = false + } else { + colUpdates += ", " + } + + colUpdates += fmt.Sprintf("`%s` = %s", col, valString) } - colUpdates += fmt.Sprintf("`%s` = \"%s\"", col, val) } dataSql += fmt.Sprintf(" SET %s", colUpdates) } diff --git a/munch_test.go b/munch_test.go new file mode 100644 index 0000000..3f7ba29 --- /dev/null +++ b/munch_test.go @@ -0,0 +1,233 @@ +package munch + +import ( + "fmt" + "testing" +) + +var qb = &QueryBuilder{} + +/*func TestInitialiseQB(t *testing.T) { + qb := &QueryBuilder{} + + if qb == nil { + t.Error("initialisation of query builder interface failed") + } +}*/ + +func assertEqual(actual, expected string) error { + if actual != expected { + return fmt.Errorf("mismatch!\nexpected:\t %s\ngot:\t\t %s", expected, actual) + } + + return nil +} + +// TestBasicRawSelect +// Expects: SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' AND `Lastname` = 'User' AND `Email` = 'test@test.com'; +func TestBasicRawSelect(t *testing.T) { + query := qb.Table("TEST_TABLE_1") + query.WhereRaw("Firstname", "=", "Test") + query.AndWhereRaw("Lastname", "=", "User") + query.AndWhereRaw("Email", "=", "test@test.com") + + sql := query.ToSQL() + + err := assertEqual(sql, "SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' AND `Lastname` = 'User' AND `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestWhereOr +// Expects: SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' OR `Email` = 'test@test.com'; +func TestWhereOr(t *testing.T) { + query := qb.Table("TEST_TABLE_1") + query.Where(BasicTestObject{ + Firstname: "Test", + }) + query.OrWhereRaw("Email", "=", "test@test.com") + + sql := query.ToSQL() + + err := assertEqual(sql, "SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' OR `Email` = 'test@test.com';") + if err != nil { + t.Error(err.Error()) + } +} + +type BasicTestObject struct { + Firstname string + Lastname string + Email string +} + +// TestBasicInsert +// Expects: INSERT INTO `TEST_TABLE_2` (`Firstname`, `Lastname`, `Email`) VALUES ('Test', 'User', 'test@test.com'); +func TestBasicInsert(t *testing.T) { + query := qb.Table("TEST_TABLE_2") + query.Insert(BasicTestObject{ + Firstname: "Test", + Lastname: "User", + Email: "test@test.com", + }) + + sql := query.ToSQL() + + err := assertEqual(sql, "INSERT INTO `TEST_TABLE_2` (`Firstname`, `Lastname`, `Email`) VALUES ('Test', 'User', 'test@test.com');") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestBasicUpdate +// Expects: UPDATE `TEST_TABLE_3` SET `Firstname` = 'Test', `Lastname` = 'User' WHERE `Email` = 'test@test.com'; +func TestBasicUpdate(t *testing.T) { + query := qb.Table("TEST_TABLE_3") + query.Update(BasicTestObject{ + Firstname: "Test", + Lastname: "User", + }) + query.Where(BasicTestObject{ + Email: "test@test.com", + }) + + sql := query.ToSQL() + err := assertEqual(sql, "UPDATE `TEST_TABLE_3` SET `Firstname` = 'Test', `Lastname` = 'User' WHERE `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestBasicDelete +// Expects: DELETE FROM `TEST_TABLE_4` WHERE `Email` = 'test@test.com'; +func TestBasicDelete(t *testing.T) { + query := qb.Table("TEST_TABLE_4") + query.Where(BasicTestObject{Email: "test@test.com"}) + query.Del() + + sql := query.ToSQL() + err := assertEqual(sql, "DELETE FROM `TEST_TABLE_4` WHERE `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +type WhereInObject struct { + Usernames []string `sql:"Username"` +} + +// TestWhereInString +// Expects: SELECT * FROM `Users` WHERE `Username` IN ('Test', 'Test2', 'Test3'); +func TestWhereInString(t *testing.T) { + query := qb.Table("Users") + query.Where(WhereInObject{Usernames: []string{"Test", "Test2", "Test3"}}) + //query.WhereRaw("Username", "IN", []string{"Test","Test2","Test3"}) + + sql := query.ToSQL() + err := assertEqual(sql, "SELECT * FROM `Users` WHERE `Username` IN ('Test', 'Test2', 'Test3');") + + if err != nil { + t.Error(err.Error()) + } +} + +type TestGroupObj struct { + GID int `sql:"GroupId"` + GName string `sql:"GroupName"` +} + +// TestInsertWithTags +// Expects: INSERT INTO `UserGroups` (`GroupId`, `GroupName`) VALUES (5, 'Test Group 5'); +func TestInsertWithTags(t *testing.T) { + query := qb.Table("UserGroups") + query.Insert(TestGroupObj{ + GID: 5, + GName: "Test Group 5", + }) + + sql := query.ToSQL() + err := assertEqual(sql, "INSERT INTO `UserGroups` (`GroupId`, `GroupName`) VALUES (5, 'Test Group 5');") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestWhereInInt +// Expects: SELECT * FROM `UserGroups` WHERE `GroupId` IN (1, 2, 3, 4); +func TestWhereInInt(t *testing.T) { + query := qb.Table("UserGroups") + query.WhereIn("GroupId", []int{1, 2, 3, 4}, false) + + sql := query.ToSQL() + err := assertEqual(sql, "SELECT * FROM `UserGroups` WHERE `GroupId` IN (1, 2, 3, 4);") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestMultipleWhereInOne +// Expects: SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' AND `Lastname` = 'User' AND `Email` = 'test@test.com'; +func TestMultipleWhereInOne(t *testing.T) { + query := qb.Table("TEST_TABLE_1") + + query.Where([]BasicTestObject{ + {Firstname: "Test"}, + {Lastname: "User"}, + {Email: "test@test.com"}, + }) + + sql := query.ToSQL() + err := assertEqual(sql, "SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' AND `Lastname` = 'User' AND `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestSelectSpecificColumns +// Expects: SELECT `UserId`, `Username` FROM `Users` WHERE `Email` = 'test@test.com'; +func TestSelectSpecificColumns(t *testing.T) { + query := qb.Table("Users") + query.Where(BasicTestObject{Email: "test@test.com"}) + query.Select([]string{"UserId", "Username"}) + + sql := query.ToSQL() + err := assertEqual(sql, "SELECT `UserId`, `Username` FROM `Users` WHERE `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +type ComplexTestObj struct { + MyText string + MyFloat float64 + MyInt int + MyBool bool +} + +// TestInsertComplexObject +// Expects: INSERT INTO `ComplexTable` (`MyText`, `MyFloat`, `MyInt`, `MyBool`) VALUES ('Test', 5.66, 9, TRUE); +func TestInsertComplexObject(t *testing.T) { + query := qb.Table("ComplexTable") + query.Insert(ComplexTestObj{ + MyText: "Test", + MyFloat: 5.660, + MyInt: 9, + MyBool: true, + }) + + sql := query.ToSQL() + err := assertEqual(sql, "INSERT INTO `ComplexTable` (`MyText`, `MyFloat`, `MyInt`, `MyBool`) VALUES ('Test', 5.66, 9, TRUE);") + + if err != nil { + t.Error(err.Error()) + } +} diff --git a/tests/basic.go b/tests/basic.go deleted file mode 100644 index 5db080d..0000000 --- a/tests/basic.go +++ /dev/null @@ -1,45 +0,0 @@ -package main - -import ( - "fmt" - "munch" -) - -type TestData1 struct { - Firstname string - Lastname string - Email string -} - -func main() { - - qb := munch.QueryBuilder{} - - query := qb.Table("TEST_TABLE_1") - - query.Where("Count", ">", "10") - query.AndWhere("Age", ">", "5") - query.OrWhere("Name", "=", "Admiral Joshua") - - fmt.Println(query.ToSQL()) - - iQuery := qb.Table("TEST_TABLE_2") - iQuery.Insert(TestData1{ - Firstname: "Test", - Lastname: "User", - Email: "test@test.com", - }) - fmt.Println(iQuery.ToSQL()) - - uQuery := qb.Table("TEST_TABLE_3") - uQuery.Update(TestData1{ - Firstname: "Test", - }) - uQuery.Where("email", "=", "test@test.com") - fmt.Println(uQuery.ToSQL()) - - dQuery := qb.Table("TEST_TABLE_2") - dQuery.Where("Firstname", "=", "Leila") - dQuery.Delete() - fmt.Println(dQuery.ToSQL()) -}