You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
4.7 KiB
208 lines
4.7 KiB
package postgres |
|
|
|
import ( |
|
"database/sql" |
|
"fmt" |
|
"regexp" |
|
"strconv" |
|
|
|
"github.com/jackc/pgx/v4" |
|
"github.com/jackc/pgx/v4/stdlib" |
|
"gorm.io/gorm" |
|
"gorm.io/gorm/callbacks" |
|
"gorm.io/gorm/clause" |
|
"gorm.io/gorm/logger" |
|
"gorm.io/gorm/migrator" |
|
"gorm.io/gorm/schema" |
|
) |
|
|
|
type Dialector struct { |
|
*Config |
|
} |
|
|
|
type Config struct { |
|
DriverName string |
|
DSN string |
|
PreferSimpleProtocol bool |
|
WithoutReturning bool |
|
Conn gorm.ConnPool |
|
} |
|
|
|
func Open(dsn string) gorm.Dialector { |
|
return &Dialector{&Config{DSN: dsn}} |
|
} |
|
|
|
func New(config Config) gorm.Dialector { |
|
return &Dialector{Config: &config} |
|
} |
|
|
|
func (dialector Dialector) Name() string { |
|
return "postgres" |
|
} |
|
|
|
var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") |
|
|
|
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { |
|
// register callbacks |
|
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ |
|
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, |
|
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, |
|
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, |
|
}) |
|
|
|
if dialector.Conn != nil { |
|
db.ConnPool = dialector.Conn |
|
} else if dialector.DriverName != "" { |
|
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN) |
|
} else { |
|
var config *pgx.ConnConfig |
|
|
|
config, err = pgx.ParseConfig(dialector.Config.DSN) |
|
if err != nil { |
|
return |
|
} |
|
if dialector.Config.PreferSimpleProtocol { |
|
config.PreferSimpleProtocol = true |
|
} |
|
result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) |
|
if len(result) > 2 { |
|
config.RuntimeParams["timezone"] = result[2] |
|
} |
|
db.ConnPool = stdlib.OpenDB(*config) |
|
} |
|
return |
|
} |
|
|
|
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { |
|
return Migrator{migrator.Migrator{Config: migrator.Config{ |
|
DB: db, |
|
Dialector: dialector, |
|
CreateIndexAfterCreateTable: true, |
|
}}} |
|
} |
|
|
|
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { |
|
return clause.Expr{SQL: "DEFAULT"} |
|
} |
|
|
|
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { |
|
writer.WriteByte('$') |
|
writer.WriteString(strconv.Itoa(len(stmt.Vars))) |
|
} |
|
|
|
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { |
|
var ( |
|
underQuoted, selfQuoted bool |
|
continuousBacktick int8 |
|
shiftDelimiter int8 |
|
) |
|
|
|
for _, v := range []byte(str) { |
|
switch v { |
|
case '"': |
|
continuousBacktick++ |
|
if continuousBacktick == 2 { |
|
writer.WriteString(`""`) |
|
continuousBacktick = 0 |
|
} |
|
case '.': |
|
if continuousBacktick > 0 || !selfQuoted { |
|
shiftDelimiter = 0 |
|
underQuoted = false |
|
continuousBacktick = 0 |
|
writer.WriteByte('"') |
|
} |
|
writer.WriteByte(v) |
|
continue |
|
default: |
|
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { |
|
writer.WriteByte('"') |
|
underQuoted = true |
|
if selfQuoted = continuousBacktick > 0; selfQuoted { |
|
continuousBacktick -= 1 |
|
} |
|
} |
|
|
|
for ; continuousBacktick > 0; continuousBacktick -= 1 { |
|
writer.WriteString(`""`) |
|
} |
|
|
|
writer.WriteByte(v) |
|
} |
|
shiftDelimiter++ |
|
} |
|
|
|
if continuousBacktick > 0 && !selfQuoted { |
|
writer.WriteString(`""`) |
|
} |
|
writer.WriteByte('"') |
|
} |
|
|
|
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)") |
|
|
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string { |
|
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) |
|
} |
|
|
|
func (dialector Dialector) DataTypeOf(field *schema.Field) string { |
|
switch field.DataType { |
|
case schema.Bool: |
|
return "boolean" |
|
case schema.Int, schema.Uint: |
|
size := field.Size |
|
if field.DataType == schema.Uint { |
|
size++ |
|
} |
|
if field.AutoIncrement { |
|
switch { |
|
case size <= 16: |
|
return "smallserial" |
|
case size <= 32: |
|
return "serial" |
|
default: |
|
return "bigserial" |
|
} |
|
} else { |
|
switch { |
|
case size <= 16: |
|
return "smallint" |
|
case size <= 32: |
|
return "integer" |
|
default: |
|
return "bigint" |
|
} |
|
} |
|
case schema.Float: |
|
if field.Precision > 0 { |
|
if field.Scale > 0 { |
|
return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale) |
|
} |
|
return fmt.Sprintf("numeric(%d)", field.Precision) |
|
} |
|
return "decimal" |
|
case schema.String: |
|
if field.Size > 0 { |
|
return fmt.Sprintf("varchar(%d)", field.Size) |
|
} |
|
return "text" |
|
case schema.Time: |
|
if field.Precision > 0 { |
|
return fmt.Sprintf("timestamptz(%d)", field.Precision) |
|
} |
|
return "timestamptz" |
|
case schema.Bytes: |
|
return "bytea" |
|
} |
|
|
|
return string(field.DataType) |
|
} |
|
|
|
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error { |
|
tx.Exec("SAVEPOINT " + name) |
|
return nil |
|
} |
|
|
|
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error { |
|
tx.Exec("ROLLBACK TO SAVEPOINT " + name) |
|
return nil |
|
}
|
|
|