You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
323 lines
10 KiB
323 lines
10 KiB
package schema |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"fmt" |
|
"go/ast" |
|
"reflect" |
|
"sync" |
|
|
|
"gorm.io/gorm/clause" |
|
"gorm.io/gorm/logger" |
|
) |
|
|
|
// ErrUnsupportedDataType unsupported data type |
|
var ErrUnsupportedDataType = errors.New("unsupported data type") |
|
|
|
type Schema struct { |
|
Name string |
|
ModelType reflect.Type |
|
Table string |
|
PrioritizedPrimaryField *Field |
|
DBNames []string |
|
PrimaryFields []*Field |
|
PrimaryFieldDBNames []string |
|
Fields []*Field |
|
FieldsByName map[string]*Field |
|
FieldsByDBName map[string]*Field |
|
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database |
|
Relationships Relationships |
|
CreateClauses []clause.Interface |
|
QueryClauses []clause.Interface |
|
UpdateClauses []clause.Interface |
|
DeleteClauses []clause.Interface |
|
BeforeCreate, AfterCreate bool |
|
BeforeUpdate, AfterUpdate bool |
|
BeforeDelete, AfterDelete bool |
|
BeforeSave, AfterSave bool |
|
AfterFind bool |
|
err error |
|
initialized chan struct{} |
|
namer Namer |
|
cacheStore *sync.Map |
|
} |
|
|
|
func (schema Schema) String() string { |
|
if schema.ModelType.Name() == "" { |
|
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) |
|
} |
|
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) |
|
} |
|
|
|
func (schema Schema) MakeSlice() reflect.Value { |
|
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) |
|
results := reflect.New(slice.Type()) |
|
results.Elem().Set(slice) |
|
return results |
|
} |
|
|
|
func (schema Schema) LookUpField(name string) *Field { |
|
if field, ok := schema.FieldsByDBName[name]; ok { |
|
return field |
|
} |
|
if field, ok := schema.FieldsByName[name]; ok { |
|
return field |
|
} |
|
return nil |
|
} |
|
|
|
type Tabler interface { |
|
TableName() string |
|
} |
|
|
|
// Parse get data type from dialector |
|
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { |
|
return ParseWithSpecialTableName(dest, cacheStore, namer, "") |
|
} |
|
|
|
// ParseWithSpecialTableName get data type from dialector with extra schema table |
|
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { |
|
if dest == nil { |
|
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) |
|
} |
|
|
|
value := reflect.ValueOf(dest) |
|
if value.Kind() == reflect.Ptr && value.IsNil() { |
|
value = reflect.New(value.Type().Elem()) |
|
} |
|
modelType := reflect.Indirect(value).Type() |
|
|
|
if modelType.Kind() == reflect.Interface { |
|
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() |
|
} |
|
|
|
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { |
|
modelType = modelType.Elem() |
|
} |
|
|
|
if modelType.Kind() != reflect.Struct { |
|
if modelType.PkgPath() == "" { |
|
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) |
|
} |
|
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) |
|
} |
|
|
|
// Cache the Schema for performance, |
|
// Use the modelType or modelType + schemaTable (if it present) as cache key. |
|
var schemaCacheKey interface{} |
|
if specialTableName != "" { |
|
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) |
|
} else { |
|
schemaCacheKey = modelType |
|
} |
|
|
|
// Load exist schmema cache, return if exists |
|
if v, ok := cacheStore.Load(schemaCacheKey); ok { |
|
s := v.(*Schema) |
|
// Wait for the initialization of other goroutines to complete |
|
<-s.initialized |
|
return s, s.err |
|
} |
|
|
|
modelValue := reflect.New(modelType) |
|
tableName := namer.TableName(modelType.Name()) |
|
if tabler, ok := modelValue.Interface().(Tabler); ok { |
|
tableName = tabler.TableName() |
|
} |
|
if en, ok := namer.(embeddedNamer); ok { |
|
tableName = en.Table |
|
} |
|
if specialTableName != "" && specialTableName != tableName { |
|
tableName = specialTableName |
|
} |
|
|
|
schema := &Schema{ |
|
Name: modelType.Name(), |
|
ModelType: modelType, |
|
Table: tableName, |
|
FieldsByName: map[string]*Field{}, |
|
FieldsByDBName: map[string]*Field{}, |
|
Relationships: Relationships{Relations: map[string]*Relationship{}}, |
|
cacheStore: cacheStore, |
|
namer: namer, |
|
initialized: make(chan struct{}), |
|
} |
|
// When the schema initialization is completed, the channel will be closed |
|
defer close(schema.initialized) |
|
|
|
// Load exist schmema cache, return if exists |
|
if v, ok := cacheStore.Load(schemaCacheKey); ok { |
|
s := v.(*Schema) |
|
// Wait for the initialization of other goroutines to complete |
|
<-s.initialized |
|
return s, s.err |
|
} |
|
|
|
for i := 0; i < modelType.NumField(); i++ { |
|
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { |
|
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { |
|
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) |
|
} else { |
|
schema.Fields = append(schema.Fields, field) |
|
} |
|
} |
|
} |
|
|
|
for _, field := range schema.Fields { |
|
if field.DBName == "" && field.DataType != "" { |
|
field.DBName = namer.ColumnName(schema.Table, field.Name) |
|
} |
|
|
|
if field.DBName != "" { |
|
// nonexistence or shortest path or first appear prioritized if has permission |
|
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { |
|
if _, ok := schema.FieldsByDBName[field.DBName]; !ok { |
|
schema.DBNames = append(schema.DBNames, field.DBName) |
|
} |
|
schema.FieldsByDBName[field.DBName] = field |
|
schema.FieldsByName[field.Name] = field |
|
|
|
if v != nil && v.PrimaryKey { |
|
for idx, f := range schema.PrimaryFields { |
|
if f == v { |
|
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) |
|
} |
|
} |
|
} |
|
|
|
if field.PrimaryKey { |
|
schema.PrimaryFields = append(schema.PrimaryFields, field) |
|
} |
|
} |
|
} |
|
|
|
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { |
|
schema.FieldsByName[field.Name] = field |
|
} |
|
|
|
field.setupValuerAndSetter() |
|
} |
|
|
|
prioritizedPrimaryField := schema.LookUpField("id") |
|
if prioritizedPrimaryField == nil { |
|
prioritizedPrimaryField = schema.LookUpField("ID") |
|
} |
|
|
|
if prioritizedPrimaryField != nil { |
|
if prioritizedPrimaryField.PrimaryKey { |
|
schema.PrioritizedPrimaryField = prioritizedPrimaryField |
|
} else if len(schema.PrimaryFields) == 0 { |
|
prioritizedPrimaryField.PrimaryKey = true |
|
schema.PrioritizedPrimaryField = prioritizedPrimaryField |
|
schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField) |
|
} |
|
} |
|
|
|
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { |
|
schema.PrioritizedPrimaryField = schema.PrimaryFields[0] |
|
} |
|
|
|
for _, field := range schema.PrimaryFields { |
|
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) |
|
} |
|
|
|
for _, field := range schema.Fields { |
|
if field.HasDefaultValue && field.DefaultValueInterface == nil { |
|
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) |
|
} |
|
} |
|
|
|
if field := schema.PrioritizedPrimaryField; field != nil { |
|
switch field.GORMDataType { |
|
case Int, Uint: |
|
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { |
|
if !field.HasDefaultValue || field.DefaultValueInterface != nil { |
|
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) |
|
} |
|
|
|
field.HasDefaultValue = true |
|
field.AutoIncrement = true |
|
} |
|
} |
|
} |
|
|
|
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} |
|
for _, name := range callbacks { |
|
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { |
|
switch methodValue.Type().String() { |
|
case "func(*gorm.DB) error": // TODO hack |
|
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) |
|
default: |
|
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) |
|
} |
|
} |
|
} |
|
|
|
// Cache the schema |
|
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { |
|
s := v.(*Schema) |
|
// Wait for the initialization of other goroutines to complete |
|
<-s.initialized |
|
return s, s.err |
|
} |
|
|
|
defer func() { |
|
if schema.err != nil { |
|
logger.Default.Error(context.Background(), schema.err.Error()) |
|
cacheStore.Delete(modelType) |
|
} |
|
}() |
|
|
|
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { |
|
for _, field := range schema.Fields { |
|
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { |
|
if schema.parseRelation(field); schema.err != nil { |
|
return schema, schema.err |
|
} else { |
|
schema.FieldsByName[field.Name] = field |
|
} |
|
} |
|
|
|
fieldValue := reflect.New(field.IndirectFieldType) |
|
fieldInterface := fieldValue.Interface() |
|
if fc, ok := fieldInterface.(CreateClausesInterface); ok { |
|
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) |
|
} |
|
|
|
if fc, ok := fieldInterface.(QueryClausesInterface); ok { |
|
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) |
|
} |
|
|
|
if fc, ok := fieldInterface.(UpdateClausesInterface); ok { |
|
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) |
|
} |
|
|
|
if fc, ok := fieldInterface.(DeleteClausesInterface); ok { |
|
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) |
|
} |
|
} |
|
} |
|
|
|
return schema, schema.err |
|
} |
|
|
|
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { |
|
modelType := reflect.ValueOf(dest).Type() |
|
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { |
|
modelType = modelType.Elem() |
|
} |
|
|
|
if modelType.Kind() != reflect.Struct { |
|
if modelType.PkgPath() == "" { |
|
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) |
|
} |
|
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) |
|
} |
|
|
|
if v, ok := cacheStore.Load(modelType); ok { |
|
return v.(*Schema), nil |
|
} |
|
|
|
return Parse(dest, cacheStore, namer) |
|
}
|
|
|