Skip to content

Commit

Permalink
Added option to specify table name
Browse files Browse the repository at this point in the history
Added option to specify table name instead of query.
  • Loading branch information
noborus committed Oct 26, 2023
1 parent eba2309 commit f324ddc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
16 changes: 12 additions & 4 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"github.com/ulikunitz/xz"
)

const TableQuery = "SELECT * FROM"

// Cli wraps stdout and error output specification.
type Cli struct {
// OutStream is the output destination.
Expand Down Expand Up @@ -69,6 +71,7 @@ func (cli Cli) Run(args []string) int {
queryFile string
analyze string
onlySQL string
tableName string

inFlag inputFlag
inDelimiter string
Expand Down Expand Up @@ -107,6 +110,7 @@ func (cli Cli) Run(args []string) int {
flags.StringVar(&queryFile, "q", "", "read query from the specified file.")
flags.StringVar(&analyze, "a", "", "analyze the file and suggest SQL.")
flags.StringVar(&onlySQL, "A", "", "analyze the file but only suggest SQL.")
flags.StringVar(&tableName, "t", "", "read table name from the specified file.")
flags.BoolVar(&usage, "help", false, "display usage information.")
flags.BoolVar(&version, "version", false, "display version information.")
flags.BoolVar(&Debug, "debug", false, "debug print.")
Expand Down Expand Up @@ -204,7 +208,7 @@ func (cli Cli) Run(args []string) int {
return 0
}

query, err := getQuery(flags.Args(), queryFile)
query, err := getQuery(flags.Args(), tableName, queryFile)
if err != nil {
log.Printf("ERROR: %s", err)
return 1
Expand Down Expand Up @@ -424,12 +428,16 @@ func trimQuery(query string) string {
return strings.TrimRight(strings.TrimSpace(query), ";")
}

func getQuery(args []string, fileName string) (string, error) {
if fileName == "" {
func getQuery(args []string, tableName string, queryFile string) (string, error) {
if tableName != "" {
return trimQuery(strings.Join([]string{TableQuery, tableName}, " ")), nil
}

if queryFile == "" {
return trimQuery(strings.Join(args, " ")), nil
}

sqlByte, err := os.ReadFile(fileName)
sqlByte, err := os.ReadFile(queryFile)
if err != nil {
return "", err
}
Expand Down
46 changes: 31 additions & 15 deletions cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ func Test_outputFormat(t *testing.T) {

func Test_getQuery(t *testing.T) {
type argss struct {
args []string
fileName string
args []string
tableName string
queryFile string
}
tests := []struct {
name string
Expand All @@ -209,52 +210,67 @@ func Test_getQuery(t *testing.T) {
{
name: "testARGS",
argss: argss{
[]string{"SELECT 1"},
"",
args: []string{"SELECT 1"},
tableName: "",
queryFile: "",
},
want: "SELECT 1",
wantErr: false,
},
{
name: "testARGS2",
argss: argss{
[]string{"SELECT", "1"},
"",
args: []string{"SELECT", "1"},
tableName: "",
queryFile: "",
},
want: "SELECT 1",
wantErr: false,
},
{
name: "testTrim",
argss: argss{
[]string{"SELECT * FROM test; "},
"",
args: []string{"SELECT * FROM test; "},
tableName: "",
queryFile: "",
},
want: "SELECT * FROM test",
wantErr: false,
},
{
name: "testFileErr",
name: "testTableName",
argss: argss{
[]string{},
filepath.Join("..", "testdata", "noFile.sql"),
args: []string{},
tableName: filepath.Join("..", "testdata", "test.csv"),
queryFile: "",
},
want: "SELECT * FROM ../testdata/test.csv",
wantErr: false,
},
{
name: "testQueryFileErr",
argss: argss{
args: []string{},
tableName: "",
queryFile: filepath.Join("..", "testdata", "noFile.sql"),
},
want: "",
wantErr: true,
},
{
name: "testFile",
name: "testQueryFile",
argss: argss{
[]string{},
filepath.Join("..", "testdata", "test.sql"),
args: []string{},
tableName: "",
queryFile: filepath.Join("..", "testdata", "test.sql"),
},
want: "SELECT * FROM testdata/test.csv",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getQuery(tt.argss.args, tt.argss.fileName)
got, err := getQuery(tt.argss.args, tt.argss.tableName, tt.argss.queryFile)
if (err != nil) != tt.wantErr {
t.Errorf("getQuery() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down

0 comments on commit f324ddc

Please sign in to comment.