This commit is contained in:
nxshock 2022-09-21 21:36:40 +05:00
parent 4c0b45e2b8
commit 219a519ff8
6 changed files with 90 additions and 53 deletions

View File

@ -9,16 +9,19 @@ Usage:
csv2db [OPTIONS]
Application Options:
/f, /file: CSV file path
/s, /server: server address (default: 127.0.0.1)
/d, /database: database name
/t, /table: table name
/l, /fields: field types
/c, /comma:[,|;|t] CSV file comma character (default: ,)
/x, /create create table
/o, /overwrite overwrite existing table
/e, /encoding:[utf8|win1251] CSV file charset (default: utf8)
/r, /skiprows: number of rows to skip
/filepath: CSV file path
/server: server address (default: 127.0.0.1)
/database: database name
/table: table name in schema.name format
/fields: field types in [sifdt ] format
/comma:[,|;|t] CSV file comma character (default: ,)
/create create table
/overwrite overwrite existing table
/encoding:[utf8|win1251] CSV file charset (default: utf8)
/skiprows: number of rows to skip
/dateformat: date format (Go style) (default: 02.01.2006)
/timestampformat: timestamp format (Go style) (default: 02.01.2006 15:04:05)
/unknowncolumnnames insert to table with unknown column names
```
## Build

View File

@ -41,8 +41,10 @@ func (e *Encoding) UnmarshalText(text []byte) error {
switch string(text) {
case "utf8":
*e = Utf8
return nil
case "win1251":
*e = Win1251
return nil
}
return fmt.Errorf("unknown encoding: %s", string(text))

80
main.go
View File

@ -18,16 +18,19 @@ import (
var db *sql.DB
var opts struct {
FilePath string `short:"f" long:"file" description:"CSV file path" required:"true"`
ServerAddress string `short:"s" long:"server" description:"server address" default:"127.0.0.1"`
DatabaseName string `short:"d" long:"database" description:"database name" required:"true"`
TableName string `short:"t" long:"table" description:"table name" required:"true"`
FieldTypes string `short:"l" long:"fields" description:"field types" required:"true"`
Comma string `short:"c" long:"comma" description:"CSV file comma character" choice:"," choice:";" choice:"t" default:","`
CreateTable bool `short:"x" long:"create" description:"create table"`
OverwriteTable bool `short:"o" long:"overwrite" description:"overwrite existing table"`
Encoding string `short:"e" long:"encoding" description:"CSV file charset" choice:"utf8" choice:"win1251" default:"utf8"`
SkipRows int `short:"r" long:"skiprows" description:"number of rows to skip"`
FilePath string `long:"filepath" description:"CSV file path" required:"true"`
ServerAddress string `long:"server" description:"server address" default:"127.0.0.1"`
DatabaseName string `long:"database" description:"database name" required:"true"`
TableName string `long:"table" description:"table name in schema.name format" required:"true"`
FieldTypes string `long:"fields" description:"field types in [sifdt ] format" required:"true"`
Comma string `long:"comma" description:"CSV file comma character" choice:"," choice:";" choice:"t" default:","`
CreateTable bool `long:"create" description:"create table"`
OverwriteTable bool `long:"overwrite" description:"overwrite existing table"`
Encoding string `long:"encoding" description:"CSV file charset" choice:"utf8" choice:"win1251" default:"utf8"`
SkipRows int `long:"skiprows" description:"number of rows to skip"`
DateFormat string `long:"dateformat" description:"date format (Go style)" default:"02.01.2006"`
TimestampFormat string `long:"timestampformat" description:"timestamp format (Go style)" default:"02.01.2006 15:04:05"`
UnknownColumnNames bool `long:"unknowncolumnnames" description:"insert to table with unknown column names"`
}
func init() {
@ -71,8 +74,15 @@ func processReader(r io.Reader) error {
bufReader := bufio.NewReaderSize(decoder, 4*1024*1024)
for i := 0; i < opts.SkipRows; i++ {
_, _, err := bufReader.ReadLine()
if err != nil {
return fmt.Errorf("skip rows: %v", err)
}
}
reader := csv.NewReader(bufReader)
reader.TrimLeadingSpace = true
reader.TrimLeadingSpace = false
reader.FieldsPerRecord = len(opts.FieldTypes)
if []rune(opts.Comma)[0] == 't' {
@ -81,16 +91,6 @@ func processReader(r io.Reader) error {
reader.Comma = []rune(opts.Comma)[0]
}
for i := 0; i < opts.SkipRows; i++ {
_, err := reader.Read()
if err == csv.ErrFieldCount {
continue
}
if err != nil {
return fmt.Errorf("skip rows: %v", err)
}
}
header, err := reader.Read()
if err != nil {
return fmt.Errorf("read header: %v", err)
@ -112,12 +112,38 @@ func processReader(r io.Reader) error {
}
var neededHeader []string
for i, v := range header {
if opts.FieldTypes[i] == ' ' {
continue
}
neededHeader = append(neededHeader, v)
if opts.UnknownColumnNames {
sql := fmt.Sprintf("SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA + '.' + TABLE_NAME = '%s' ORDER BY ORDINAL_POSITION", opts.TableName)
rows, err := db.Query(sql)
if err != nil {
return fmt.Errorf("get column names from database: %v", err)
}
defer rows.Close()
for rows.Next() {
if rows.Err() != nil {
return fmt.Errorf("get column names from database: %v", err)
}
var columnName string
err = rows.Scan(&columnName)
if err != nil {
return fmt.Errorf("get column names from database: %v", err)
}
neededHeader = append(neededHeader, columnName)
}
} else {
for i, v := range header {
if opts.FieldTypes[i] == ' ' {
continue
}
neededHeader = append(neededHeader, v)
}
}
if len(neededHeader) == 0 {
return fmt.Errorf("no columns to process (check table name or field types)")
}
tx, err := db.Begin()
@ -126,7 +152,7 @@ func processReader(r io.Reader) error {
}
if opts.CreateTable {
err = createTable(tx, header)
err = createTable(tx, header, opts.FieldTypes)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("create table: %v", err)

8
sql.go
View File

@ -5,7 +5,7 @@ import (
"fmt"
)
func createTable(tx *sql.Tx, header []string) error {
func createTable(tx *sql.Tx, header []string, fieldTypes string) error {
if opts.OverwriteTable {
_, err := tx.Exec(fmt.Sprintf("IF object_id('%s', 'U') IS NOT NULL DROP TABLE %s", opts.TableName, opts.TableName))
if err != nil {
@ -17,11 +17,15 @@ func createTable(tx *sql.Tx, header []string) error {
for i, v := range header {
var fieldType FieldType
err := fieldType.UnmarshalText([]byte(v))
err := fieldType.UnmarshalText([]byte(fieldTypes[i : i+1]))
if err != nil {
return fmt.Errorf("detect field type: %v", err)
}
if fieldType == Skip {
continue
}
sql += fmt.Sprintf(`"%s" %s`, v, fieldType.SqlFieldType())
if i+1 < len(header) {

View File

@ -17,10 +17,11 @@ const (
Money
Date
Timestamp
TimestampWithoutSeconds
)
func (ft FieldType) ParseValue(s string) (any, error) {
s = strings.TrimSpace(s)
switch ft {
case String:
return s, nil
@ -29,11 +30,9 @@ func (ft FieldType) ParseValue(s string) (any, error) {
case Float:
return strconv.ParseFloat(strings.ReplaceAll(s, ",", "."), 64)
case Date:
return time.Parse("02.01.2006", s)
return time.Parse(opts.DateFormat, s)
case Timestamp:
return time.Parse("02.01.2006 15:04:05", s)
case TimestampWithoutSeconds:
return time.Parse("02.01.2006 15:04", s)
return time.Parse(opts.TimestampFormat, s)
}
return nil, fmt.Errorf("unknown type id = %d", ft)
@ -51,7 +50,7 @@ func (ft FieldType) SqlFieldType() string {
panic("do not implemented - see https://github.com/denisenkom/go-mssqldb/issues/460") // TODO: https://github.com/denisenkom/go-mssqldb/issues/460
case Date:
return "date"
case Timestamp, TimestampWithoutSeconds:
case Timestamp:
return "datetime2"
}
@ -74,8 +73,6 @@ func (ft FieldType) MarshalText() (text []byte, err error) {
return []byte("d"), nil
case Timestamp:
return []byte("t"), nil
case TimestampWithoutSeconds:
return []byte("w"), nil
}
return nil, fmt.Errorf("unknown type id = %d", ft)
@ -85,21 +82,26 @@ func (ft *FieldType) UnmarshalText(text []byte) error {
switch string(text) {
case " ":
*ft = Skip
return nil
case "i":
*ft = Integer
return nil
case "s":
*ft = String
return nil
case "f":
*ft = Float
return nil
case "m":
*ft = Money
return nil
case "d":
*ft = Date
return nil
case "t":
*ft = Timestamp
case "w":
*ft = TimestampWithoutSeconds
return nil
}
return fmt.Errorf("unknown format code %s", string(text))
return fmt.Errorf(`unknown format code "%s"`, string(text))
}

4
zip.go
View File

@ -8,7 +8,7 @@ import (
func processZipFile(filePath string) error {
r, err := zip.OpenReader(filePath)
if err != nil {
return err
return fmt.Errorf("open ZIP file: %v", err)
}
if len(r.File) != 1 {
@ -17,7 +17,7 @@ func processZipFile(filePath string) error {
zipFileReader, err := r.File[0].Open()
if err != nil {
return err
return fmt.Errorf("open file from ZIP archive: %v", err)
}
defer zipFileReader.Close()