diff --git a/importer.go b/importer.go index 4d1802f..d137e94 100644 --- a/importer.go +++ b/importer.go @@ -303,23 +303,8 @@ func guessFormat(fileName string) Format { return CSV } ext := strings.ToUpper(strings.TrimLeft(dotExt, ".")) - switch ext { - case "CSV": - return CSV - case "TSV": - return TSV - case "PSV": - return PSV - case "LTSV": - return LTSV - case "JSON", "JSONL": - return JSON - case "YAML", "YML": - return YAML - case "TBLN": - return TBLN - case "WIDTH": - return WIDTH + if format, ok := extToFormat[ext]; ok { + return format } fileName = fileName[:len(fileName)-len(dotExt)] } diff --git a/input_csv.go b/input_csv.go index fd6f8c6..d339ff6 100644 --- a/input_csv.go +++ b/input_csv.go @@ -90,6 +90,16 @@ func NewCSVReader(reader io.Reader, opts *ReadOpts) (*CSVReader, error) { return r, nil } +func NewTSVReader(reader io.Reader, opts *ReadOpts) (*CSVReader, error) { + opts.InDelimiter = "\t" + return NewCSVReader(reader, opts) +} + +func NewPSVReader(reader io.Reader, opts *ReadOpts) (*CSVReader, error) { + opts.InDelimiter = "|" + return NewCSVReader(reader, opts) +} + func (r *CSVReader) setColumnType() { if r.names == nil { return diff --git a/reader.go b/reader.go index da46c03..efec33b 100644 --- a/reader.go +++ b/reader.go @@ -3,8 +3,65 @@ package trdsql import ( "io" "log" + "sync" ) +// extToFormat is a map of file extensions to formats. +var extToFormat map[string]Format = map[string]Format{ + "CSV": CSV, + "LTSV": LTSV, + "JSON": JSON, + "JSONL": JSON, + "YAML": YAML, + "YML": YAML, + "TBLN": TBLN, + "TSV": TSV, + "PSV": PSV, + "WIDTH": WIDTH, +} + +// ReaderFunc is a function that creates a new Reader. +type ReaderFunc func(io.Reader, *ReadOpts) (Reader, error) + +// readerFuncs maps formats to their corresponding ReaderFunc. +var readerFuncs = map[Format]ReaderFunc{ + CSV: func(reader io.Reader, opts *ReadOpts) (Reader, error) { + return NewCSVReader(reader, opts) + }, + LTSV: func(reader io.Reader, opts *ReadOpts) (Reader, error) { + return NewLTSVReader(reader, opts) + }, + JSON: func(reader io.Reader, opts *ReadOpts) (Reader, error) { + return NewJSONReader(reader, opts) + }, + YAML: func(reader io.Reader, opts *ReadOpts) (Reader, error) { + return NewYAMLReader(reader, opts) + }, + TBLN: func(reader io.Reader, opts *ReadOpts) (Reader, error) { + return NewTBLNReader(reader, opts) + }, + TSV: func(reader io.Reader, opts *ReadOpts) (Reader, error) { + return NewCSVReader(reader, opts) + }, + PSV: func(reader io.Reader, opts *ReadOpts) (Reader, error) { + return NewCSVReader(reader, opts) + }, + WIDTH: func(reader io.Reader, opts *ReadOpts) (Reader, error) { + return NewGWReader(reader, opts) + }, +} + +var extFormat Format = 100 +var registerMux = &sync.Mutex{} + +func RegisterReaderFunc(ext string, readerFunc ReaderFunc) { + registerMux.Lock() + defer registerMux.Unlock() + extToFormat[ext] = extFormat + readerFuncs[extFormat] = readerFunc + extFormat++ +} + // Reader is wrap the reader. // Reader reads from tabular files. type Reader interface { @@ -12,7 +69,7 @@ type Reader interface { Names() ([]string, error) // Types returns column types. Types() ([]string, error) - // PreReadRow is returns only columns that store preread rows. + // PreReadRow is returns only columns that store preRead rows. PreReadRow() [][]interface{} // ReadRow is read the rest of the row. ReadRow(row []interface{}) ([]interface{}, error) @@ -157,28 +214,12 @@ func NewReader(reader io.Reader, readOpts *ReadOpts) (Reader, error) { if reader == nil { return nil, ErrNoReader } - switch readOpts.realFormat { - case CSV: - return NewCSVReader(reader, readOpts) - case TSV: - readOpts.InDelimiter = "\t" - return NewCSVReader(reader, readOpts) - case PSV: - readOpts.InDelimiter = "|" - return NewCSVReader(reader, readOpts) - case LTSV: - return NewLTSVReader(reader, readOpts) - case JSON: - return NewJSONReader(reader, readOpts) - case YAML: - return NewYAMLReader(reader, readOpts) - case TBLN: - return NewTBLNReader(reader, readOpts) - case WIDTH: - return NewGWReader(reader, readOpts) - default: + readerFunc, ok := readerFuncs[readOpts.realFormat] + if !ok { return nil, ErrUnknownFormat } + + return readerFunc(reader, readOpts) } func skipRead(r Reader, skipNum int) { diff --git a/trdsql.go b/trdsql.go index 8e44e67..742d1d1 100644 --- a/trdsql.go +++ b/trdsql.go @@ -92,11 +92,11 @@ const ( YAML // import - // Tab-Separated Values format. + // Tab-Separated Values format. Format using go standard CSV library. TSV // import - // Pipe-Separated Values format. + // Pipe-Separated Values format. Format using go standard CSV library. PSV )