diff --git a/munch.go b/munch.go index e9305e9..ac4310b 100644 --- a/munch.go +++ b/munch.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "reflect" + "strconv" + "strings" ) // SQL Operations - ENUM @@ -34,30 +36,85 @@ type query struct { operation int table string where []filter + columns []string - data map[string]string + data map[string]interface{} } type filter struct { isOr bool columnName string comparator string - value string + value interface{} } -func (q *query) OrWhere(columnName, comparator, value string) { +func getColumns(in reflect.Type) []string { + cols := make([]string, 0) + for x := 0; x < in.NumField(); x++ { + var columnName string + sqlTag := in.Field(x).Tag.Get("sql") + if len(sqlTag) > 0 { + columnName = sqlTag + } else { + columnName = in.Field(x).Name + } + + cols = append(cols, columnName) + } + + return cols +} + +func (q *query) Select(cols []string) { + q.columns = cols +} + +func (q *query) Where(in interface{}) { + t := reflect.TypeOf(in) + + if t.Kind() == reflect.Slice { + mySlice := reflect.ValueOf(in) + for i := 0; i < mySlice.Len(); i++ { + obj := mySlice.Index(i).Interface() + + q.Where(obj) + } + } else if t.Kind() == reflect.Struct { + cols := getColumns(t) + + q.filterObj(in, cols) + } +} + +func (q *query) WhereIn(columnName string, values interface{}, isOr bool) { + q.addFilter(columnName, "IN", values, isOr) +} + +func (q *query) filterObj(obj interface{}, columns []string) { + v := reflect.ValueOf(obj) + + for i := 0; i < len(columns); i++ { + fieldVal := v.Field(i) + //fieldStr := fieldVal.String() + if fieldVal.IsValid() { + q.addFilter(columns[i], "=", fieldVal.Interface(), false) + } + } +} + +func (q *query) OrWhereRaw(columnName, comparator string, value interface{}) { q.addFilter(columnName, comparator, value, true) } -func (q *query) Where(columnName, comparator, value string) { +func (q *query) WhereRaw(columnName, comparator string, value interface{}) { q.addFilter(columnName, comparator, value, false) } -func (q *query) AndWhere(columnName, comparator, value string) { +func (q *query) AndWhereRaw(columnName, comparator string, value interface{}) { q.addFilter(columnName, comparator, value, false) } -func (q *query) addFilter(columnName, comparator, value string, isOr bool) { +func (q *query) addFilter(columnName, comparator string, value interface{}, isOr bool) { q.where = append(q.where, filter{ isOr: isOr, columnName: columnName, @@ -75,14 +132,13 @@ func (q *query) appendData(in interface{}) { sTag := field.Tag.Get("sql") v := reflect.ValueOf(in) - s := reflect.Indirect(v).FieldByName(field.Name).String() + //s := reflect.Indirect(v).FieldByName(field.Name) + s := v.Field(i).Interface() - if len(s) > 0 { - if len(sTag) > 0 { - q.data[sTag] = s - } else { - q.data[field.Name] = s - } + if len(sTag) > 0 { + q.data[sTag] = s + } else { + q.data[field.Name] = s } } } @@ -91,7 +147,7 @@ func (q *query) Insert(in interface{}) { q.operation = sql_INSERT if q.data == nil { - q.data = make(map[string]string) + q.data = make(map[string]interface{}) } q.appendData(in) @@ -101,7 +157,7 @@ func (q *query) Update(in interface{}) { q.operation = sql_UPDATE if q.data == nil { - q.data = make(map[string]string) + q.data = make(map[string]interface{}) } q.appendData(in) @@ -115,6 +171,43 @@ func (q *query) Delete() { q.operation = sql_DELETE } +func formatValue(t reflect.Type, v reflect.Value) string { + valStr := "" + + switch t.Kind() { + case reflect.Int: + valStr = strconv.FormatInt(v.Int(), 10) + break + case reflect.Bool: + valStr = strings.ToUpper(strconv.FormatBool(v.Bool())) + break + case reflect.Float64: + valStr = strconv.FormatFloat(v.Float(), 'f', -1, 64) + break + case reflect.Slice: + valList := "" + for i := 0; i < v.Len(); i++ { + vIdx := v.Index(i) + if i > 0 { + valList += ", " + } + valList += formatValue(vIdx.Type(), vIdx) + } + if len(valList) > 0 { + valStr = fmt.Sprintf("(%s)", valList) + } + break + default: + s := v.String() + if len(s) > 0 { + valStr = fmt.Sprintf("'%s'", v.String()) + } + break + } + + return valStr +} + func (q *query) ToSQL() string { var sqlStr string filterSql := "" @@ -131,7 +224,13 @@ func (q *query) ToSQL() string { sqlStr = fmt.Sprintf("DELETE FROM `%s`", q.table) break default: - sqlStr = fmt.Sprintf("SELECT * FROM `%s`", q.table) + colString := "*" + + if len(q.columns) > 0 { + colString = "`" + strings.Join(q.columns, "`, `") + "`" + } + + sqlStr = fmt.Sprintf("SELECT %s FROM `%s`", colString, q.table) break } @@ -139,15 +238,31 @@ func (q *query) ToSQL() string { first := true filterSql = "" for _, filter := range q.where { - predicate := "AND" - if first { - predicate = "WHERE" - first = false - } else if filter.isOr { - predicate = "OR" - } + var ( + valStr string + escape = "" + ) + + fType := reflect.TypeOf(filter.value) + fValue := reflect.ValueOf(filter.value) + + valStr = formatValue(fType, fValue) + + if len(valStr) > 2 { + predicate := "AND" + if first { + predicate = "WHERE" + first = false + } else if filter.isOr { + predicate = "OR" + } + + if fType.Kind() == reflect.Slice { + filter.comparator = "IN" + } - filterSql += fmt.Sprintf(" %s `%s` %s \"%s\"", predicate, filter.columnName, filter.comparator, filter.value) + filterSql += fmt.Sprintf(" %s `%s` %s %s%s%s", predicate, filter.columnName, filter.comparator, escape, valStr, escape) + } } } @@ -155,30 +270,47 @@ func (q *query) ToSQL() string { first := true if q.operation == sql_INSERT { cols := "" - vals := "" + values := "" for col, val := range q.data { - if first { - first = false - } else { - cols += ", " - vals += ", " + valT := reflect.TypeOf(val) + valV := reflect.ValueOf(val) + + valString := formatValue(valT, valV) + + if len(valString) > 0 { + if first { + first = false + } else { + cols += ", " + values += ", " + } + cols += "`" + col + "`" + + values += valString } - cols += "`" + col + "`" - vals += "\"" + val + "\"" } - dataSql += fmt.Sprintf(" (%s) VALUES (%s)", cols, vals) + dataSql += fmt.Sprintf(" (%s) VALUES (%s)", cols, values) } else if q.operation == sql_UPDATE { colUpdates := "" first := true for col, val := range q.data { - if first { - first = false - } else { - colUpdates += ", " + valT := reflect.TypeOf(val) + valV := reflect.ValueOf(val) + + valString := formatValue(valT, valV) + + if len(valString) > 0 { + + if first { + first = false + } else { + colUpdates += ", " + } + + colUpdates += fmt.Sprintf("`%s` = %s", col, valString) } - colUpdates += fmt.Sprintf("`%s` = \"%s\"", col, val) } dataSql += fmt.Sprintf(" SET %s", colUpdates) } diff --git a/munch_test.go b/munch_test.go new file mode 100644 index 0000000..3f7ba29 --- /dev/null +++ b/munch_test.go @@ -0,0 +1,233 @@ +package munch + +import ( + "fmt" + "testing" +) + +var qb = &QueryBuilder{} + +/*func TestInitialiseQB(t *testing.T) { + qb := &QueryBuilder{} + + if qb == nil { + t.Error("initialisation of query builder interface failed") + } +}*/ + +func assertEqual(actual, expected string) error { + if actual != expected { + return fmt.Errorf("mismatch!\nexpected:\t %s\ngot:\t\t %s", expected, actual) + } + + return nil +} + +// TestBasicRawSelect +// Expects: SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' AND `Lastname` = 'User' AND `Email` = 'test@test.com'; +func TestBasicRawSelect(t *testing.T) { + query := qb.Table("TEST_TABLE_1") + query.WhereRaw("Firstname", "=", "Test") + query.AndWhereRaw("Lastname", "=", "User") + query.AndWhereRaw("Email", "=", "test@test.com") + + sql := query.ToSQL() + + err := assertEqual(sql, "SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' AND `Lastname` = 'User' AND `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestWhereOr +// Expects: SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' OR `Email` = 'test@test.com'; +func TestWhereOr(t *testing.T) { + query := qb.Table("TEST_TABLE_1") + query.Where(BasicTestObject{ + Firstname: "Test", + }) + query.OrWhereRaw("Email", "=", "test@test.com") + + sql := query.ToSQL() + + err := assertEqual(sql, "SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' OR `Email` = 'test@test.com';") + if err != nil { + t.Error(err.Error()) + } +} + +type BasicTestObject struct { + Firstname string + Lastname string + Email string +} + +// TestBasicInsert +// Expects: INSERT INTO `TEST_TABLE_2` (`Firstname`, `Lastname`, `Email`) VALUES ('Test', 'User', 'test@test.com'); +func TestBasicInsert(t *testing.T) { + query := qb.Table("TEST_TABLE_2") + query.Insert(BasicTestObject{ + Firstname: "Test", + Lastname: "User", + Email: "test@test.com", + }) + + sql := query.ToSQL() + + err := assertEqual(sql, "INSERT INTO `TEST_TABLE_2` (`Firstname`, `Lastname`, `Email`) VALUES ('Test', 'User', 'test@test.com');") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestBasicUpdate +// Expects: UPDATE `TEST_TABLE_3` SET `Firstname` = 'Test', `Lastname` = 'User' WHERE `Email` = 'test@test.com'; +func TestBasicUpdate(t *testing.T) { + query := qb.Table("TEST_TABLE_3") + query.Update(BasicTestObject{ + Firstname: "Test", + Lastname: "User", + }) + query.Where(BasicTestObject{ + Email: "test@test.com", + }) + + sql := query.ToSQL() + err := assertEqual(sql, "UPDATE `TEST_TABLE_3` SET `Firstname` = 'Test', `Lastname` = 'User' WHERE `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestBasicDelete +// Expects: DELETE FROM `TEST_TABLE_4` WHERE `Email` = 'test@test.com'; +func TestBasicDelete(t *testing.T) { + query := qb.Table("TEST_TABLE_4") + query.Where(BasicTestObject{Email: "test@test.com"}) + query.Del() + + sql := query.ToSQL() + err := assertEqual(sql, "DELETE FROM `TEST_TABLE_4` WHERE `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +type WhereInObject struct { + Usernames []string `sql:"Username"` +} + +// TestWhereInString +// Expects: SELECT * FROM `Users` WHERE `Username` IN ('Test', 'Test2', 'Test3'); +func TestWhereInString(t *testing.T) { + query := qb.Table("Users") + query.Where(WhereInObject{Usernames: []string{"Test", "Test2", "Test3"}}) + //query.WhereRaw("Username", "IN", []string{"Test","Test2","Test3"}) + + sql := query.ToSQL() + err := assertEqual(sql, "SELECT * FROM `Users` WHERE `Username` IN ('Test', 'Test2', 'Test3');") + + if err != nil { + t.Error(err.Error()) + } +} + +type TestGroupObj struct { + GID int `sql:"GroupId"` + GName string `sql:"GroupName"` +} + +// TestInsertWithTags +// Expects: INSERT INTO `UserGroups` (`GroupId`, `GroupName`) VALUES (5, 'Test Group 5'); +func TestInsertWithTags(t *testing.T) { + query := qb.Table("UserGroups") + query.Insert(TestGroupObj{ + GID: 5, + GName: "Test Group 5", + }) + + sql := query.ToSQL() + err := assertEqual(sql, "INSERT INTO `UserGroups` (`GroupId`, `GroupName`) VALUES (5, 'Test Group 5');") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestWhereInInt +// Expects: SELECT * FROM `UserGroups` WHERE `GroupId` IN (1, 2, 3, 4); +func TestWhereInInt(t *testing.T) { + query := qb.Table("UserGroups") + query.WhereIn("GroupId", []int{1, 2, 3, 4}, false) + + sql := query.ToSQL() + err := assertEqual(sql, "SELECT * FROM `UserGroups` WHERE `GroupId` IN (1, 2, 3, 4);") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestMultipleWhereInOne +// Expects: SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' AND `Lastname` = 'User' AND `Email` = 'test@test.com'; +func TestMultipleWhereInOne(t *testing.T) { + query := qb.Table("TEST_TABLE_1") + + query.Where([]BasicTestObject{ + {Firstname: "Test"}, + {Lastname: "User"}, + {Email: "test@test.com"}, + }) + + sql := query.ToSQL() + err := assertEqual(sql, "SELECT * FROM `TEST_TABLE_1` WHERE `Firstname` = 'Test' AND `Lastname` = 'User' AND `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +// TestSelectSpecificColumns +// Expects: SELECT `UserId`, `Username` FROM `Users` WHERE `Email` = 'test@test.com'; +func TestSelectSpecificColumns(t *testing.T) { + query := qb.Table("Users") + query.Where(BasicTestObject{Email: "test@test.com"}) + query.Select([]string{"UserId", "Username"}) + + sql := query.ToSQL() + err := assertEqual(sql, "SELECT `UserId`, `Username` FROM `Users` WHERE `Email` = 'test@test.com';") + + if err != nil { + t.Error(err.Error()) + } +} + +type ComplexTestObj struct { + MyText string + MyFloat float64 + MyInt int + MyBool bool +} + +// TestInsertComplexObject +// Expects: INSERT INTO `ComplexTable` (`MyText`, `MyFloat`, `MyInt`, `MyBool`) VALUES ('Test', 5.66, 9, TRUE); +func TestInsertComplexObject(t *testing.T) { + query := qb.Table("ComplexTable") + query.Insert(ComplexTestObj{ + MyText: "Test", + MyFloat: 5.660, + MyInt: 9, + MyBool: true, + }) + + sql := query.ToSQL() + err := assertEqual(sql, "INSERT INTO `ComplexTable` (`MyText`, `MyFloat`, `MyInt`, `MyBool`) VALUES ('Test', 5.66, 9, TRUE);") + + if err != nil { + t.Error(err.Error()) + } +} diff --git a/tests/basic.go b/tests/basic.go deleted file mode 100644 index 5db080d..0000000 --- a/tests/basic.go +++ /dev/null @@ -1,45 +0,0 @@ -package main - -import ( - "fmt" - "munch" -) - -type TestData1 struct { - Firstname string - Lastname string - Email string -} - -func main() { - - qb := munch.QueryBuilder{} - - query := qb.Table("TEST_TABLE_1") - - query.Where("Count", ">", "10") - query.AndWhere("Age", ">", "5") - query.OrWhere("Name", "=", "Admiral Joshua") - - fmt.Println(query.ToSQL()) - - iQuery := qb.Table("TEST_TABLE_2") - iQuery.Insert(TestData1{ - Firstname: "Test", - Lastname: "User", - Email: "test@test.com", - }) - fmt.Println(iQuery.ToSQL()) - - uQuery := qb.Table("TEST_TABLE_3") - uQuery.Update(TestData1{ - Firstname: "Test", - }) - uQuery.Where("email", "=", "test@test.com") - fmt.Println(uQuery.ToSQL()) - - dQuery := qb.Table("TEST_TABLE_2") - dQuery.Where("Firstname", "=", "Leila") - dQuery.Delete() - fmt.Println(dQuery.ToSQL()) -}