mirror of
https://github.com/nxshock/mssqlbulkloader.git
synced 2024-11-28 00:21:03 +05:00
152 lines
3.6 KiB
Go
152 lines
3.6 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"os"
|
||
|
"strings"
|
||
|
|
||
|
mssql "github.com/denisenkom/go-mssqldb"
|
||
|
)
|
||
|
|
||
|
// TODO: add escaping
|
||
|
func prepareTable(reader Reader, tx *sql.Tx) error {
|
||
|
if reader.Options().unknownColumnNames {
|
||
|
var columnNames []string
|
||
|
|
||
|
sql := fmt.Sprintf("SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA + '.' + TABLE_NAME = '%s' ORDER BY ORDINAL_POSITION", reader.Options().tableName)
|
||
|
rows, err := tx.Query(sql)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get column names from database: %w", err)
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
|
||
|
for rows.Next() {
|
||
|
if rows.Err() != nil {
|
||
|
return fmt.Errorf("get column names from database: %w", err)
|
||
|
}
|
||
|
var columnName string
|
||
|
err = rows.Scan(&columnName)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("get column names from database: %w", err)
|
||
|
}
|
||
|
columnNames = append(columnNames, columnName)
|
||
|
}
|
||
|
|
||
|
reader.Options().columnNames = columnNames
|
||
|
} else {
|
||
|
reader.Options().columnNames = reader.GetHeader()
|
||
|
}
|
||
|
|
||
|
if !reader.Options().create && !reader.Options().overwrite {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
if !reader.Options().create && reader.Options().overwrite {
|
||
|
logger.Println("Truncating table...")
|
||
|
_, err := tx.Exec(fmt.Sprintf("TRUNCATE TABLE %s", reader.Options().tableName))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if reader.Options().overwrite {
|
||
|
logger.Println("Dropping table...")
|
||
|
_, err := tx.Exec(fmt.Sprintf("IF object_id('%s', 'U') IS NOT NULL DROP TABLE %s", reader.Options().tableName, reader.Options().tableName))
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("drop table: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
sql := fmt.Sprintf("CREATE TABLE %s (", reader.Options().tableName)
|
||
|
|
||
|
fieldTypes := strings.ReplaceAll(reader.Options().fieldsTypes, " ", "")
|
||
|
|
||
|
for i, columnName := range reader.Options().columnNames {
|
||
|
var fieldType FieldType
|
||
|
err := fieldType.UnmarshalText([]byte{fieldTypes[i]})
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("detect field type: %w", err)
|
||
|
}
|
||
|
|
||
|
sql += fmt.Sprintf(`"%s" %s`, columnName, fieldType.SqlFieldType())
|
||
|
|
||
|
if i+1 < len(reader.GetHeader()) {
|
||
|
sql += ", "
|
||
|
} else {
|
||
|
sql += ") WITH (DATA_COMPRESSION = PAGE)" // TODO: add optional params
|
||
|
}
|
||
|
}
|
||
|
|
||
|
logger.Println("Creating table...")
|
||
|
logger.Println(sql)
|
||
|
_, err := tx.Exec(sql)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("execute table creation: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func insertData(reader Reader, tx *sql.Tx) error {
|
||
|
columnNames := reader.GetHeader()
|
||
|
if reader.Options().unknownColumnNames {
|
||
|
columnNames = reader.Options().columnNames
|
||
|
}
|
||
|
|
||
|
sql := mssql.CopyIn(reader.Options().tableName, mssql.BulkOptions{Tablock: true}, columnNames...)
|
||
|
|
||
|
stmt, err := tx.Prepare(sql)
|
||
|
if err != nil {
|
||
|
_ = tx.Rollback()
|
||
|
return fmt.Errorf("prepare statement: %w", err)
|
||
|
}
|
||
|
|
||
|
n := 0
|
||
|
for {
|
||
|
if n%100000 == 0 {
|
||
|
if !reader.Options().silent {
|
||
|
fmt.Fprintf(os.Stderr, "Processed %d records...\r", n)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
record, err := reader.GetRow(false)
|
||
|
if err == io.EOF {
|
||
|
break
|
||
|
}
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("read record: %w", err)
|
||
|
}
|
||
|
|
||
|
_, err = stmt.Exec(record...)
|
||
|
if err != nil {
|
||
|
_ = stmt.Close()
|
||
|
_ = tx.Rollback()
|
||
|
return fmt.Errorf("execute statement: %w", err)
|
||
|
}
|
||
|
n++
|
||
|
}
|
||
|
result, err := stmt.Exec()
|
||
|
if err != nil {
|
||
|
_ = tx.Rollback()
|
||
|
return fmt.Errorf("execute statement: %w", err)
|
||
|
}
|
||
|
rowsAffected, err := result.RowsAffected()
|
||
|
if err != nil {
|
||
|
_ = tx.Rollback()
|
||
|
return fmt.Errorf("calc rows affected: %w", err)
|
||
|
}
|
||
|
if !reader.Options().silent {
|
||
|
fmt.Fprintf(os.Stderr, "Processed %d records. \n", rowsAffected)
|
||
|
}
|
||
|
|
||
|
err = stmt.Close()
|
||
|
if err != nil {
|
||
|
_ = tx.Rollback()
|
||
|
return fmt.Errorf("close statement: %w", err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|