You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
522 lines
19 KiB
522 lines
19 KiB
package gorm |
|
|
|
import ( |
|
"fmt" |
|
"reflect" |
|
"strings" |
|
|
|
"gorm.io/gorm/clause" |
|
"gorm.io/gorm/schema" |
|
"gorm.io/gorm/utils" |
|
) |
|
|
|
// Association Mode contains some helper methods to handle relationship things easily. |
|
type Association struct { |
|
DB *DB |
|
Relationship *schema.Relationship |
|
Error error |
|
} |
|
|
|
func (db *DB) Association(column string) *Association { |
|
association := &Association{DB: db} |
|
table := db.Statement.Table |
|
|
|
if err := db.Statement.Parse(db.Statement.Model); err == nil { |
|
db.Statement.Table = table |
|
association.Relationship = db.Statement.Schema.Relationships.Relations[column] |
|
|
|
if association.Relationship == nil { |
|
association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) |
|
} |
|
|
|
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) |
|
for db.Statement.ReflectValue.Kind() == reflect.Ptr { |
|
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() |
|
} |
|
} else { |
|
association.Error = err |
|
} |
|
|
|
return association |
|
} |
|
|
|
func (association *Association) Find(out interface{}, conds ...interface{}) error { |
|
if association.Error == nil { |
|
association.Error = association.buildCondition().Find(out, conds...).Error |
|
} |
|
return association.Error |
|
} |
|
|
|
func (association *Association) Append(values ...interface{}) error { |
|
if association.Error == nil { |
|
switch association.Relationship.Type { |
|
case schema.HasOne, schema.BelongsTo: |
|
if len(values) > 0 { |
|
association.Error = association.Replace(values...) |
|
} |
|
default: |
|
association.saveAssociation( /*clear*/ false, values...) |
|
} |
|
} |
|
|
|
return association.Error |
|
} |
|
|
|
func (association *Association) Replace(values ...interface{}) error { |
|
if association.Error == nil { |
|
// save associations |
|
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { |
|
return association.Error |
|
} |
|
|
|
// set old associations's foreign key to null |
|
reflectValue := association.DB.Statement.ReflectValue |
|
rel := association.Relationship |
|
switch rel.Type { |
|
case schema.BelongsTo: |
|
if len(values) == 0 { |
|
updateMap := map[string]interface{}{} |
|
switch reflectValue.Kind() { |
|
case reflect.Slice, reflect.Array: |
|
for i := 0; i < reflectValue.Len(); i++ { |
|
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) |
|
} |
|
case reflect.Struct: |
|
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) |
|
} |
|
|
|
for _, ref := range rel.References { |
|
updateMap[ref.ForeignKey.DBName] = nil |
|
} |
|
|
|
association.Error = association.DB.UpdateColumns(updateMap).Error |
|
} |
|
case schema.HasOne, schema.HasMany: |
|
var ( |
|
primaryFields []*schema.Field |
|
foreignKeys []string |
|
updateMap = map[string]interface{}{} |
|
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) |
|
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() |
|
tx = association.DB.Model(modelValue) |
|
) |
|
|
|
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { |
|
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { |
|
tx.Not(clause.IN{Column: column, Values: values}) |
|
} |
|
} |
|
|
|
for _, ref := range rel.References { |
|
if ref.OwnPrimaryKey { |
|
primaryFields = append(primaryFields, ref.PrimaryKey) |
|
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) |
|
updateMap[ref.ForeignKey.DBName] = nil |
|
} else if ref.PrimaryValue != "" { |
|
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) |
|
} |
|
} |
|
|
|
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { |
|
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) |
|
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error |
|
} |
|
case schema.Many2Many: |
|
var ( |
|
primaryFields, relPrimaryFields []*schema.Field |
|
joinPrimaryKeys, joinRelPrimaryKeys []string |
|
modelValue = reflect.New(rel.JoinTable.ModelType).Interface() |
|
tx = association.DB.Model(modelValue) |
|
) |
|
|
|
for _, ref := range rel.References { |
|
if ref.PrimaryValue == "" { |
|
if ref.OwnPrimaryKey { |
|
primaryFields = append(primaryFields, ref.PrimaryKey) |
|
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) |
|
} else { |
|
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) |
|
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) |
|
} |
|
} else { |
|
tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) |
|
} |
|
} |
|
|
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) |
|
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { |
|
tx.Where(clause.IN{Column: column, Values: values}) |
|
} else { |
|
return ErrPrimaryKeyRequired |
|
} |
|
|
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) |
|
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { |
|
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) |
|
} |
|
|
|
association.Error = tx.Delete(modelValue).Error |
|
} |
|
} |
|
return association.Error |
|
} |
|
|
|
func (association *Association) Delete(values ...interface{}) error { |
|
if association.Error == nil { |
|
var ( |
|
reflectValue = association.DB.Statement.ReflectValue |
|
rel = association.Relationship |
|
primaryFields []*schema.Field |
|
foreignKeys []string |
|
updateAttrs = map[string]interface{}{} |
|
conds []clause.Expression |
|
) |
|
|
|
for _, ref := range rel.References { |
|
if ref.PrimaryValue == "" { |
|
primaryFields = append(primaryFields, ref.PrimaryKey) |
|
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) |
|
updateAttrs[ref.ForeignKey.DBName] = nil |
|
} else { |
|
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) |
|
} |
|
} |
|
|
|
switch rel.Type { |
|
case schema.BelongsTo: |
|
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) |
|
|
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) |
|
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { |
|
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) |
|
} else { |
|
return ErrPrimaryKeyRequired |
|
} |
|
|
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) |
|
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) |
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) |
|
|
|
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error |
|
case schema.HasOne, schema.HasMany: |
|
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) |
|
|
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) |
|
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { |
|
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) |
|
} else { |
|
return ErrPrimaryKeyRequired |
|
} |
|
|
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) |
|
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) |
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) |
|
|
|
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error |
|
case schema.Many2Many: |
|
var ( |
|
primaryFields, relPrimaryFields []*schema.Field |
|
joinPrimaryKeys, joinRelPrimaryKeys []string |
|
joinValue = reflect.New(rel.JoinTable.ModelType).Interface() |
|
) |
|
|
|
for _, ref := range rel.References { |
|
if ref.PrimaryValue == "" { |
|
if ref.OwnPrimaryKey { |
|
primaryFields = append(primaryFields, ref.PrimaryKey) |
|
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) |
|
} else { |
|
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) |
|
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) |
|
} |
|
} else { |
|
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) |
|
} |
|
} |
|
|
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) |
|
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 { |
|
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) |
|
} else { |
|
return ErrPrimaryKeyRequired |
|
} |
|
|
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) |
|
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) |
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) |
|
|
|
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error |
|
} |
|
|
|
if association.Error == nil { |
|
// clean up deleted values's foreign key |
|
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) |
|
|
|
cleanUpDeletedRelations := func(data reflect.Value) { |
|
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { |
|
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) |
|
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) |
|
|
|
switch fieldValue.Kind() { |
|
case reflect.Slice, reflect.Array: |
|
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) |
|
for i := 0; i < fieldValue.Len(); i++ { |
|
for idx, field := range rel.FieldSchema.PrimaryFields { |
|
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) |
|
} |
|
|
|
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { |
|
validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) |
|
} |
|
} |
|
|
|
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) |
|
case reflect.Struct: |
|
for idx, field := range rel.FieldSchema.PrimaryFields { |
|
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) |
|
} |
|
|
|
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { |
|
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { |
|
break |
|
} |
|
|
|
if rel.JoinTable == nil { |
|
for _, ref := range rel.References { |
|
if ref.OwnPrimaryKey || ref.PrimaryValue != "" { |
|
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) |
|
} else { |
|
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
switch reflectValue.Kind() { |
|
case reflect.Slice, reflect.Array: |
|
for i := 0; i < reflectValue.Len(); i++ { |
|
cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) |
|
} |
|
case reflect.Struct: |
|
cleanUpDeletedRelations(reflectValue) |
|
} |
|
} |
|
} |
|
|
|
return association.Error |
|
} |
|
|
|
func (association *Association) Clear() error { |
|
return association.Replace() |
|
} |
|
|
|
func (association *Association) Count() (count int64) { |
|
if association.Error == nil { |
|
association.Error = association.buildCondition().Count(&count).Error |
|
} |
|
return |
|
} |
|
|
|
type assignBack struct { |
|
Source reflect.Value |
|
Index int |
|
Dest reflect.Value |
|
} |
|
|
|
func (association *Association) saveAssociation(clear bool, values ...interface{}) { |
|
var ( |
|
reflectValue = association.DB.Statement.ReflectValue |
|
assignBacks []assignBack // assign association values back to arguments after save |
|
) |
|
|
|
appendToRelations := func(source, rv reflect.Value, clear bool) { |
|
switch association.Relationship.Type { |
|
case schema.HasOne, schema.BelongsTo: |
|
switch rv.Kind() { |
|
case reflect.Slice, reflect.Array: |
|
if rv.Len() > 0 { |
|
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) |
|
|
|
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { |
|
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) |
|
} |
|
} |
|
case reflect.Struct: |
|
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) |
|
|
|
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { |
|
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) |
|
} |
|
} |
|
case schema.HasMany, schema.Many2Many: |
|
elemType := association.Relationship.Field.IndirectFieldType.Elem() |
|
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) |
|
if clear { |
|
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() |
|
} |
|
|
|
appendToFieldValues := func(ev reflect.Value) { |
|
if ev.Type().AssignableTo(elemType) { |
|
fieldValue = reflect.Append(fieldValue, ev) |
|
} else if ev.Type().Elem().AssignableTo(elemType) { |
|
fieldValue = reflect.Append(fieldValue, ev.Elem()) |
|
} else { |
|
association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) |
|
} |
|
|
|
if elemType.Kind() == reflect.Struct { |
|
assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()}) |
|
} |
|
} |
|
|
|
switch rv.Kind() { |
|
case reflect.Slice, reflect.Array: |
|
for i := 0; i < rv.Len(); i++ { |
|
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) |
|
} |
|
case reflect.Struct: |
|
appendToFieldValues(rv.Addr()) |
|
} |
|
|
|
if association.Error == nil { |
|
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) |
|
} |
|
} |
|
} |
|
|
|
selectedSaveColumns := []string{association.Relationship.Name} |
|
omitColumns := []string{} |
|
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) |
|
for name, ok := range selectColumns { |
|
columnName := "" |
|
if strings.HasPrefix(name, association.Relationship.Name) { |
|
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" { |
|
columnName = name |
|
} |
|
} else if strings.HasPrefix(name, clause.Associations) { |
|
columnName = name |
|
} |
|
|
|
if columnName != "" { |
|
if ok { |
|
selectedSaveColumns = append(selectedSaveColumns, columnName) |
|
} else { |
|
omitColumns = append(omitColumns, columnName) |
|
} |
|
} |
|
} |
|
|
|
for _, ref := range association.Relationship.References { |
|
if !ref.OwnPrimaryKey { |
|
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) |
|
} |
|
} |
|
|
|
associationDB := association.DB.Session(&Session{}).Model(nil) |
|
if !association.DB.FullSaveAssociations { |
|
associationDB.Select(selectedSaveColumns) |
|
} |
|
if len(omitColumns) > 0 { |
|
associationDB.Omit(omitColumns...) |
|
} |
|
associationDB = associationDB.Session(&Session{}) |
|
|
|
switch reflectValue.Kind() { |
|
case reflect.Slice, reflect.Array: |
|
if len(values) != reflectValue.Len() { |
|
// clear old data |
|
if clear && len(values) == 0 { |
|
for i := 0; i < reflectValue.Len(); i++ { |
|
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { |
|
association.Error = err |
|
break |
|
} |
|
|
|
if association.Relationship.JoinTable == nil { |
|
for _, ref := range association.Relationship.References { |
|
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { |
|
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { |
|
association.Error = err |
|
break |
|
} |
|
} |
|
} |
|
} |
|
} |
|
break |
|
} |
|
|
|
association.Error = ErrInvalidValueOfLength |
|
return |
|
} |
|
|
|
for i := 0; i < reflectValue.Len(); i++ { |
|
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) |
|
|
|
// TODO support save slice data, sql with case? |
|
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error |
|
} |
|
case reflect.Struct: |
|
// clear old data |
|
if clear && len(values) == 0 { |
|
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) |
|
|
|
if association.Relationship.JoinTable == nil && association.Error == nil { |
|
for _, ref := range association.Relationship.References { |
|
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { |
|
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) |
|
} |
|
} |
|
} |
|
} |
|
|
|
for idx, value := range values { |
|
rv := reflect.Indirect(reflect.ValueOf(value)) |
|
appendToRelations(reflectValue, rv, clear && idx == 0) |
|
} |
|
|
|
if len(values) > 0 { |
|
association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error |
|
} |
|
} |
|
|
|
for _, assignBack := range assignBacks { |
|
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) |
|
if assignBack.Index > 0 { |
|
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) |
|
} else { |
|
reflect.Indirect(assignBack.Dest).Set(fieldValue) |
|
} |
|
} |
|
} |
|
|
|
func (association *Association) buildCondition() *DB { |
|
var ( |
|
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) |
|
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() |
|
tx = association.DB.Model(modelValue) |
|
) |
|
|
|
if association.Relationship.JoinTable != nil { |
|
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { |
|
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} |
|
for _, queryClause := range association.Relationship.JoinTable.QueryClauses { |
|
joinStmt.AddClause(queryClause) |
|
} |
|
joinStmt.Build("WHERE") |
|
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) |
|
} |
|
|
|
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ |
|
Table: clause.Table{Name: association.Relationship.JoinTable.Table}, |
|
ON: clause.Where{Exprs: queryConds}, |
|
}}}) |
|
} else { |
|
tx.Clauses(clause.Where{Exprs: queryConds}) |
|
} |
|
|
|
return tx |
|
}
|
|
|