You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
331 lines
7.9 KiB
331 lines
7.9 KiB
package gorm |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"fmt" |
|
"reflect" |
|
"sort" |
|
"time" |
|
|
|
"gorm.io/gorm/schema" |
|
"gorm.io/gorm/utils" |
|
) |
|
|
|
func initializeCallbacks(db *DB) *callbacks { |
|
return &callbacks{ |
|
processors: map[string]*processor{ |
|
"create": {db: db}, |
|
"query": {db: db}, |
|
"update": {db: db}, |
|
"delete": {db: db}, |
|
"row": {db: db}, |
|
"raw": {db: db}, |
|
}, |
|
} |
|
} |
|
|
|
// callbacks gorm callbacks manager |
|
type callbacks struct { |
|
processors map[string]*processor |
|
} |
|
|
|
type processor struct { |
|
db *DB |
|
Clauses []string |
|
fns []func(*DB) |
|
callbacks []*callback |
|
} |
|
|
|
type callback struct { |
|
name string |
|
before string |
|
after string |
|
remove bool |
|
replace bool |
|
match func(*DB) bool |
|
handler func(*DB) |
|
processor *processor |
|
} |
|
|
|
func (cs *callbacks) Create() *processor { |
|
return cs.processors["create"] |
|
} |
|
|
|
func (cs *callbacks) Query() *processor { |
|
return cs.processors["query"] |
|
} |
|
|
|
func (cs *callbacks) Update() *processor { |
|
return cs.processors["update"] |
|
} |
|
|
|
func (cs *callbacks) Delete() *processor { |
|
return cs.processors["delete"] |
|
} |
|
|
|
func (cs *callbacks) Row() *processor { |
|
return cs.processors["row"] |
|
} |
|
|
|
func (cs *callbacks) Raw() *processor { |
|
return cs.processors["raw"] |
|
} |
|
|
|
func (p *processor) Execute(db *DB) *DB { |
|
// call scopes |
|
for len(db.Statement.scopes) > 0 { |
|
scopes := db.Statement.scopes |
|
db.Statement.scopes = nil |
|
for _, scope := range scopes { |
|
db = scope(db) |
|
} |
|
} |
|
|
|
var ( |
|
curTime = time.Now() |
|
stmt = db.Statement |
|
resetBuildClauses bool |
|
) |
|
|
|
if len(stmt.BuildClauses) == 0 { |
|
stmt.BuildClauses = p.Clauses |
|
resetBuildClauses = true |
|
} |
|
|
|
// assign model values |
|
if stmt.Model == nil { |
|
stmt.Model = stmt.Dest |
|
} else if stmt.Dest == nil { |
|
stmt.Dest = stmt.Model |
|
} |
|
|
|
// parse model values |
|
if stmt.Model != nil { |
|
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { |
|
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { |
|
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) |
|
} else { |
|
db.AddError(err) |
|
} |
|
} |
|
} |
|
|
|
// assign stmt.ReflectValue |
|
if stmt.Dest != nil { |
|
stmt.ReflectValue = reflect.ValueOf(stmt.Dest) |
|
for stmt.ReflectValue.Kind() == reflect.Ptr { |
|
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { |
|
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) |
|
} |
|
|
|
stmt.ReflectValue = stmt.ReflectValue.Elem() |
|
} |
|
if !stmt.ReflectValue.IsValid() { |
|
db.AddError(ErrInvalidValue) |
|
} |
|
} |
|
|
|
for _, f := range p.fns { |
|
f(db) |
|
} |
|
|
|
if stmt.SQL.Len() > 0 { |
|
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { |
|
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected |
|
}, db.Error) |
|
} |
|
|
|
if !stmt.DB.DryRun { |
|
stmt.SQL.Reset() |
|
stmt.Vars = nil |
|
} |
|
|
|
if resetBuildClauses { |
|
stmt.BuildClauses = nil |
|
} |
|
|
|
return db |
|
} |
|
|
|
func (p *processor) Get(name string) func(*DB) { |
|
for i := len(p.callbacks) - 1; i >= 0; i-- { |
|
if v := p.callbacks[i]; v.name == name && !v.remove { |
|
return v.handler |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
func (p *processor) Before(name string) *callback { |
|
return &callback{before: name, processor: p} |
|
} |
|
|
|
func (p *processor) After(name string) *callback { |
|
return &callback{after: name, processor: p} |
|
} |
|
|
|
func (p *processor) Match(fc func(*DB) bool) *callback { |
|
return &callback{match: fc, processor: p} |
|
} |
|
|
|
func (p *processor) Register(name string, fn func(*DB)) error { |
|
return (&callback{processor: p}).Register(name, fn) |
|
} |
|
|
|
func (p *processor) Remove(name string) error { |
|
return (&callback{processor: p}).Remove(name) |
|
} |
|
|
|
func (p *processor) Replace(name string, fn func(*DB)) error { |
|
return (&callback{processor: p}).Replace(name, fn) |
|
} |
|
|
|
func (p *processor) compile() (err error) { |
|
var callbacks []*callback |
|
for _, callback := range p.callbacks { |
|
if callback.match == nil || callback.match(p.db) { |
|
callbacks = append(callbacks, callback) |
|
} |
|
} |
|
p.callbacks = callbacks |
|
|
|
if p.fns, err = sortCallbacks(p.callbacks); err != nil { |
|
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) |
|
} |
|
return |
|
} |
|
|
|
func (c *callback) Before(name string) *callback { |
|
c.before = name |
|
return c |
|
} |
|
|
|
func (c *callback) After(name string) *callback { |
|
c.after = name |
|
return c |
|
} |
|
|
|
func (c *callback) Register(name string, fn func(*DB)) error { |
|
c.name = name |
|
c.handler = fn |
|
c.processor.callbacks = append(c.processor.callbacks, c) |
|
return c.processor.compile() |
|
} |
|
|
|
func (c *callback) Remove(name string) error { |
|
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) |
|
c.name = name |
|
c.remove = true |
|
c.processor.callbacks = append(c.processor.callbacks, c) |
|
return c.processor.compile() |
|
} |
|
|
|
func (c *callback) Replace(name string, fn func(*DB)) error { |
|
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) |
|
c.name = name |
|
c.handler = fn |
|
c.replace = true |
|
c.processor.callbacks = append(c.processor.callbacks, c) |
|
return c.processor.compile() |
|
} |
|
|
|
// getRIndex get right index from string slice |
|
func getRIndex(strs []string, str string) int { |
|
for i := len(strs) - 1; i >= 0; i-- { |
|
if strs[i] == str { |
|
return i |
|
} |
|
} |
|
return -1 |
|
} |
|
|
|
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { |
|
var ( |
|
names, sorted []string |
|
sortCallback func(*callback) error |
|
) |
|
sort.Slice(cs, func(i, j int) bool { |
|
return cs[j].before == "*" || cs[j].after == "*" |
|
}) |
|
|
|
for _, c := range cs { |
|
// show warning message the callback name already exists |
|
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { |
|
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) |
|
} |
|
names = append(names, c.name) |
|
} |
|
|
|
sortCallback = func(c *callback) error { |
|
if c.before != "" { // if defined before callback |
|
if c.before == "*" && len(sorted) > 0 { |
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 { |
|
sorted = append([]string{c.name}, sorted...) |
|
} |
|
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { |
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 { |
|
// if before callback already sorted, append current callback just after it |
|
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) |
|
} else if curIdx > sortedIdx { |
|
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) |
|
} |
|
} else if idx := getRIndex(names, c.before); idx != -1 { |
|
// if before callback exists |
|
cs[idx].after = c.name |
|
} |
|
} |
|
|
|
if c.after != "" { // if defined after callback |
|
if c.after == "*" && len(sorted) > 0 { |
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 { |
|
sorted = append(sorted, c.name) |
|
} |
|
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { |
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 { |
|
// if after callback sorted, append current callback to last |
|
sorted = append(sorted, c.name) |
|
} else if curIdx < sortedIdx { |
|
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) |
|
} |
|
} else if idx := getRIndex(names, c.after); idx != -1 { |
|
// if after callback exists but haven't sorted |
|
// set after callback's before callback to current callback |
|
after := cs[idx] |
|
|
|
if after.before == "" { |
|
after.before = c.name |
|
} |
|
|
|
if err := sortCallback(after); err != nil { |
|
return err |
|
} |
|
|
|
if err := sortCallback(c); err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
|
|
// if current callback haven't been sorted, append it to last |
|
if getRIndex(sorted, c.name) == -1 { |
|
sorted = append(sorted, c.name) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
for _, c := range cs { |
|
if err = sortCallback(c); err != nil { |
|
return |
|
} |
|
} |
|
|
|
for _, name := range sorted { |
|
if idx := getRIndex(names, name); !cs[idx].remove { |
|
fns = append(fns, cs[idx].handler) |
|
} |
|
} |
|
|
|
return |
|
}
|
|
|