Skip to content

Commit

Permalink
Switched origninal Where function to a WhereRaw function.
Browse files Browse the repository at this point in the history
Created new Where function(s) capable of taking structs in and storing the filters this way.
Added support for multiple data types and for handling of arrays of filters as well as an array of values for one filter. (I.E. WHERE X IN (1, 2, 3))
  • Loading branch information
Leila-Codes committed Jan 20, 2021
1 parent edf5cd3 commit 1c1060b
Show file tree
Hide file tree
Showing 3 changed files with 403 additions and 83 deletions.
208 changes: 170 additions & 38 deletions munch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
)

// SQL Operations - ENUM
Expand Down Expand Up @@ -34,30 +36,85 @@ type query struct {
operation int
table string
where []filter
columns []string

data map[string]string
data map[string]interface{}
}

type filter struct {
isOr bool
columnName string
comparator string
value string
value interface{}
}

func (q *query) OrWhere(columnName, comparator, value string) {
func getColumns(in reflect.Type) []string {
cols := make([]string, 0)
for x := 0; x < in.NumField(); x++ {
var columnName string
sqlTag := in.Field(x).Tag.Get("sql")
if len(sqlTag) > 0 {
columnName = sqlTag
} else {
columnName = in.Field(x).Name
}

cols = append(cols, columnName)
}

return cols
}

func (q *query) Select(cols []string) {
q.columns = cols
}

func (q *query) Where(in interface{}) {
t := reflect.TypeOf(in)

if t.Kind() == reflect.Slice {
mySlice := reflect.ValueOf(in)
for i := 0; i < mySlice.Len(); i++ {
obj := mySlice.Index(i).Interface()

q.Where(obj)
}
} else if t.Kind() == reflect.Struct {
cols := getColumns(t)

q.filterObj(in, cols)
}
}

func (q *query) WhereIn(columnName string, values interface{}, isOr bool) {
q.addFilter(columnName, "IN", values, isOr)
}

func (q *query) filterObj(obj interface{}, columns []string) {
v := reflect.ValueOf(obj)

for i := 0; i < len(columns); i++ {
fieldVal := v.Field(i)
//fieldStr := fieldVal.String()
if fieldVal.IsValid() {
q.addFilter(columns[i], "=", fieldVal.Interface(), false)
}
}
}

func (q *query) OrWhereRaw(columnName, comparator string, value interface{}) {
q.addFilter(columnName, comparator, value, true)
}

func (q *query) Where(columnName, comparator, value string) {
func (q *query) WhereRaw(columnName, comparator string, value interface{}) {
q.addFilter(columnName, comparator, value, false)
}

func (q *query) AndWhere(columnName, comparator, value string) {
func (q *query) AndWhereRaw(columnName, comparator string, value interface{}) {
q.addFilter(columnName, comparator, value, false)
}

func (q *query) addFilter(columnName, comparator, value string, isOr bool) {
func (q *query) addFilter(columnName, comparator string, value interface{}, isOr bool) {
q.where = append(q.where, filter{
isOr: isOr,
columnName: columnName,
Expand All @@ -75,14 +132,13 @@ func (q *query) appendData(in interface{}) {
sTag := field.Tag.Get("sql")

v := reflect.ValueOf(in)
s := reflect.Indirect(v).FieldByName(field.Name).String()
//s := reflect.Indirect(v).FieldByName(field.Name)
s := v.Field(i).Interface()

if len(s) > 0 {
if len(sTag) > 0 {
q.data[sTag] = s
} else {
q.data[field.Name] = s
}
if len(sTag) > 0 {
q.data[sTag] = s
} else {
q.data[field.Name] = s
}
}
}
Expand All @@ -91,7 +147,7 @@ func (q *query) Insert(in interface{}) {
q.operation = sql_INSERT

if q.data == nil {
q.data = make(map[string]string)
q.data = make(map[string]interface{})
}

q.appendData(in)
Expand All @@ -101,7 +157,7 @@ func (q *query) Update(in interface{}) {
q.operation = sql_UPDATE

if q.data == nil {
q.data = make(map[string]string)
q.data = make(map[string]interface{})
}

q.appendData(in)
Expand All @@ -115,6 +171,43 @@ func (q *query) Delete() {
q.operation = sql_DELETE
}

func formatValue(t reflect.Type, v reflect.Value) string {
valStr := ""

switch t.Kind() {
case reflect.Int:
valStr = strconv.FormatInt(v.Int(), 10)
break
case reflect.Bool:
valStr = strings.ToUpper(strconv.FormatBool(v.Bool()))
break
case reflect.Float64:
valStr = strconv.FormatFloat(v.Float(), 'f', -1, 64)
break
case reflect.Slice:
valList := ""
for i := 0; i < v.Len(); i++ {
vIdx := v.Index(i)
if i > 0 {
valList += ", "
}
valList += formatValue(vIdx.Type(), vIdx)
}
if len(valList) > 0 {
valStr = fmt.Sprintf("(%s)", valList)
}
break
default:
s := v.String()
if len(s) > 0 {
valStr = fmt.Sprintf("'%s'", v.String())
}
break
}

return valStr
}

func (q *query) ToSQL() string {
var sqlStr string
filterSql := ""
Expand All @@ -131,54 +224,93 @@ func (q *query) ToSQL() string {
sqlStr = fmt.Sprintf("DELETE FROM `%s`", q.table)
break
default:
sqlStr = fmt.Sprintf("SELECT * FROM `%s`", q.table)
colString := "*"

if len(q.columns) > 0 {
colString = "`" + strings.Join(q.columns, "`, `") + "`"
}

sqlStr = fmt.Sprintf("SELECT %s FROM `%s`", colString, q.table)
break
}

if q.operation != sql_INSERT && len(q.where) > 0 {
first := true
filterSql = ""
for _, filter := range q.where {
predicate := "AND"
if first {
predicate = "WHERE"
first = false
} else if filter.isOr {
predicate = "OR"
}
var (
valStr string
escape = ""
)

fType := reflect.TypeOf(filter.value)
fValue := reflect.ValueOf(filter.value)

valStr = formatValue(fType, fValue)

if len(valStr) > 2 {
predicate := "AND"
if first {
predicate = "WHERE"
first = false
} else if filter.isOr {
predicate = "OR"
}

if fType.Kind() == reflect.Slice {
filter.comparator = "IN"
}

filterSql += fmt.Sprintf(" %s `%s` %s \"%s\"", predicate, filter.columnName, filter.comparator, filter.value)
filterSql += fmt.Sprintf(" %s `%s` %s %s%s%s", predicate, filter.columnName, filter.comparator, escape, valStr, escape)
}
}
}

if len(q.data) > 0 {
first := true
if q.operation == sql_INSERT {
cols := ""
vals := ""
values := ""

for col, val := range q.data {
if first {
first = false
} else {
cols += ", "
vals += ", "
valT := reflect.TypeOf(val)
valV := reflect.ValueOf(val)

valString := formatValue(valT, valV)

if len(valString) > 0 {
if first {
first = false
} else {
cols += ", "
values += ", "
}
cols += "`" + col + "`"

values += valString
}
cols += "`" + col + "`"
vals += "\"" + val + "\""
}

dataSql += fmt.Sprintf(" (%s) VALUES (%s)", cols, vals)
dataSql += fmt.Sprintf(" (%s) VALUES (%s)", cols, values)
} else if q.operation == sql_UPDATE {
colUpdates := ""
first := true
for col, val := range q.data {
if first {
first = false
} else {
colUpdates += ", "
valT := reflect.TypeOf(val)
valV := reflect.ValueOf(val)

valString := formatValue(valT, valV)

if len(valString) > 0 {

if first {
first = false
} else {
colUpdates += ", "
}

colUpdates += fmt.Sprintf("`%s` = %s", col, valString)
}
colUpdates += fmt.Sprintf("`%s` = \"%s\"", col, val)
}
dataSql += fmt.Sprintf(" SET %s", colUpdates)
}
Expand Down
Loading

0 comments on commit 1c1060b

Please sign in to comment.