You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
291 lines
9.1 KiB
291 lines
9.1 KiB
package callbacks |
|
|
|
import ( |
|
"reflect" |
|
"sort" |
|
|
|
"gorm.io/gorm" |
|
"gorm.io/gorm/clause" |
|
"gorm.io/gorm/schema" |
|
"gorm.io/gorm/utils" |
|
) |
|
|
|
func SetupUpdateReflectValue(db *gorm.DB) { |
|
if db.Error == nil && db.Statement.Schema != nil { |
|
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { |
|
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) |
|
for db.Statement.ReflectValue.Kind() == reflect.Ptr { |
|
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() |
|
} |
|
|
|
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { |
|
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { |
|
if _, ok := dest[rel.Name]; ok { |
|
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
// BeforeUpdate before update hooks |
|
func BeforeUpdate(db *gorm.DB) { |
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { |
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { |
|
if db.Statement.Schema.BeforeSave { |
|
if i, ok := value.(BeforeSaveInterface); ok { |
|
called = true |
|
db.AddError(i.BeforeSave(tx)) |
|
} |
|
} |
|
|
|
if db.Statement.Schema.BeforeUpdate { |
|
if i, ok := value.(BeforeUpdateInterface); ok { |
|
called = true |
|
db.AddError(i.BeforeUpdate(tx)) |
|
} |
|
} |
|
|
|
return called |
|
}) |
|
} |
|
} |
|
|
|
// Update update hook |
|
func Update(config *Config) func(db *gorm.DB) { |
|
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") |
|
|
|
return func(db *gorm.DB) { |
|
if db.Error != nil { |
|
return |
|
} |
|
|
|
if db.Statement.Schema != nil { |
|
for _, c := range db.Statement.Schema.UpdateClauses { |
|
db.Statement.AddClause(c) |
|
} |
|
} |
|
|
|
if db.Statement.SQL.Len() == 0 { |
|
db.Statement.SQL.Grow(180) |
|
db.Statement.AddClauseIfNotExists(clause.Update{}) |
|
if set := ConvertToAssignments(db.Statement); len(set) != 0 { |
|
db.Statement.AddClause(set) |
|
} else if _, ok := db.Statement.Clauses["SET"]; !ok { |
|
return |
|
} |
|
|
|
db.Statement.Build(db.Statement.BuildClauses...) |
|
} |
|
|
|
checkMissingWhereConditions(db) |
|
|
|
if !db.DryRun && db.Error == nil { |
|
if ok, mode := hasReturning(db, supportReturning); ok { |
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { |
|
dest := db.Statement.Dest |
|
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() |
|
gorm.Scan(rows, db, mode) |
|
db.Statement.Dest = dest |
|
db.AddError(rows.Close()) |
|
} |
|
} else { |
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) |
|
|
|
if db.AddError(err) == nil { |
|
db.RowsAffected, _ = result.RowsAffected() |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
// AfterUpdate after update hooks |
|
func AfterUpdate(db *gorm.DB) { |
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { |
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { |
|
if db.Statement.Schema.AfterUpdate { |
|
if i, ok := value.(AfterUpdateInterface); ok { |
|
called = true |
|
db.AddError(i.AfterUpdate(tx)) |
|
} |
|
} |
|
|
|
if db.Statement.Schema.AfterSave { |
|
if i, ok := value.(AfterSaveInterface); ok { |
|
called = true |
|
db.AddError(i.AfterSave(tx)) |
|
} |
|
} |
|
|
|
return called |
|
}) |
|
} |
|
} |
|
|
|
// ConvertToAssignments convert to update assignments |
|
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { |
|
var ( |
|
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) |
|
assignValue func(field *schema.Field, value interface{}) |
|
) |
|
|
|
switch stmt.ReflectValue.Kind() { |
|
case reflect.Slice, reflect.Array: |
|
assignValue = func(field *schema.Field, value interface{}) { |
|
for i := 0; i < stmt.ReflectValue.Len(); i++ { |
|
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) |
|
} |
|
} |
|
case reflect.Struct: |
|
assignValue = func(field *schema.Field, value interface{}) { |
|
if stmt.ReflectValue.CanAddr() { |
|
field.Set(stmt.Context, stmt.ReflectValue, value) |
|
} |
|
} |
|
default: |
|
assignValue = func(field *schema.Field, value interface{}) { |
|
} |
|
} |
|
|
|
updatingValue := reflect.ValueOf(stmt.Dest) |
|
for updatingValue.Kind() == reflect.Ptr { |
|
updatingValue = updatingValue.Elem() |
|
} |
|
|
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { |
|
switch stmt.ReflectValue.Kind() { |
|
case reflect.Slice, reflect.Array: |
|
if size := stmt.ReflectValue.Len(); size > 0 { |
|
var primaryKeyExprs []clause.Expression |
|
for i := 0; i < size; i++ { |
|
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) |
|
var notZero bool |
|
for idx, field := range stmt.Schema.PrimaryFields { |
|
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) |
|
exprs[idx] = clause.Eq{Column: field.DBName, Value: value} |
|
notZero = notZero || !isZero |
|
} |
|
if notZero { |
|
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) |
|
} |
|
} |
|
|
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) |
|
} |
|
case reflect.Struct: |
|
for _, field := range stmt.Schema.PrimaryFields { |
|
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { |
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) |
|
} |
|
} |
|
} |
|
} |
|
|
|
switch value := updatingValue.Interface().(type) { |
|
case map[string]interface{}: |
|
set = make([]clause.Assignment, 0, len(value)) |
|
|
|
keys := make([]string, 0, len(value)) |
|
for k := range value { |
|
keys = append(keys, k) |
|
} |
|
sort.Strings(keys) |
|
|
|
for _, k := range keys { |
|
kv := value[k] |
|
if _, ok := kv.(*gorm.DB); ok { |
|
kv = []interface{}{kv} |
|
} |
|
|
|
if stmt.Schema != nil { |
|
if field := stmt.Schema.LookUpField(k); field != nil { |
|
if field.DBName != "" { |
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { |
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) |
|
assignValue(field, value[k]) |
|
} |
|
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { |
|
assignValue(field, value[k]) |
|
} |
|
continue |
|
} |
|
} |
|
|
|
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { |
|
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) |
|
} |
|
} |
|
|
|
if !stmt.SkipHooks && stmt.Schema != nil { |
|
for _, dbName := range stmt.Schema.DBNames { |
|
field := stmt.Schema.LookUpField(dbName) |
|
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { |
|
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { |
|
now := stmt.DB.NowFunc() |
|
assignValue(field, now) |
|
|
|
if field.AutoUpdateTime == schema.UnixNanosecond { |
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) |
|
} else if field.AutoUpdateTime == schema.UnixMillisecond { |
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) |
|
} else if field.AutoUpdateTime == schema.UnixSecond { |
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) |
|
} else { |
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) |
|
} |
|
} |
|
} |
|
} |
|
} |
|
default: |
|
updatingSchema := stmt.Schema |
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { |
|
// different schema |
|
updatingStmt := &gorm.Statement{DB: stmt.DB} |
|
if err := updatingStmt.Parse(stmt.Dest); err == nil { |
|
updatingSchema = updatingStmt.Schema |
|
} |
|
} |
|
|
|
switch updatingValue.Kind() { |
|
case reflect.Struct: |
|
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) |
|
for _, dbName := range stmt.Schema.DBNames { |
|
if field := updatingSchema.LookUpField(dbName); field != nil { |
|
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { |
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { |
|
value, isZero := field.ValueOf(stmt.Context, updatingValue) |
|
if !stmt.SkipHooks && field.AutoUpdateTime > 0 { |
|
if field.AutoUpdateTime == schema.UnixNanosecond { |
|
value = stmt.DB.NowFunc().UnixNano() |
|
} else if field.AutoUpdateTime == schema.UnixMillisecond { |
|
value = stmt.DB.NowFunc().UnixNano() / 1e6 |
|
} else if field.AutoUpdateTime == schema.UnixSecond { |
|
value = stmt.DB.NowFunc().Unix() |
|
} else { |
|
value = stmt.DB.NowFunc() |
|
} |
|
isZero = false |
|
} |
|
|
|
if (ok || !isZero) && field.Updatable { |
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) |
|
assignValue(field, value) |
|
} |
|
} |
|
} else { |
|
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { |
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) |
|
} |
|
} |
|
} |
|
} |
|
default: |
|
stmt.AddError(gorm.ErrInvalidData) |
|
} |
|
} |
|
|
|
return |
|
}
|
|
|