2022-09-18 12:43:16 +05:00
package main
import (
"bufio"
"database/sql"
"encoding/csv"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
_ "github.com/denisenkom/go-mssqldb"
mssql "github.com/denisenkom/go-mssqldb"
"github.com/jessevdk/go-flags"
)
var db * sql . DB
var opts struct {
2022-09-21 21:36:40 +05:00
FilePath string ` long:"filepath" description:"CSV file path" required:"true" `
ServerAddress string ` long:"server" description:"server address" default:"127.0.0.1" `
DatabaseName string ` long:"database" description:"database name" required:"true" `
TableName string ` long:"table" description:"table name in schema.name format" required:"true" `
FieldTypes string ` long:"fields" description:"field types in [sifdt ] format" required:"true" `
Comma string ` long:"comma" description:"CSV file comma character" choice:"," choice:";" choice:"t" default:"," `
CreateTable bool ` long:"create" description:"create table" `
OverwriteTable bool ` long:"overwrite" description:"overwrite existing table" `
Encoding string ` long:"encoding" description:"CSV file charset" choice:"utf8" choice:"win1251" default:"utf8" `
SkipRows int ` long:"skiprows" description:"number of rows to skip" `
DateFormat string ` long:"dateformat" description:"date format (Go style)" default:"02.01.2006" `
TimestampFormat string ` long:"timestampformat" description:"timestamp format (Go style)" default:"02.01.2006 15:04:05" `
UnknownColumnNames bool ` long:"unknowncolumnnames" description:"insert to table with unknown column names" `
2022-09-18 12:43:16 +05:00
}
func init ( ) {
log . SetFlags ( 0 )
}
func main ( ) {
_ , err := flags . Parse ( & opts )
if err != nil {
os . Exit ( 1 )
}
db , err = sql . Open ( "sqlserver" , fmt . Sprintf ( "sqlserver://%s?database=%s" , opts . ServerAddress , opts . DatabaseName ) )
if err != nil {
log . Fatalln ( fmt . Errorf ( "open database: %v" , err ) )
}
defer db . Close ( )
switch strings . ToLower ( filepath . Ext ( opts . FilePath ) ) {
case ".zip" :
err = processZipFile ( opts . FilePath )
case ".csv" :
err = processCsvFile ( opts . FilePath )
}
if err != nil {
log . Fatalln ( err )
}
}
func processReader ( r io . Reader ) error {
var encoding Encoding
err := encoding . UnmarshalText ( [ ] byte ( opts . Encoding ) )
if err != nil {
return fmt . Errorf ( "get decoder: %v" , opts . Encoding )
}
decoder , err := encoding . Translate ( r )
if err != nil {
return fmt . Errorf ( "enable decoder: %v" , opts . Encoding )
}
bufReader := bufio . NewReaderSize ( decoder , 4 * 1024 * 1024 )
2022-09-21 21:36:40 +05:00
for i := 0 ; i < opts . SkipRows ; i ++ {
_ , _ , err := bufReader . ReadLine ( )
if err != nil {
return fmt . Errorf ( "skip rows: %v" , err )
}
}
2022-09-18 12:43:16 +05:00
reader := csv . NewReader ( bufReader )
2022-10-07 20:17:50 +05:00
reader . LazyQuotes = true
2022-09-21 21:36:40 +05:00
reader . TrimLeadingSpace = false
2022-09-18 12:43:16 +05:00
reader . FieldsPerRecord = len ( opts . FieldTypes )
if [ ] rune ( opts . Comma ) [ 0 ] == 't' {
reader . Comma = '\t'
} else {
reader . Comma = [ ] rune ( opts . Comma ) [ 0 ]
}
header , err := reader . Read ( )
if err != nil {
return fmt . Errorf ( "read header: %v" , err )
}
headerList := ` " `
for i , v := range header {
if opts . FieldTypes [ i ] == ' ' {
continue
}
headerList += v
if i + 1 < len ( header ) {
headerList += ` ", " `
} else {
headerList += ` " `
}
}
var neededHeader [ ] string
2022-09-21 21:36:40 +05:00
if opts . UnknownColumnNames {
sql := fmt . Sprintf ( "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA + '.' + TABLE_NAME = '%s' ORDER BY ORDINAL_POSITION" , opts . TableName )
rows , err := db . Query ( sql )
if err != nil {
return fmt . Errorf ( "get column names from database: %v" , err )
}
defer rows . Close ( )
for rows . Next ( ) {
if rows . Err ( ) != nil {
return fmt . Errorf ( "get column names from database: %v" , err )
}
var columnName string
err = rows . Scan ( & columnName )
if err != nil {
return fmt . Errorf ( "get column names from database: %v" , err )
}
neededHeader = append ( neededHeader , columnName )
}
} else {
for i , v := range header {
if opts . FieldTypes [ i ] == ' ' {
continue
}
neededHeader = append ( neededHeader , v )
2022-09-18 12:43:16 +05:00
}
2022-09-21 21:36:40 +05:00
}
2022-09-18 12:43:16 +05:00
2022-09-21 21:36:40 +05:00
if len ( neededHeader ) == 0 {
return fmt . Errorf ( "no columns to process (check table name or field types)" )
2022-09-18 12:43:16 +05:00
}
tx , err := db . Begin ( )
if err != nil {
return fmt . Errorf ( "start transaction: %v" , err )
}
if opts . CreateTable {
2022-09-21 21:36:40 +05:00
err = createTable ( tx , header , opts . FieldTypes )
2022-09-18 12:43:16 +05:00
if err != nil {
_ = tx . Rollback ( )
return fmt . Errorf ( "create table: %v" , err )
}
}
sql := mssql . CopyIn ( opts . TableName , mssql . BulkOptions { Tablock : true } , neededHeader ... )
stmt , err := tx . Prepare ( sql )
if err != nil {
_ = tx . Rollback ( )
return fmt . Errorf ( "prepare statement: %v" , err )
}
n := 0
for {
if n % 100000 == 0 {
fmt . Fprintf ( os . Stderr , "Processed %d records...\r" , n )
}
record , err := reader . Read ( )
if err == io . EOF {
break
}
if err != nil {
return fmt . Errorf ( "read record: %v" , err )
}
var args [ ] any
for i , v := range record {
var fieldType FieldType
err = fieldType . UnmarshalText ( [ ] byte { opts . FieldTypes [ i ] } )
if err != nil {
return fmt . Errorf ( "get record type: %v" , err )
}
if fieldType == Skip {
continue
}
parsedValue , err := fieldType . ParseValue ( v )
if err != nil {
return fmt . Errorf ( "parse value: %v" , err )
}
args = append ( args , parsedValue )
}
_ , err = stmt . Exec ( args ... )
if err != nil {
_ = stmt . Close ( )
_ = tx . Rollback ( )
return fmt . Errorf ( "execute statement: %v" , err )
}
n ++
}
result , err := stmt . Exec ( )
if err != nil {
_ = tx . Rollback ( )
return fmt . Errorf ( "execute statement: %v" , err )
}
rowsAffected , err := result . RowsAffected ( )
if err != nil {
_ = tx . Rollback ( )
return fmt . Errorf ( "calc rows affected: %v" , err )
}
fmt . Fprintf ( os . Stderr , "Processed %d records. \n" , rowsAffected )
err = stmt . Close ( )
if err != nil {
_ = tx . Rollback ( )
return fmt . Errorf ( "close statement: %v" , err )
}
err = tx . Commit ( )
if err != nil {
return fmt . Errorf ( "commit: %v" , err )
}
return nil
}