mssqlbulkloader/sql.go

152 lines
3.6 KiB
Go
Raw Normal View History

2023-04-03 19:12:30 +05:00
package main
import (
"database/sql"
"fmt"
"io"
"os"
"strings"
mssql "github.com/denisenkom/go-mssqldb"
)
// TODO: add escaping
func prepareTable(reader Reader, tx *sql.Tx) error {
if reader.Options().unknownColumnNames {
var columnNames []string
sql := fmt.Sprintf("SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA + '.' + TABLE_NAME = '%s' ORDER BY ORDINAL_POSITION", reader.Options().tableName)
rows, err := tx.Query(sql)
if err != nil {
return fmt.Errorf("get column names from database: %w", err)
}
defer rows.Close()
for rows.Next() {
if rows.Err() != nil {
return fmt.Errorf("get column names from database: %w", err)
}
var columnName string
err = rows.Scan(&columnName)
if err != nil {
return fmt.Errorf("get column names from database: %w", err)
}
columnNames = append(columnNames, columnName)
}
reader.Options().columnNames = columnNames
} else {
reader.Options().columnNames = reader.GetHeader()
}
if !reader.Options().create && !reader.Options().overwrite {
return nil
}
if !reader.Options().create && reader.Options().overwrite {
logger.Println("Truncating table...")
_, err := tx.Exec(fmt.Sprintf("TRUNCATE TABLE %s", reader.Options().tableName))
if err != nil {
return err
}
}
if reader.Options().overwrite {
logger.Println("Dropping table...")
_, err := tx.Exec(fmt.Sprintf("IF object_id('%s', 'U') IS NOT NULL DROP TABLE %s", reader.Options().tableName, reader.Options().tableName))
if err != nil {
return fmt.Errorf("drop table: %w", err)
}
}
sql := fmt.Sprintf("CREATE TABLE %s (", reader.Options().tableName)
fieldTypes := strings.ReplaceAll(reader.Options().fieldsTypes, " ", "")
for i, columnName := range reader.Options().columnNames {
var fieldType FieldType
err := fieldType.UnmarshalText([]byte{fieldTypes[i]})
if err != nil {
return fmt.Errorf("detect field type: %w", err)
}
sql += fmt.Sprintf(`"%s" %s`, columnName, fieldType.SqlFieldType())
if i+1 < len(reader.GetHeader()) {
sql += ", "
} else {
sql += ") WITH (DATA_COMPRESSION = PAGE)" // TODO: add optional params
}
}
logger.Println("Creating table...")
logger.Println(sql)
_, err := tx.Exec(sql)
if err != nil {
return fmt.Errorf("execute table creation: %w", err)
}
return nil
}
func insertData(reader Reader, tx *sql.Tx) error {
columnNames := reader.GetHeader()
if reader.Options().unknownColumnNames {
columnNames = reader.Options().columnNames
}
sql := mssql.CopyIn(reader.Options().tableName, mssql.BulkOptions{Tablock: true}, columnNames...)
stmt, err := tx.Prepare(sql)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("prepare statement: %w", err)
}
n := 0
for {
if n%100000 == 0 {
if !reader.Options().silent {
fmt.Fprintf(os.Stderr, "Processed %d records...\r", n)
}
}
record, err := reader.GetRow(false)
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("read record: %w", err)
}
_, err = stmt.Exec(record...)
if err != nil {
_ = stmt.Close()
_ = tx.Rollback()
return fmt.Errorf("execute statement: %w", err)
}
n++
}
result, err := stmt.Exec()
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("execute statement: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("calc rows affected: %w", err)
}
if !reader.Options().silent {
fmt.Fprintf(os.Stderr, "Processed %d records. \n", rowsAffected)
}
err = stmt.Close()
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("close statement: %w", err)
}
return nil
}