You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
4.7 KiB
209 lines
4.7 KiB
3 years ago
|
package postgres
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
"regexp"
|
||
|
"strconv"
|
||
|
|
||
|
"github.com/jackc/pgx/v4"
|
||
|
"github.com/jackc/pgx/v4/stdlib"
|
||
|
"gorm.io/gorm"
|
||
|
"gorm.io/gorm/callbacks"
|
||
|
"gorm.io/gorm/clause"
|
||
|
"gorm.io/gorm/logger"
|
||
|
"gorm.io/gorm/migrator"
|
||
|
"gorm.io/gorm/schema"
|
||
|
)
|
||
|
|
||
|
type Dialector struct {
|
||
|
*Config
|
||
|
}
|
||
|
|
||
|
type Config struct {
|
||
|
DriverName string
|
||
|
DSN string
|
||
|
PreferSimpleProtocol bool
|
||
|
WithoutReturning bool
|
||
|
Conn gorm.ConnPool
|
||
|
}
|
||
|
|
||
|
func Open(dsn string) gorm.Dialector {
|
||
|
return &Dialector{&Config{DSN: dsn}}
|
||
|
}
|
||
|
|
||
|
func New(config Config) gorm.Dialector {
|
||
|
return &Dialector{Config: &config}
|
||
|
}
|
||
|
|
||
|
func (dialector Dialector) Name() string {
|
||
|
return "postgres"
|
||
|
}
|
||
|
|
||
|
var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )")
|
||
|
|
||
|
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||
|
// register callbacks
|
||
|
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||
|
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||
|
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||
|
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||
|
})
|
||
|
|
||
|
if dialector.Conn != nil {
|
||
|
db.ConnPool = dialector.Conn
|
||
|
} else if dialector.DriverName != "" {
|
||
|
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN)
|
||
|
} else {
|
||
|
var config *pgx.ConnConfig
|
||
|
|
||
|
config, err = pgx.ParseConfig(dialector.Config.DSN)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
if dialector.Config.PreferSimpleProtocol {
|
||
|
config.PreferSimpleProtocol = true
|
||
|
}
|
||
|
result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN)
|
||
|
if len(result) > 2 {
|
||
|
config.RuntimeParams["timezone"] = result[2]
|
||
|
}
|
||
|
db.ConnPool = stdlib.OpenDB(*config)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||
|
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||
|
DB: db,
|
||
|
Dialector: dialector,
|
||
|
CreateIndexAfterCreateTable: true,
|
||
|
}}}
|
||
|
}
|
||
|
|
||
|
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||
|
return clause.Expr{SQL: "DEFAULT"}
|
||
|
}
|
||
|
|
||
|
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||
|
writer.WriteByte('$')
|
||
|
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
|
||
|
}
|
||
|
|
||
|
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||
|
var (
|
||
|
underQuoted, selfQuoted bool
|
||
|
continuousBacktick int8
|
||
|
shiftDelimiter int8
|
||
|
)
|
||
|
|
||
|
for _, v := range []byte(str) {
|
||
|
switch v {
|
||
|
case '"':
|
||
|
continuousBacktick++
|
||
|
if continuousBacktick == 2 {
|
||
|
writer.WriteString(`""`)
|
||
|
continuousBacktick = 0
|
||
|
}
|
||
|
case '.':
|
||
|
if continuousBacktick > 0 || !selfQuoted {
|
||
|
shiftDelimiter = 0
|
||
|
underQuoted = false
|
||
|
continuousBacktick = 0
|
||
|
writer.WriteByte('"')
|
||
|
}
|
||
|
writer.WriteByte(v)
|
||
|
continue
|
||
|
default:
|
||
|
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
||
|
writer.WriteByte('"')
|
||
|
underQuoted = true
|
||
|
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
||
|
continuousBacktick -= 1
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
||
|
writer.WriteString(`""`)
|
||
|
}
|
||
|
|
||
|
writer.WriteByte(v)
|
||
|
}
|
||
|
shiftDelimiter++
|
||
|
}
|
||
|
|
||
|
if continuousBacktick > 0 && !selfQuoted {
|
||
|
writer.WriteString(`""`)
|
||
|
}
|
||
|
writer.WriteByte('"')
|
||
|
}
|
||
|
|
||
|
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
|
||
|
|
||
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||
|
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
||
|
}
|
||
|
|
||
|
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||
|
switch field.DataType {
|
||
|
case schema.Bool:
|
||
|
return "boolean"
|
||
|
case schema.Int, schema.Uint:
|
||
|
size := field.Size
|
||
|
if field.DataType == schema.Uint {
|
||
|
size++
|
||
|
}
|
||
|
if field.AutoIncrement {
|
||
|
switch {
|
||
|
case size <= 16:
|
||
|
return "smallserial"
|
||
|
case size <= 32:
|
||
|
return "serial"
|
||
|
default:
|
||
|
return "bigserial"
|
||
|
}
|
||
|
} else {
|
||
|
switch {
|
||
|
case size <= 16:
|
||
|
return "smallint"
|
||
|
case size <= 32:
|
||
|
return "integer"
|
||
|
default:
|
||
|
return "bigint"
|
||
|
}
|
||
|
}
|
||
|
case schema.Float:
|
||
|
if field.Precision > 0 {
|
||
|
if field.Scale > 0 {
|
||
|
return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale)
|
||
|
}
|
||
|
return fmt.Sprintf("numeric(%d)", field.Precision)
|
||
|
}
|
||
|
return "decimal"
|
||
|
case schema.String:
|
||
|
if field.Size > 0 {
|
||
|
return fmt.Sprintf("varchar(%d)", field.Size)
|
||
|
}
|
||
|
return "text"
|
||
|
case schema.Time:
|
||
|
if field.Precision > 0 {
|
||
|
return fmt.Sprintf("timestamptz(%d)", field.Precision)
|
||
|
}
|
||
|
return "timestamptz"
|
||
|
case schema.Bytes:
|
||
|
return "bytea"
|
||
|
}
|
||
|
|
||
|
return string(field.DataType)
|
||
|
}
|
||
|
|
||
|
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||
|
tx.Exec("SAVEPOINT " + name)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||
|
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||
|
return nil
|
||
|
}
|