This commit is contained in:
nxshock 2022-09-21 21:36:40 +05:00
parent 4c0b45e2b8
commit 219a519ff8
6 changed files with 90 additions and 53 deletions

View File

@ -9,16 +9,19 @@ Usage:
csv2db [OPTIONS] csv2db [OPTIONS]
Application Options: Application Options:
/f, /file: CSV file path /filepath: CSV file path
/s, /server: server address (default: 127.0.0.1) /server: server address (default: 127.0.0.1)
/d, /database: database name /database: database name
/t, /table: table name /table: table name in schema.name format
/l, /fields: field types /fields: field types in [sifdt ] format
/c, /comma:[,|;|t] CSV file comma character (default: ,) /comma:[,|;|t] CSV file comma character (default: ,)
/x, /create create table /create create table
/o, /overwrite overwrite existing table /overwrite overwrite existing table
/e, /encoding:[utf8|win1251] CSV file charset (default: utf8) /encoding:[utf8|win1251] CSV file charset (default: utf8)
/r, /skiprows: number of rows to skip /skiprows: number of rows to skip
/dateformat: date format (Go style) (default: 02.01.2006)
/timestampformat: timestamp format (Go style) (default: 02.01.2006 15:04:05)
/unknowncolumnnames insert to table with unknown column names
``` ```
## Build ## Build

View File

@ -41,8 +41,10 @@ func (e *Encoding) UnmarshalText(text []byte) error {
switch string(text) { switch string(text) {
case "utf8": case "utf8":
*e = Utf8 *e = Utf8
return nil
case "win1251": case "win1251":
*e = Win1251 *e = Win1251
return nil
} }
return fmt.Errorf("unknown encoding: %s", string(text)) return fmt.Errorf("unknown encoding: %s", string(text))

80
main.go
View File

@ -18,16 +18,19 @@ import (
var db *sql.DB var db *sql.DB
var opts struct { var opts struct {
FilePath string `short:"f" long:"file" description:"CSV file path" required:"true"` FilePath string `long:"filepath" description:"CSV file path" required:"true"`
ServerAddress string `short:"s" long:"server" description:"server address" default:"127.0.0.1"` ServerAddress string `long:"server" description:"server address" default:"127.0.0.1"`
DatabaseName string `short:"d" long:"database" description:"database name" required:"true"` DatabaseName string `long:"database" description:"database name" required:"true"`
TableName string `short:"t" long:"table" description:"table name" required:"true"` TableName string `long:"table" description:"table name in schema.name format" required:"true"`
FieldTypes string `short:"l" long:"fields" description:"field types" required:"true"` FieldTypes string `long:"fields" description:"field types in [sifdt ] format" required:"true"`
Comma string `short:"c" long:"comma" description:"CSV file comma character" choice:"," choice:";" choice:"t" default:","` Comma string `long:"comma" description:"CSV file comma character" choice:"," choice:";" choice:"t" default:","`
CreateTable bool `short:"x" long:"create" description:"create table"` CreateTable bool `long:"create" description:"create table"`
OverwriteTable bool `short:"o" long:"overwrite" description:"overwrite existing table"` OverwriteTable bool `long:"overwrite" description:"overwrite existing table"`
Encoding string `short:"e" long:"encoding" description:"CSV file charset" choice:"utf8" choice:"win1251" default:"utf8"` Encoding string `long:"encoding" description:"CSV file charset" choice:"utf8" choice:"win1251" default:"utf8"`
SkipRows int `short:"r" long:"skiprows" description:"number of rows to skip"` SkipRows int `long:"skiprows" description:"number of rows to skip"`
DateFormat string `long:"dateformat" description:"date format (Go style)" default:"02.01.2006"`
TimestampFormat string `long:"timestampformat" description:"timestamp format (Go style)" default:"02.01.2006 15:04:05"`
UnknownColumnNames bool `long:"unknowncolumnnames" description:"insert to table with unknown column names"`
} }
func init() { func init() {
@ -71,8 +74,15 @@ func processReader(r io.Reader) error {
bufReader := bufio.NewReaderSize(decoder, 4*1024*1024) bufReader := bufio.NewReaderSize(decoder, 4*1024*1024)
for i := 0; i < opts.SkipRows; i++ {
_, _, err := bufReader.ReadLine()
if err != nil {
return fmt.Errorf("skip rows: %v", err)
}
}
reader := csv.NewReader(bufReader) reader := csv.NewReader(bufReader)
reader.TrimLeadingSpace = true reader.TrimLeadingSpace = false
reader.FieldsPerRecord = len(opts.FieldTypes) reader.FieldsPerRecord = len(opts.FieldTypes)
if []rune(opts.Comma)[0] == 't' { if []rune(opts.Comma)[0] == 't' {
@ -81,16 +91,6 @@ func processReader(r io.Reader) error {
reader.Comma = []rune(opts.Comma)[0] reader.Comma = []rune(opts.Comma)[0]
} }
for i := 0; i < opts.SkipRows; i++ {
_, err := reader.Read()
if err == csv.ErrFieldCount {
continue
}
if err != nil {
return fmt.Errorf("skip rows: %v", err)
}
}
header, err := reader.Read() header, err := reader.Read()
if err != nil { if err != nil {
return fmt.Errorf("read header: %v", err) return fmt.Errorf("read header: %v", err)
@ -112,12 +112,38 @@ func processReader(r io.Reader) error {
} }
var neededHeader []string var neededHeader []string
for i, v := range header {
if opts.FieldTypes[i] == ' ' {
continue
}
neededHeader = append(neededHeader, v) if opts.UnknownColumnNames {
sql := fmt.Sprintf("SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA + '.' + TABLE_NAME = '%s' ORDER BY ORDINAL_POSITION", opts.TableName)
rows, err := db.Query(sql)
if err != nil {
return fmt.Errorf("get column names from database: %v", err)
}
defer rows.Close()
for rows.Next() {
if rows.Err() != nil {
return fmt.Errorf("get column names from database: %v", err)
}
var columnName string
err = rows.Scan(&columnName)
if err != nil {
return fmt.Errorf("get column names from database: %v", err)
}
neededHeader = append(neededHeader, columnName)
}
} else {
for i, v := range header {
if opts.FieldTypes[i] == ' ' {
continue
}
neededHeader = append(neededHeader, v)
}
}
if len(neededHeader) == 0 {
return fmt.Errorf("no columns to process (check table name or field types)")
} }
tx, err := db.Begin() tx, err := db.Begin()
@ -126,7 +152,7 @@ func processReader(r io.Reader) error {
} }
if opts.CreateTable { if opts.CreateTable {
err = createTable(tx, header) err = createTable(tx, header, opts.FieldTypes)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
return fmt.Errorf("create table: %v", err) return fmt.Errorf("create table: %v", err)

8
sql.go
View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
) )
func createTable(tx *sql.Tx, header []string) error { func createTable(tx *sql.Tx, header []string, fieldTypes string) error {
if opts.OverwriteTable { if opts.OverwriteTable {
_, err := tx.Exec(fmt.Sprintf("IF object_id('%s', 'U') IS NOT NULL DROP TABLE %s", opts.TableName, opts.TableName)) _, err := tx.Exec(fmt.Sprintf("IF object_id('%s', 'U') IS NOT NULL DROP TABLE %s", opts.TableName, opts.TableName))
if err != nil { if err != nil {
@ -17,11 +17,15 @@ func createTable(tx *sql.Tx, header []string) error {
for i, v := range header { for i, v := range header {
var fieldType FieldType var fieldType FieldType
err := fieldType.UnmarshalText([]byte(v)) err := fieldType.UnmarshalText([]byte(fieldTypes[i : i+1]))
if err != nil { if err != nil {
return fmt.Errorf("detect field type: %v", err) return fmt.Errorf("detect field type: %v", err)
} }
if fieldType == Skip {
continue
}
sql += fmt.Sprintf(`"%s" %s`, v, fieldType.SqlFieldType()) sql += fmt.Sprintf(`"%s" %s`, v, fieldType.SqlFieldType())
if i+1 < len(header) { if i+1 < len(header) {

View File

@ -17,10 +17,11 @@ const (
Money Money
Date Date
Timestamp Timestamp
TimestampWithoutSeconds
) )
func (ft FieldType) ParseValue(s string) (any, error) { func (ft FieldType) ParseValue(s string) (any, error) {
s = strings.TrimSpace(s)
switch ft { switch ft {
case String: case String:
return s, nil return s, nil
@ -29,11 +30,9 @@ func (ft FieldType) ParseValue(s string) (any, error) {
case Float: case Float:
return strconv.ParseFloat(strings.ReplaceAll(s, ",", "."), 64) return strconv.ParseFloat(strings.ReplaceAll(s, ",", "."), 64)
case Date: case Date:
return time.Parse("02.01.2006", s) return time.Parse(opts.DateFormat, s)
case Timestamp: case Timestamp:
return time.Parse("02.01.2006 15:04:05", s) return time.Parse(opts.TimestampFormat, s)
case TimestampWithoutSeconds:
return time.Parse("02.01.2006 15:04", s)
} }
return nil, fmt.Errorf("unknown type id = %d", ft) return nil, fmt.Errorf("unknown type id = %d", ft)
@ -51,7 +50,7 @@ func (ft FieldType) SqlFieldType() string {
panic("do not implemented - see https://github.com/denisenkom/go-mssqldb/issues/460") // TODO: https://github.com/denisenkom/go-mssqldb/issues/460 panic("do not implemented - see https://github.com/denisenkom/go-mssqldb/issues/460") // TODO: https://github.com/denisenkom/go-mssqldb/issues/460
case Date: case Date:
return "date" return "date"
case Timestamp, TimestampWithoutSeconds: case Timestamp:
return "datetime2" return "datetime2"
} }
@ -74,8 +73,6 @@ func (ft FieldType) MarshalText() (text []byte, err error) {
return []byte("d"), nil return []byte("d"), nil
case Timestamp: case Timestamp:
return []byte("t"), nil return []byte("t"), nil
case TimestampWithoutSeconds:
return []byte("w"), nil
} }
return nil, fmt.Errorf("unknown type id = %d", ft) return nil, fmt.Errorf("unknown type id = %d", ft)
@ -85,21 +82,26 @@ func (ft *FieldType) UnmarshalText(text []byte) error {
switch string(text) { switch string(text) {
case " ": case " ":
*ft = Skip *ft = Skip
return nil
case "i": case "i":
*ft = Integer *ft = Integer
return nil
case "s": case "s":
*ft = String *ft = String
return nil
case "f": case "f":
*ft = Float *ft = Float
return nil
case "m": case "m":
*ft = Money *ft = Money
return nil
case "d": case "d":
*ft = Date *ft = Date
return nil
case "t": case "t":
*ft = Timestamp *ft = Timestamp
case "w": return nil
*ft = TimestampWithoutSeconds
} }
return fmt.Errorf("unknown format code %s", string(text)) return fmt.Errorf(`unknown format code "%s"`, string(text))
} }

4
zip.go
View File

@ -8,7 +8,7 @@ import (
func processZipFile(filePath string) error { func processZipFile(filePath string) error {
r, err := zip.OpenReader(filePath) r, err := zip.OpenReader(filePath)
if err != nil { if err != nil {
return err return fmt.Errorf("open ZIP file: %v", err)
} }
if len(r.File) != 1 { if len(r.File) != 1 {
@ -17,7 +17,7 @@ func processZipFile(filePath string) error {
zipFileReader, err := r.File[0].Open() zipFileReader, err := r.File[0].Open()
if err != nil { if err != nil {
return err return fmt.Errorf("open file from ZIP archive: %v", err)
} }
defer zipFileReader.Close() defer zipFileReader.Close()