You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
477 lines
16 KiB
477 lines
16 KiB
package postgres |
|
|
|
import ( |
|
"database/sql" |
|
"fmt" |
|
"regexp" |
|
"strings" |
|
|
|
"gorm.io/gorm" |
|
"gorm.io/gorm/clause" |
|
"gorm.io/gorm/migrator" |
|
"gorm.io/gorm/schema" |
|
) |
|
|
|
type Migrator struct { |
|
migrator.Migrator |
|
} |
|
|
|
func (m Migrator) CurrentDatabase() (name string) { |
|
m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name) |
|
return |
|
} |
|
|
|
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { |
|
for _, opt := range opts { |
|
str := stmt.Quote(opt.DBName) |
|
if opt.Expression != "" { |
|
str = opt.Expression |
|
} |
|
|
|
if opt.Collate != "" { |
|
str += " COLLATE " + opt.Collate |
|
} |
|
|
|
if opt.Sort != "" { |
|
str += " " + opt.Sort |
|
} |
|
results = append(results, clause.Expr{SQL: str}) |
|
} |
|
return |
|
} |
|
|
|
func (m Migrator) HasIndex(value interface{}, name string) bool { |
|
var count int64 |
|
m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
if idx := stmt.Schema.LookIndex(name); idx != nil { |
|
name = idx.Name |
|
} |
|
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) |
|
return m.DB.Raw( |
|
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, |
|
).Scan(&count).Error |
|
}) |
|
|
|
return count > 0 |
|
} |
|
|
|
func (m Migrator) CreateIndex(value interface{}, name string) error { |
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
if idx := stmt.Schema.LookIndex(name); idx != nil { |
|
opts := m.BuildIndexOptions(idx.Fields, stmt) |
|
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} |
|
|
|
createIndexSQL := "CREATE " |
|
if idx.Class != "" { |
|
createIndexSQL += idx.Class + " " |
|
} |
|
createIndexSQL += "INDEX " |
|
|
|
if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { |
|
createIndexSQL += "CONCURRENTLY " |
|
} |
|
|
|
createIndexSQL += "IF NOT EXISTS ? ON ?" |
|
|
|
if idx.Type != "" { |
|
createIndexSQL += " USING " + idx.Type + "(?)" |
|
} else { |
|
createIndexSQL += " ?" |
|
} |
|
|
|
if idx.Where != "" { |
|
createIndexSQL += " WHERE " + idx.Where |
|
} |
|
|
|
return m.DB.Exec(createIndexSQL, values...).Error |
|
} |
|
|
|
return fmt.Errorf("failed to create index with name %v", name) |
|
}) |
|
} |
|
|
|
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { |
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
return m.DB.Exec( |
|
"ALTER INDEX ? RENAME TO ?", |
|
clause.Column{Name: oldName}, clause.Column{Name: newName}, |
|
).Error |
|
}) |
|
} |
|
|
|
func (m Migrator) DropIndex(value interface{}, name string) error { |
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
if idx := stmt.Schema.LookIndex(name); idx != nil { |
|
name = idx.Name |
|
} |
|
|
|
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error |
|
}) |
|
} |
|
|
|
func (m Migrator) GetTables() (tableList []string, err error) { |
|
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") |
|
return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error |
|
} |
|
|
|
func (m Migrator) CreateTable(values ...interface{}) (err error) { |
|
if err = m.Migrator.CreateTable(values...); err != nil { |
|
return |
|
} |
|
for _, value := range m.ReorderModels(values, false) { |
|
if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
for _, field := range stmt.Schema.FieldsByDBName { |
|
if field.Comment != "" { |
|
if err := m.DB.Exec( |
|
"COMMENT ON COLUMN ?.? IS ?", |
|
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), |
|
).Error; err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
return nil |
|
}); err != nil { |
|
return |
|
} |
|
} |
|
return |
|
} |
|
|
|
func (m Migrator) HasTable(value interface{}) bool { |
|
var count int64 |
|
m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) |
|
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error |
|
}) |
|
return count > 0 |
|
} |
|
|
|
func (m Migrator) DropTable(values ...interface{}) error { |
|
values = m.ReorderModels(values, false) |
|
tx := m.DB.Session(&gorm.Session{}) |
|
for i := len(values) - 1; i >= 0; i-- { |
|
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { |
|
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error |
|
}); err != nil { |
|
return err |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
func (m Migrator) AddColumn(value interface{}, field string) error { |
|
if err := m.Migrator.AddColumn(value, field); err != nil { |
|
return err |
|
} |
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
if field := stmt.Schema.LookUpField(field); field != nil { |
|
if field.Comment != "" { |
|
if err := m.DB.Exec( |
|
"COMMENT ON COLUMN ?.? IS ?", |
|
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), |
|
).Error; err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
return nil |
|
}) |
|
} |
|
|
|
func (m Migrator) HasColumn(value interface{}, field string) bool { |
|
var count int64 |
|
m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
name := field |
|
if stmt.Schema != nil { |
|
if field := stmt.Schema.LookUpField(field); field != nil { |
|
name = field.DBName |
|
} |
|
} |
|
|
|
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) |
|
return m.DB.Raw( |
|
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", |
|
currentSchema, curTable, name, |
|
).Scan(&count).Error |
|
}) |
|
|
|
return count > 0 |
|
} |
|
|
|
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { |
|
// skip primary field |
|
if !field.PrimaryKey { |
|
if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
var description string |
|
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) |
|
values := []interface{}{currentSchema, curTable, field.DBName, stmt.Table, currentSchema} |
|
checkSQL := "SELECT description FROM pg_catalog.pg_description " |
|
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " |
|
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " |
|
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" |
|
m.DB.Raw(checkSQL, values...).Scan(&description) |
|
comment := field.Comment |
|
if comment != "" { |
|
comment = comment[1 : len(comment)-1] |
|
} |
|
if field.Comment != "" && comment != description { |
|
if err := m.DB.Exec( |
|
"COMMENT ON COLUMN ?.? IS ?", |
|
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), |
|
).Error; err != nil { |
|
return err |
|
} |
|
} |
|
return nil |
|
}) |
|
} |
|
|
|
// AlterColumn alter value's `field` column' type based on schema definition |
|
func (m Migrator) AlterColumn(value interface{}, field string) error { |
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
if field := stmt.Schema.LookUpField(field); field != nil { |
|
var ( |
|
columnTypes, _ = m.DB.Migrator().ColumnTypes(value) |
|
fieldColumnType *migrator.ColumnType |
|
) |
|
for _, columnType := range columnTypes { |
|
if columnType.Name() == field.DBName { |
|
fieldColumnType, _ = columnType.(*migrator.ColumnType) |
|
} |
|
} |
|
|
|
return m.DB.Connection(func(tx *gorm.DB) error { |
|
fileType := clause.Expr{SQL: m.DataTypeOf(field)} |
|
if fieldColumnType.DatabaseTypeName() != fileType.SQL { |
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { |
|
return err |
|
} |
|
} |
|
|
|
if null, _ := fieldColumnType.Nullable(); null == field.NotNull { |
|
if field.NotNull { |
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { |
|
return err |
|
} |
|
} else { |
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
|
|
if uniq, _ := fieldColumnType.Unique(); uniq != field.Unique { |
|
idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)} |
|
if err := tx.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil { |
|
return err |
|
} |
|
} |
|
|
|
if v, _ := fieldColumnType.DefaultValue(); v != field.DefaultValue { |
|
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { |
|
if field.DefaultValueInterface != nil { |
|
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} |
|
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) |
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil { |
|
return err |
|
} |
|
} else if field.DefaultValue != "(-)" { |
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { |
|
return err |
|
} |
|
} else { |
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
} |
|
return nil |
|
}) |
|
} |
|
return fmt.Errorf("failed to look up field with name: %s", field) |
|
}) |
|
} |
|
|
|
func (m Migrator) HasConstraint(value interface{}, name string) bool { |
|
var count int64 |
|
m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
constraint, chk, table := m.GuessConstraintAndTable(stmt, name) |
|
currentSchema, curTable := m.CurrentSchema(stmt, table) |
|
if constraint != nil { |
|
name = constraint.Name |
|
} else if chk != nil { |
|
name = chk.Name |
|
} |
|
|
|
return m.DB.Raw( |
|
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", |
|
currentSchema, curTable, name, |
|
).Scan(&count).Error |
|
}) |
|
|
|
return count > 0 |
|
} |
|
|
|
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { |
|
columnTypes = make([]gorm.ColumnType, 0) |
|
err = m.RunWithValue(value, func(stmt *gorm.Statement) error { |
|
var ( |
|
currentDatabase = m.DB.Migrator().CurrentDatabase() |
|
currentSchema, table = m.CurrentSchema(stmt, stmt.Table) |
|
columns, err = m.DB.Raw( |
|
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", |
|
currentDatabase, currentSchema, table).Rows() |
|
) |
|
|
|
if err != nil { |
|
return err |
|
} |
|
|
|
for columns.Next() { |
|
var ( |
|
column = &migrator.ColumnType{ |
|
PrimaryKeyValue: sql.NullBool{Valid: true}, |
|
UniqueValue: sql.NullBool{Valid: true}, |
|
} |
|
datetimePrecision sql.NullInt64 |
|
radixValue sql.NullInt64 |
|
typeLenValue sql.NullInt64 |
|
) |
|
|
|
err = columns.Scan( |
|
&column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue, |
|
&radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue, |
|
) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if typeLenValue.Valid && typeLenValue.Int64 > 0 { |
|
column.LengthValue = typeLenValue |
|
} |
|
|
|
if strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)") { |
|
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} |
|
column.DefaultValueValue = sql.NullString{} |
|
} |
|
|
|
if column.DefaultValueValue.Valid { |
|
column.DefaultValueValue.String = regexp.MustCompile("'(.*)'::[\\w]+$").ReplaceAllString(column.DefaultValueValue.String, "$1") |
|
} |
|
|
|
if datetimePrecision.Valid { |
|
column.DecimalSizeValue = datetimePrecision |
|
} |
|
|
|
columnTypes = append(columnTypes, column) |
|
} |
|
columns.Close() |
|
|
|
// assign sql column type |
|
{ |
|
rows, rowsErr := m.GetRows(currentSchema, table) |
|
if rowsErr != nil { |
|
return rowsErr |
|
} |
|
rawColumnTypes, err := rows.ColumnTypes() |
|
if err != nil { |
|
return err |
|
} |
|
for _, columnType := range columnTypes { |
|
for _, c := range rawColumnTypes { |
|
if c.Name() == columnType.Name() { |
|
columnType.(*migrator.ColumnType).SQLColumnType = c |
|
break |
|
} |
|
} |
|
} |
|
rows.Close() |
|
} |
|
|
|
// check primary, unique field |
|
{ |
|
columnTypeRows, err := m.DB.Raw("SELECT c.column_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
for columnTypeRows.Next() { |
|
var name, columnType string |
|
columnTypeRows.Scan(&name, &columnType) |
|
for _, c := range columnTypes { |
|
mc := c.(*migrator.ColumnType) |
|
if mc.NameValue.String == name { |
|
switch columnType { |
|
case "PRIMARY KEY": |
|
mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} |
|
case "UNIQUE": |
|
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} |
|
} |
|
break |
|
} |
|
} |
|
} |
|
columnTypeRows.Close() |
|
} |
|
|
|
// check column type |
|
{ |
|
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type |
|
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.relfilenode AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) |
|
WHERE a.attnum > 0 -- hide internal columns |
|
AND NOT a.attisdropped -- hide deleted columns |
|
AND b.relname = ?`, currentSchema, table).Rows() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
for dataTypeRows.Next() { |
|
var name, dataType string |
|
dataTypeRows.Scan(&name, &dataType) |
|
for _, c := range columnTypes { |
|
mc := c.(*migrator.ColumnType) |
|
if mc.NameValue.String == name { |
|
mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true} |
|
break |
|
} |
|
} |
|
} |
|
dataTypeRows.Close() |
|
} |
|
|
|
return err |
|
}) |
|
return |
|
} |
|
|
|
func (m Migrator) GetRows(currentSchema interface{}, table interface{}) (*sql.Rows, error) { |
|
name := table.(string) |
|
if currentSchema != nil { |
|
if _, ok := currentSchema.(string); ok { |
|
name = fmt.Sprintf("%v.%v", currentSchema, table) |
|
} |
|
} |
|
return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Rows() |
|
} |
|
|
|
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}, interface{}) { |
|
if strings.Contains(table, ".") { |
|
if tables := strings.Split(table, `.`); len(tables) == 2 { |
|
return tables[0], tables[1] |
|
} |
|
} |
|
|
|
if stmt.TableExpr != nil { |
|
if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 { |
|
return strings.TrimPrefix(tables[0], `"`), table |
|
} |
|
} |
|
return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table |
|
}
|
|
|