csv2db/main.go
2022-10-07 20:17:50 +05:00

238 lines
5.7 KiB
Go

package main
import (
"bufio"
"database/sql"
"encoding/csv"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
_ "github.com/denisenkom/go-mssqldb"
mssql "github.com/denisenkom/go-mssqldb"
"github.com/jessevdk/go-flags"
)
var db *sql.DB
var opts struct {
FilePath string `long:"filepath" description:"CSV file path" required:"true"`
ServerAddress string `long:"server" description:"server address" default:"127.0.0.1"`
DatabaseName string `long:"database" description:"database name" required:"true"`
TableName string `long:"table" description:"table name in schema.name format" required:"true"`
FieldTypes string `long:"fields" description:"field types in [sifdt ] format" required:"true"`
Comma string `long:"comma" description:"CSV file comma character" choice:"," choice:";" choice:"t" default:","`
CreateTable bool `long:"create" description:"create table"`
OverwriteTable bool `long:"overwrite" description:"overwrite existing table"`
Encoding string `long:"encoding" description:"CSV file charset" choice:"utf8" choice:"win1251" default:"utf8"`
SkipRows int `long:"skiprows" description:"number of rows to skip"`
DateFormat string `long:"dateformat" description:"date format (Go style)" default:"02.01.2006"`
TimestampFormat string `long:"timestampformat" description:"timestamp format (Go style)" default:"02.01.2006 15:04:05"`
UnknownColumnNames bool `long:"unknowncolumnnames" description:"insert to table with unknown column names"`
}
func init() {
log.SetFlags(0)
}
func main() {
_, err := flags.Parse(&opts)
if err != nil {
os.Exit(1)
}
db, err = sql.Open("sqlserver", fmt.Sprintf("sqlserver://%s?database=%s", opts.ServerAddress, opts.DatabaseName))
if err != nil {
log.Fatalln(fmt.Errorf("open database: %v", err))
}
defer db.Close()
switch strings.ToLower(filepath.Ext(opts.FilePath)) {
case ".zip":
err = processZipFile(opts.FilePath)
case ".csv":
err = processCsvFile(opts.FilePath)
}
if err != nil {
log.Fatalln(err)
}
}
func processReader(r io.Reader) error {
var encoding Encoding
err := encoding.UnmarshalText([]byte(opts.Encoding))
if err != nil {
return fmt.Errorf("get decoder: %v", opts.Encoding)
}
decoder, err := encoding.Translate(r)
if err != nil {
return fmt.Errorf("enable decoder: %v", opts.Encoding)
}
bufReader := bufio.NewReaderSize(decoder, 4*1024*1024)
for i := 0; i < opts.SkipRows; i++ {
_, _, err := bufReader.ReadLine()
if err != nil {
return fmt.Errorf("skip rows: %v", err)
}
}
reader := csv.NewReader(bufReader)
reader.LazyQuotes = true
reader.TrimLeadingSpace = false
reader.FieldsPerRecord = len(opts.FieldTypes)
if []rune(opts.Comma)[0] == 't' {
reader.Comma = '\t'
} else {
reader.Comma = []rune(opts.Comma)[0]
}
header, err := reader.Read()
if err != nil {
return fmt.Errorf("read header: %v", err)
}
headerList := `"`
for i, v := range header {
if opts.FieldTypes[i] == ' ' {
continue
}
headerList += v
if i+1 < len(header) {
headerList += `", "`
} else {
headerList += `"`
}
}
var neededHeader []string
if opts.UnknownColumnNames {
sql := fmt.Sprintf("SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA + '.' + TABLE_NAME = '%s' ORDER BY ORDINAL_POSITION", opts.TableName)
rows, err := db.Query(sql)
if err != nil {
return fmt.Errorf("get column names from database: %v", err)
}
defer rows.Close()
for rows.Next() {
if rows.Err() != nil {
return fmt.Errorf("get column names from database: %v", err)
}
var columnName string
err = rows.Scan(&columnName)
if err != nil {
return fmt.Errorf("get column names from database: %v", err)
}
neededHeader = append(neededHeader, columnName)
}
} else {
for i, v := range header {
if opts.FieldTypes[i] == ' ' {
continue
}
neededHeader = append(neededHeader, v)
}
}
if len(neededHeader) == 0 {
return fmt.Errorf("no columns to process (check table name or field types)")
}
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("start transaction: %v", err)
}
if opts.CreateTable {
err = createTable(tx, header, opts.FieldTypes)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("create table: %v", err)
}
}
sql := mssql.CopyIn(opts.TableName, mssql.BulkOptions{Tablock: true}, neededHeader...)
stmt, err := tx.Prepare(sql)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("prepare statement: %v", err)
}
n := 0
for {
if n%100000 == 0 {
fmt.Fprintf(os.Stderr, "Processed %d records...\r", n)
}
record, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("read record: %v", err)
}
var args []any
for i, v := range record {
var fieldType FieldType
err = fieldType.UnmarshalText([]byte{opts.FieldTypes[i]})
if err != nil {
return fmt.Errorf("get record type: %v", err)
}
if fieldType == Skip {
continue
}
parsedValue, err := fieldType.ParseValue(v)
if err != nil {
return fmt.Errorf("parse value: %v", err)
}
args = append(args, parsedValue)
}
_, err = stmt.Exec(args...)
if err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return fmt.Errorf("execute statement: %v", err)
}
n++
}
result, err := stmt.Exec()
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("execute statement: %v", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("calc rows affected: %v", err)
}
fmt.Fprintf(os.Stderr, "Processed %d records. \n", rowsAffected)
err = stmt.Close()
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("close statement: %v", err)
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("commit: %v", err)
}
return nil
}