You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
469 lines
11 KiB
469 lines
11 KiB
package gorm |
|
|
|
import ( |
|
"context" |
|
"database/sql" |
|
"fmt" |
|
"sort" |
|
"sync" |
|
"time" |
|
|
|
"gorm.io/gorm/clause" |
|
"gorm.io/gorm/logger" |
|
"gorm.io/gorm/schema" |
|
) |
|
|
|
// for Config.cacheStore store PreparedStmtDB key |
|
const preparedStmtDBKey = "preparedStmt" |
|
|
|
// Config GORM config |
|
type Config struct { |
|
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity |
|
// You can disable it by setting `SkipDefaultTransaction` to true |
|
SkipDefaultTransaction bool |
|
// NamingStrategy tables, columns naming strategy |
|
NamingStrategy schema.Namer |
|
// FullSaveAssociations full save associations |
|
FullSaveAssociations bool |
|
// Logger |
|
Logger logger.Interface |
|
// NowFunc the function to be used when creating a new timestamp |
|
NowFunc func() time.Time |
|
// DryRun generate sql without execute |
|
DryRun bool |
|
// PrepareStmt executes the given query in cached statement |
|
PrepareStmt bool |
|
// DisableAutomaticPing |
|
DisableAutomaticPing bool |
|
// DisableForeignKeyConstraintWhenMigrating |
|
DisableForeignKeyConstraintWhenMigrating bool |
|
// DisableNestedTransaction disable nested transaction |
|
DisableNestedTransaction bool |
|
// AllowGlobalUpdate allow global update |
|
AllowGlobalUpdate bool |
|
// QueryFields executes the SQL query with all fields of the table |
|
QueryFields bool |
|
// CreateBatchSize default create batch size |
|
CreateBatchSize int |
|
|
|
// ClauseBuilders clause builder |
|
ClauseBuilders map[string]clause.ClauseBuilder |
|
// ConnPool db conn pool |
|
ConnPool ConnPool |
|
// Dialector database dialector |
|
Dialector |
|
// Plugins registered plugins |
|
Plugins map[string]Plugin |
|
|
|
callbacks *callbacks |
|
cacheStore *sync.Map |
|
} |
|
|
|
// Apply update config to new config |
|
func (c *Config) Apply(config *Config) error { |
|
if config != c { |
|
*config = *c |
|
} |
|
return nil |
|
} |
|
|
|
// AfterInitialize initialize plugins after db connected |
|
func (c *Config) AfterInitialize(db *DB) error { |
|
if db != nil { |
|
for _, plugin := range c.Plugins { |
|
if err := plugin.Initialize(db); err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
// Option gorm option interface |
|
type Option interface { |
|
Apply(*Config) error |
|
AfterInitialize(*DB) error |
|
} |
|
|
|
// DB GORM DB definition |
|
type DB struct { |
|
*Config |
|
Error error |
|
RowsAffected int64 |
|
Statement *Statement |
|
clone int |
|
} |
|
|
|
// Session session config when create session with Session() method |
|
type Session struct { |
|
DryRun bool |
|
PrepareStmt bool |
|
NewDB bool |
|
Initialized bool |
|
SkipHooks bool |
|
SkipDefaultTransaction bool |
|
DisableNestedTransaction bool |
|
AllowGlobalUpdate bool |
|
FullSaveAssociations bool |
|
QueryFields bool |
|
Context context.Context |
|
Logger logger.Interface |
|
NowFunc func() time.Time |
|
CreateBatchSize int |
|
} |
|
|
|
// Open initialize db session based on dialector |
|
func Open(dialector Dialector, opts ...Option) (db *DB, err error) { |
|
config := &Config{} |
|
|
|
sort.Slice(opts, func(i, j int) bool { |
|
_, isConfig := opts[i].(*Config) |
|
_, isConfig2 := opts[j].(*Config) |
|
return isConfig && !isConfig2 |
|
}) |
|
|
|
for _, opt := range opts { |
|
if opt != nil { |
|
if applyErr := opt.Apply(config); applyErr != nil { |
|
return nil, applyErr |
|
} |
|
defer func(opt Option) { |
|
if errr := opt.AfterInitialize(db); errr != nil { |
|
err = errr |
|
} |
|
}(opt) |
|
} |
|
} |
|
|
|
if d, ok := dialector.(interface{ Apply(*Config) error }); ok { |
|
if err = d.Apply(config); err != nil { |
|
return |
|
} |
|
} |
|
|
|
if config.NamingStrategy == nil { |
|
config.NamingStrategy = schema.NamingStrategy{} |
|
} |
|
|
|
if config.Logger == nil { |
|
config.Logger = logger.Default |
|
} |
|
|
|
if config.NowFunc == nil { |
|
config.NowFunc = func() time.Time { return time.Now().Local() } |
|
} |
|
|
|
if dialector != nil { |
|
config.Dialector = dialector |
|
} |
|
|
|
if config.Plugins == nil { |
|
config.Plugins = map[string]Plugin{} |
|
} |
|
|
|
if config.cacheStore == nil { |
|
config.cacheStore = &sync.Map{} |
|
} |
|
|
|
db = &DB{Config: config, clone: 1} |
|
|
|
db.callbacks = initializeCallbacks(db) |
|
|
|
if config.ClauseBuilders == nil { |
|
config.ClauseBuilders = map[string]clause.ClauseBuilder{} |
|
} |
|
|
|
if config.Dialector != nil { |
|
err = config.Dialector.Initialize(db) |
|
} |
|
|
|
preparedStmt := &PreparedStmtDB{ |
|
ConnPool: db.ConnPool, |
|
Stmts: map[string]Stmt{}, |
|
Mux: &sync.RWMutex{}, |
|
PreparedSQL: make([]string, 0, 100), |
|
} |
|
db.cacheStore.Store(preparedStmtDBKey, preparedStmt) |
|
|
|
if config.PrepareStmt { |
|
db.ConnPool = preparedStmt |
|
} |
|
|
|
db.Statement = &Statement{ |
|
DB: db, |
|
ConnPool: db.ConnPool, |
|
Context: context.Background(), |
|
Clauses: map[string]clause.Clause{}, |
|
} |
|
|
|
if err == nil && !config.DisableAutomaticPing { |
|
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { |
|
err = pinger.Ping() |
|
} |
|
} |
|
|
|
if err != nil { |
|
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) |
|
} |
|
|
|
return |
|
} |
|
|
|
// Session create new db session |
|
func (db *DB) Session(config *Session) *DB { |
|
var ( |
|
txConfig = *db.Config |
|
tx = &DB{ |
|
Config: &txConfig, |
|
Statement: db.Statement, |
|
Error: db.Error, |
|
clone: 1, |
|
} |
|
) |
|
if config.CreateBatchSize > 0 { |
|
tx.Config.CreateBatchSize = config.CreateBatchSize |
|
} |
|
|
|
if config.SkipDefaultTransaction { |
|
tx.Config.SkipDefaultTransaction = true |
|
} |
|
|
|
if config.AllowGlobalUpdate { |
|
txConfig.AllowGlobalUpdate = true |
|
} |
|
|
|
if config.FullSaveAssociations { |
|
txConfig.FullSaveAssociations = true |
|
} |
|
|
|
if config.Context != nil || config.PrepareStmt || config.SkipHooks { |
|
tx.Statement = tx.Statement.clone() |
|
tx.Statement.DB = tx |
|
} |
|
|
|
if config.Context != nil { |
|
tx.Statement.Context = config.Context |
|
} |
|
|
|
if config.PrepareStmt { |
|
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { |
|
preparedStmt := v.(*PreparedStmtDB) |
|
tx.Statement.ConnPool = &PreparedStmtDB{ |
|
ConnPool: db.Config.ConnPool, |
|
Mux: preparedStmt.Mux, |
|
Stmts: preparedStmt.Stmts, |
|
} |
|
txConfig.ConnPool = tx.Statement.ConnPool |
|
txConfig.PrepareStmt = true |
|
} |
|
} |
|
|
|
if config.SkipHooks { |
|
tx.Statement.SkipHooks = true |
|
} |
|
|
|
if config.DisableNestedTransaction { |
|
txConfig.DisableNestedTransaction = true |
|
} |
|
|
|
if !config.NewDB { |
|
tx.clone = 2 |
|
} |
|
|
|
if config.DryRun { |
|
tx.Config.DryRun = true |
|
} |
|
|
|
if config.QueryFields { |
|
tx.Config.QueryFields = true |
|
} |
|
|
|
if config.Logger != nil { |
|
tx.Config.Logger = config.Logger |
|
} |
|
|
|
if config.NowFunc != nil { |
|
tx.Config.NowFunc = config.NowFunc |
|
} |
|
|
|
if config.Initialized { |
|
tx = tx.getInstance() |
|
} |
|
|
|
return tx |
|
} |
|
|
|
// WithContext change current instance db's context to ctx |
|
func (db *DB) WithContext(ctx context.Context) *DB { |
|
return db.Session(&Session{Context: ctx}) |
|
} |
|
|
|
// Debug start debug mode |
|
func (db *DB) Debug() (tx *DB) { |
|
return db.Session(&Session{ |
|
Logger: db.Logger.LogMode(logger.Info), |
|
}) |
|
} |
|
|
|
// Set store value with key into current db instance's context |
|
func (db *DB) Set(key string, value interface{}) *DB { |
|
tx := db.getInstance() |
|
tx.Statement.Settings.Store(key, value) |
|
return tx |
|
} |
|
|
|
// Get get value with key from current db instance's context |
|
func (db *DB) Get(key string) (interface{}, bool) { |
|
return db.Statement.Settings.Load(key) |
|
} |
|
|
|
// InstanceSet store value with key into current db instance's context |
|
func (db *DB) InstanceSet(key string, value interface{}) *DB { |
|
tx := db.getInstance() |
|
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) |
|
return tx |
|
} |
|
|
|
// InstanceGet get value with key from current db instance's context |
|
func (db *DB) InstanceGet(key string) (interface{}, bool) { |
|
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) |
|
} |
|
|
|
// Callback returns callback manager |
|
func (db *DB) Callback() *callbacks { |
|
return db.callbacks |
|
} |
|
|
|
// AddError add error to db |
|
func (db *DB) AddError(err error) error { |
|
if db.Error == nil { |
|
db.Error = err |
|
} else if err != nil { |
|
db.Error = fmt.Errorf("%v; %w", db.Error, err) |
|
} |
|
return db.Error |
|
} |
|
|
|
// DB returns `*sql.DB` |
|
func (db *DB) DB() (*sql.DB, error) { |
|
connPool := db.ConnPool |
|
|
|
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { |
|
return dbConnector.GetDBConn() |
|
} |
|
|
|
if sqldb, ok := connPool.(*sql.DB); ok { |
|
return sqldb, nil |
|
} |
|
|
|
return nil, ErrInvalidDB |
|
} |
|
|
|
func (db *DB) getInstance() *DB { |
|
if db.clone > 0 { |
|
tx := &DB{Config: db.Config, Error: db.Error} |
|
|
|
if db.clone == 1 { |
|
// clone with new statement |
|
tx.Statement = &Statement{ |
|
DB: tx, |
|
ConnPool: db.Statement.ConnPool, |
|
Context: db.Statement.Context, |
|
Clauses: map[string]clause.Clause{}, |
|
Vars: make([]interface{}, 0, 8), |
|
} |
|
} else { |
|
// with clone statement |
|
tx.Statement = db.Statement.clone() |
|
tx.Statement.DB = tx |
|
} |
|
|
|
return tx |
|
} |
|
|
|
return db |
|
} |
|
|
|
// Expr returns clause.Expr, which can be used to pass SQL expression as params |
|
func Expr(expr string, args ...interface{}) clause.Expr { |
|
return clause.Expr{SQL: expr, Vars: args} |
|
} |
|
|
|
// SetupJoinTable setup join table schema |
|
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { |
|
var ( |
|
tx = db.getInstance() |
|
stmt = tx.Statement |
|
modelSchema, joinSchema *schema.Schema |
|
) |
|
|
|
err := stmt.Parse(model) |
|
if err != nil { |
|
return err |
|
} |
|
modelSchema = stmt.Schema |
|
|
|
err = stmt.Parse(joinTable) |
|
if err != nil { |
|
return err |
|
} |
|
joinSchema = stmt.Schema |
|
|
|
relation, ok := modelSchema.Relationships.Relations[field] |
|
isRelation := ok && relation.JoinTable != nil |
|
if !isRelation { |
|
return fmt.Errorf("failed to found relation: %s", field) |
|
} |
|
|
|
for _, ref := range relation.References { |
|
f := joinSchema.LookUpField(ref.ForeignKey.DBName) |
|
if f == nil { |
|
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) |
|
} |
|
|
|
f.DataType = ref.ForeignKey.DataType |
|
f.GORMDataType = ref.ForeignKey.GORMDataType |
|
if f.Size == 0 { |
|
f.Size = ref.ForeignKey.Size |
|
} |
|
ref.ForeignKey = f |
|
} |
|
|
|
for name, rel := range relation.JoinTable.Relationships.Relations { |
|
if _, ok := joinSchema.Relationships.Relations[name]; !ok { |
|
rel.Schema = joinSchema |
|
joinSchema.Relationships.Relations[name] = rel |
|
} |
|
} |
|
relation.JoinTable = joinSchema |
|
|
|
return nil |
|
} |
|
|
|
// Use use plugin |
|
func (db *DB) Use(plugin Plugin) error { |
|
name := plugin.Name() |
|
if _, ok := db.Plugins[name]; ok { |
|
return ErrRegistered |
|
} |
|
if err := plugin.Initialize(db); err != nil { |
|
return err |
|
} |
|
db.Plugins[name] = plugin |
|
return nil |
|
} |
|
|
|
// ToSQL for generate SQL string. |
|
// |
|
// db.ToSQL(func(tx *gorm.DB) *gorm.DB { |
|
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) |
|
// .Limit(10).Offset(5) |
|
// .Order("name ASC") |
|
// .First(&User{}) |
|
// }) |
|
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { |
|
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) |
|
stmt := tx.Statement |
|
|
|
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) |
|
}
|
|
|