diff --git a/README.md b/README.md index bc3589c..2ef44c6 100644 --- a/README.md +++ b/README.md @@ -9,18 +9,21 @@ Usage: csv2db [OPTIONS] Application Options: - /f, /file: CSV file path - /s, /server: server address (default: 127.0.0.1) - /d, /database: database name - /t, /table: table name - /l, /fields: field types - /c, /comma:[,|;|t] CSV file comma character (default: ,) - /x, /create create table - /o, /overwrite overwrite existing table - /e, /encoding:[utf8|win1251] CSV file charset (default: utf8) - /r, /skiprows: number of rows to skip + /filepath: CSV file path + /server: server address (default: 127.0.0.1) + /database: database name + /table: table name in schema.name format + /fields: field types in [sifdt ] format + /comma:[,|;|t] CSV file comma character (default: ,) + /create create table + /overwrite overwrite existing table + /encoding:[utf8|win1251] CSV file charset (default: utf8) + /skiprows: number of rows to skip + /dateformat: date format (Go style) (default: 02.01.2006) + /timestampformat: timestamp format (Go style) (default: 02.01.2006 15:04:05) + /unknowncolumnnames insert to table with unknown column names ``` ## Build -Use `make.bat` file to build `csv2db.exe` executable. \ No newline at end of file +Use `make.bat` file to build `csv2db.exe` executable. diff --git a/encodings.go b/encodings.go index 2f7c7c2..ef77dd4 100644 --- a/encodings.go +++ b/encodings.go @@ -41,8 +41,10 @@ func (e *Encoding) UnmarshalText(text []byte) error { switch string(text) { case "utf8": *e = Utf8 + return nil case "win1251": *e = Win1251 + return nil } return fmt.Errorf("unknown encoding: %s", string(text)) diff --git a/main.go b/main.go index a19a65c..1bc8d95 100644 --- a/main.go +++ b/main.go @@ -18,16 +18,19 @@ import ( var db *sql.DB var opts struct { - FilePath string `short:"f" long:"file" description:"CSV file path" required:"true"` - ServerAddress string `short:"s" long:"server" description:"server address" default:"127.0.0.1"` - DatabaseName string `short:"d" long:"database" description:"database name" required:"true"` - TableName string `short:"t" long:"table" description:"table name" required:"true"` - FieldTypes string `short:"l" long:"fields" description:"field types" required:"true"` - Comma string `short:"c" long:"comma" description:"CSV file comma character" choice:"," choice:";" choice:"t" default:","` - CreateTable bool `short:"x" long:"create" description:"create table"` - OverwriteTable bool `short:"o" long:"overwrite" description:"overwrite existing table"` - Encoding string `short:"e" long:"encoding" description:"CSV file charset" choice:"utf8" choice:"win1251" default:"utf8"` - SkipRows int `short:"r" long:"skiprows" description:"number of rows to skip"` + FilePath string `long:"filepath" description:"CSV file path" required:"true"` + ServerAddress string `long:"server" description:"server address" default:"127.0.0.1"` + DatabaseName string `long:"database" description:"database name" required:"true"` + TableName string `long:"table" description:"table name in schema.name format" required:"true"` + FieldTypes string `long:"fields" description:"field types in [sifdt ] format" required:"true"` + Comma string `long:"comma" description:"CSV file comma character" choice:"," choice:";" choice:"t" default:","` + CreateTable bool `long:"create" description:"create table"` + OverwriteTable bool `long:"overwrite" description:"overwrite existing table"` + Encoding string `long:"encoding" description:"CSV file charset" choice:"utf8" choice:"win1251" default:"utf8"` + SkipRows int `long:"skiprows" description:"number of rows to skip"` + DateFormat string `long:"dateformat" description:"date format (Go style)" default:"02.01.2006"` + TimestampFormat string `long:"timestampformat" description:"timestamp format (Go style)" default:"02.01.2006 15:04:05"` + UnknownColumnNames bool `long:"unknowncolumnnames" description:"insert to table with unknown column names"` } func init() { @@ -71,8 +74,15 @@ func processReader(r io.Reader) error { bufReader := bufio.NewReaderSize(decoder, 4*1024*1024) + for i := 0; i < opts.SkipRows; i++ { + _, _, err := bufReader.ReadLine() + if err != nil { + return fmt.Errorf("skip rows: %v", err) + } + } + reader := csv.NewReader(bufReader) - reader.TrimLeadingSpace = true + reader.TrimLeadingSpace = false reader.FieldsPerRecord = len(opts.FieldTypes) if []rune(opts.Comma)[0] == 't' { @@ -81,16 +91,6 @@ func processReader(r io.Reader) error { reader.Comma = []rune(opts.Comma)[0] } - for i := 0; i < opts.SkipRows; i++ { - _, err := reader.Read() - if err == csv.ErrFieldCount { - continue - } - if err != nil { - return fmt.Errorf("skip rows: %v", err) - } - } - header, err := reader.Read() if err != nil { return fmt.Errorf("read header: %v", err) @@ -112,12 +112,38 @@ func processReader(r io.Reader) error { } var neededHeader []string - for i, v := range header { - if opts.FieldTypes[i] == ' ' { - continue - } - neededHeader = append(neededHeader, v) + if opts.UnknownColumnNames { + sql := fmt.Sprintf("SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA + '.' + TABLE_NAME = '%s' ORDER BY ORDINAL_POSITION", opts.TableName) + rows, err := db.Query(sql) + if err != nil { + return fmt.Errorf("get column names from database: %v", err) + } + defer rows.Close() + + for rows.Next() { + if rows.Err() != nil { + return fmt.Errorf("get column names from database: %v", err) + } + var columnName string + err = rows.Scan(&columnName) + if err != nil { + return fmt.Errorf("get column names from database: %v", err) + } + neededHeader = append(neededHeader, columnName) + } + } else { + for i, v := range header { + if opts.FieldTypes[i] == ' ' { + continue + } + + neededHeader = append(neededHeader, v) + } + } + + if len(neededHeader) == 0 { + return fmt.Errorf("no columns to process (check table name or field types)") } tx, err := db.Begin() @@ -126,7 +152,7 @@ func processReader(r io.Reader) error { } if opts.CreateTable { - err = createTable(tx, header) + err = createTable(tx, header, opts.FieldTypes) if err != nil { _ = tx.Rollback() return fmt.Errorf("create table: %v", err) diff --git a/sql.go b/sql.go index f4a6561..f04afd9 100644 --- a/sql.go +++ b/sql.go @@ -5,7 +5,7 @@ import ( "fmt" ) -func createTable(tx *sql.Tx, header []string) error { +func createTable(tx *sql.Tx, header []string, fieldTypes string) error { if opts.OverwriteTable { _, err := tx.Exec(fmt.Sprintf("IF object_id('%s', 'U') IS NOT NULL DROP TABLE %s", opts.TableName, opts.TableName)) if err != nil { @@ -17,11 +17,15 @@ func createTable(tx *sql.Tx, header []string) error { for i, v := range header { var fieldType FieldType - err := fieldType.UnmarshalText([]byte(v)) + err := fieldType.UnmarshalText([]byte(fieldTypes[i : i+1])) if err != nil { return fmt.Errorf("detect field type: %v", err) } + if fieldType == Skip { + continue + } + sql += fmt.Sprintf(`"%s" %s`, v, fieldType.SqlFieldType()) if i+1 < len(header) { diff --git a/types.go b/types.go index e77225a..8c4f933 100644 --- a/types.go +++ b/types.go @@ -17,10 +17,11 @@ const ( Money Date Timestamp - TimestampWithoutSeconds ) func (ft FieldType) ParseValue(s string) (any, error) { + s = strings.TrimSpace(s) + switch ft { case String: return s, nil @@ -29,11 +30,9 @@ func (ft FieldType) ParseValue(s string) (any, error) { case Float: return strconv.ParseFloat(strings.ReplaceAll(s, ",", "."), 64) case Date: - return time.Parse("02.01.2006", s) + return time.Parse(opts.DateFormat, s) case Timestamp: - return time.Parse("02.01.2006 15:04:05", s) - case TimestampWithoutSeconds: - return time.Parse("02.01.2006 15:04", s) + return time.Parse(opts.TimestampFormat, s) } return nil, fmt.Errorf("unknown type id = %d", ft) @@ -51,7 +50,7 @@ func (ft FieldType) SqlFieldType() string { panic("do not implemented - see https://github.com/denisenkom/go-mssqldb/issues/460") // TODO: https://github.com/denisenkom/go-mssqldb/issues/460 case Date: return "date" - case Timestamp, TimestampWithoutSeconds: + case Timestamp: return "datetime2" } @@ -74,8 +73,6 @@ func (ft FieldType) MarshalText() (text []byte, err error) { return []byte("d"), nil case Timestamp: return []byte("t"), nil - case TimestampWithoutSeconds: - return []byte("w"), nil } return nil, fmt.Errorf("unknown type id = %d", ft) @@ -85,21 +82,26 @@ func (ft *FieldType) UnmarshalText(text []byte) error { switch string(text) { case " ": *ft = Skip + return nil case "i": *ft = Integer + return nil case "s": *ft = String + return nil case "f": *ft = Float + return nil case "m": *ft = Money + return nil case "d": *ft = Date + return nil case "t": *ft = Timestamp - case "w": - *ft = TimestampWithoutSeconds + return nil } - return fmt.Errorf("unknown format code %s", string(text)) + return fmt.Errorf(`unknown format code "%s"`, string(text)) } diff --git a/zip.go b/zip.go index ed35ea8..c190793 100644 --- a/zip.go +++ b/zip.go @@ -8,7 +8,7 @@ import ( func processZipFile(filePath string) error { r, err := zip.OpenReader(filePath) if err != nil { - return err + return fmt.Errorf("open ZIP file: %v", err) } if len(r.File) != 1 { @@ -17,7 +17,7 @@ func processZipFile(filePath string) error { zipFileReader, err := r.File[0].Open() if err != nil { - return err + return fmt.Errorf("open file from ZIP archive: %v", err) } defer zipFileReader.Close()