You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
478 lines
16 KiB
478 lines
16 KiB
3 years ago
|
package postgres
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
|
||
|
"gorm.io/gorm"
|
||
|
"gorm.io/gorm/clause"
|
||
|
"gorm.io/gorm/migrator"
|
||
|
"gorm.io/gorm/schema"
|
||
|
)
|
||
|
|
||
|
type Migrator struct {
|
||
|
migrator.Migrator
|
||
|
}
|
||
|
|
||
|
func (m Migrator) CurrentDatabase() (name string) {
|
||
|
m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||
|
for _, opt := range opts {
|
||
|
str := stmt.Quote(opt.DBName)
|
||
|
if opt.Expression != "" {
|
||
|
str = opt.Expression
|
||
|
}
|
||
|
|
||
|
if opt.Collate != "" {
|
||
|
str += " COLLATE " + opt.Collate
|
||
|
}
|
||
|
|
||
|
if opt.Sort != "" {
|
||
|
str += " " + opt.Sort
|
||
|
}
|
||
|
results = append(results, clause.Expr{SQL: str})
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||
|
var count int64
|
||
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||
|
name = idx.Name
|
||
|
}
|
||
|
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
|
||
|
return m.DB.Raw(
|
||
|
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
|
||
|
).Scan(&count).Error
|
||
|
})
|
||
|
|
||
|
return count > 0
|
||
|
}
|
||
|
|
||
|
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||
|
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||
|
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
|
||
|
|
||
|
createIndexSQL := "CREATE "
|
||
|
if idx.Class != "" {
|
||
|
createIndexSQL += idx.Class + " "
|
||
|
}
|
||
|
createIndexSQL += "INDEX "
|
||
|
|
||
|
if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
|
||
|
createIndexSQL += "CONCURRENTLY "
|
||
|
}
|
||
|
|
||
|
createIndexSQL += "IF NOT EXISTS ? ON ?"
|
||
|
|
||
|
if idx.Type != "" {
|
||
|
createIndexSQL += " USING " + idx.Type + "(?)"
|
||
|
} else {
|
||
|
createIndexSQL += " ?"
|
||
|
}
|
||
|
|
||
|
if idx.Where != "" {
|
||
|
createIndexSQL += " WHERE " + idx.Where
|
||
|
}
|
||
|
|
||
|
return m.DB.Exec(createIndexSQL, values...).Error
|
||
|
}
|
||
|
|
||
|
return fmt.Errorf("failed to create index with name %v", name)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
return m.DB.Exec(
|
||
|
"ALTER INDEX ? RENAME TO ?",
|
||
|
clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||
|
).Error
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||
|
name = idx.Name
|
||
|
}
|
||
|
|
||
|
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (m Migrator) GetTables() (tableList []string, err error) {
|
||
|
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
|
||
|
return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
|
||
|
}
|
||
|
|
||
|
func (m Migrator) CreateTable(values ...interface{}) (err error) {
|
||
|
if err = m.Migrator.CreateTable(values...); err != nil {
|
||
|
return
|
||
|
}
|
||
|
for _, value := range m.ReorderModels(values, false) {
|
||
|
if err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
for _, field := range stmt.Schema.FieldsByDBName {
|
||
|
if field.Comment != "" {
|
||
|
if err := m.DB.Exec(
|
||
|
"COMMENT ON COLUMN ?.? IS ?",
|
||
|
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
|
||
|
).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}); err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (m Migrator) HasTable(value interface{}) bool {
|
||
|
var count int64
|
||
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
|
||
|
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
|
||
|
})
|
||
|
return count > 0
|
||
|
}
|
||
|
|
||
|
func (m Migrator) DropTable(values ...interface{}) error {
|
||
|
values = m.ReorderModels(values, false)
|
||
|
tx := m.DB.Session(&gorm.Session{})
|
||
|
for i := len(values) - 1; i >= 0; i-- {
|
||
|
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||
|
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error
|
||
|
}); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m Migrator) AddColumn(value interface{}, field string) error {
|
||
|
if err := m.Migrator.AddColumn(value, field); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
if field := stmt.Schema.LookUpField(field); field != nil {
|
||
|
if field.Comment != "" {
|
||
|
if err := m.DB.Exec(
|
||
|
"COMMENT ON COLUMN ?.? IS ?",
|
||
|
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
|
||
|
).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||
|
var count int64
|
||
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
name := field
|
||
|
if stmt.Schema != nil {
|
||
|
if field := stmt.Schema.LookUpField(field); field != nil {
|
||
|
name = field.DBName
|
||
|
}
|
||
|
}
|
||
|
|
||
|
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
|
||
|
return m.DB.Raw(
|
||
|
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
|
||
|
currentSchema, curTable, name,
|
||
|
).Scan(&count).Error
|
||
|
})
|
||
|
|
||
|
return count > 0
|
||
|
}
|
||
|
|
||
|
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||
|
// skip primary field
|
||
|
if !field.PrimaryKey {
|
||
|
if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
var description string
|
||
|
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
|
||
|
values := []interface{}{currentSchema, curTable, field.DBName, stmt.Table, currentSchema}
|
||
|
checkSQL := "SELECT description FROM pg_catalog.pg_description "
|
||
|
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
|
||
|
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
|
||
|
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
|
||
|
m.DB.Raw(checkSQL, values...).Scan(&description)
|
||
|
comment := field.Comment
|
||
|
if comment != "" {
|
||
|
comment = comment[1 : len(comment)-1]
|
||
|
}
|
||
|
if field.Comment != "" && comment != description {
|
||
|
if err := m.DB.Exec(
|
||
|
"COMMENT ON COLUMN ?.? IS ?",
|
||
|
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
|
||
|
).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// AlterColumn alter value's `field` column' type based on schema definition
|
||
|
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
if field := stmt.Schema.LookUpField(field); field != nil {
|
||
|
var (
|
||
|
columnTypes, _ = m.DB.Migrator().ColumnTypes(value)
|
||
|
fieldColumnType *migrator.ColumnType
|
||
|
)
|
||
|
for _, columnType := range columnTypes {
|
||
|
if columnType.Name() == field.DBName {
|
||
|
fieldColumnType, _ = columnType.(*migrator.ColumnType)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return m.DB.Connection(func(tx *gorm.DB) error {
|
||
|
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
|
||
|
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
|
||
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if null, _ := fieldColumnType.Nullable(); null == field.NotNull {
|
||
|
if field.NotNull {
|
||
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else {
|
||
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if uniq, _ := fieldColumnType.Unique(); uniq != field.Unique {
|
||
|
idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)}
|
||
|
if err := tx.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if v, _ := fieldColumnType.DefaultValue(); v != field.DefaultValue {
|
||
|
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
|
||
|
if field.DefaultValueInterface != nil {
|
||
|
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||
|
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
|
||
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else if field.DefaultValue != "(-)" {
|
||
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else {
|
||
|
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
return fmt.Errorf("failed to look up field with name: %s", field)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||
|
var count int64
|
||
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||
|
currentSchema, curTable := m.CurrentSchema(stmt, table)
|
||
|
if constraint != nil {
|
||
|
name = constraint.Name
|
||
|
} else if chk != nil {
|
||
|
name = chk.Name
|
||
|
}
|
||
|
|
||
|
return m.DB.Raw(
|
||
|
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
|
||
|
currentSchema, curTable, name,
|
||
|
).Scan(&count).Error
|
||
|
})
|
||
|
|
||
|
return count > 0
|
||
|
}
|
||
|
|
||
|
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) {
|
||
|
columnTypes = make([]gorm.ColumnType, 0)
|
||
|
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||
|
var (
|
||
|
currentDatabase = m.DB.Migrator().CurrentDatabase()
|
||
|
currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
|
||
|
columns, err = m.DB.Raw(
|
||
|
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
|
||
|
currentDatabase, currentSchema, table).Rows()
|
||
|
)
|
||
|
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for columns.Next() {
|
||
|
var (
|
||
|
column = &migrator.ColumnType{
|
||
|
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||
|
UniqueValue: sql.NullBool{Valid: true},
|
||
|
}
|
||
|
datetimePrecision sql.NullInt64
|
||
|
radixValue sql.NullInt64
|
||
|
typeLenValue sql.NullInt64
|
||
|
)
|
||
|
|
||
|
err = columns.Scan(
|
||
|
&column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue,
|
||
|
&radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue,
|
||
|
)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if typeLenValue.Valid && typeLenValue.Int64 > 0 {
|
||
|
column.LengthValue = typeLenValue
|
||
|
}
|
||
|
|
||
|
if strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)") {
|
||
|
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
|
||
|
column.DefaultValueValue = sql.NullString{}
|
||
|
}
|
||
|
|
||
|
if column.DefaultValueValue.Valid {
|
||
|
column.DefaultValueValue.String = regexp.MustCompile("'(.*)'::[\\w]+$").ReplaceAllString(column.DefaultValueValue.String, "$1")
|
||
|
}
|
||
|
|
||
|
if datetimePrecision.Valid {
|
||
|
column.DecimalSizeValue = datetimePrecision
|
||
|
}
|
||
|
|
||
|
columnTypes = append(columnTypes, column)
|
||
|
}
|
||
|
columns.Close()
|
||
|
|
||
|
// assign sql column type
|
||
|
{
|
||
|
rows, rowsErr := m.GetRows(currentSchema, table)
|
||
|
if rowsErr != nil {
|
||
|
return rowsErr
|
||
|
}
|
||
|
rawColumnTypes, err := rows.ColumnTypes()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
for _, columnType := range columnTypes {
|
||
|
for _, c := range rawColumnTypes {
|
||
|
if c.Name() == columnType.Name() {
|
||
|
columnType.(*migrator.ColumnType).SQLColumnType = c
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
rows.Close()
|
||
|
}
|
||
|
|
||
|
// check primary, unique field
|
||
|
{
|
||
|
columnTypeRows, err := m.DB.Raw("SELECT c.column_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for columnTypeRows.Next() {
|
||
|
var name, columnType string
|
||
|
columnTypeRows.Scan(&name, &columnType)
|
||
|
for _, c := range columnTypes {
|
||
|
mc := c.(*migrator.ColumnType)
|
||
|
if mc.NameValue.String == name {
|
||
|
switch columnType {
|
||
|
case "PRIMARY KEY":
|
||
|
mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||
|
case "UNIQUE":
|
||
|
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
columnTypeRows.Close()
|
||
|
}
|
||
|
|
||
|
// check column type
|
||
|
{
|
||
|
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
|
||
|
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.relfilenode AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
|
||
|
WHERE a.attnum > 0 -- hide internal columns
|
||
|
AND NOT a.attisdropped -- hide deleted columns
|
||
|
AND b.relname = ?`, currentSchema, table).Rows()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for dataTypeRows.Next() {
|
||
|
var name, dataType string
|
||
|
dataTypeRows.Scan(&name, &dataType)
|
||
|
for _, c := range columnTypes {
|
||
|
mc := c.(*migrator.ColumnType)
|
||
|
if mc.NameValue.String == name {
|
||
|
mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true}
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
dataTypeRows.Close()
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
})
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (m Migrator) GetRows(currentSchema interface{}, table interface{}) (*sql.Rows, error) {
|
||
|
name := table.(string)
|
||
|
if currentSchema != nil {
|
||
|
if _, ok := currentSchema.(string); ok {
|
||
|
name = fmt.Sprintf("%v.%v", currentSchema, table)
|
||
|
}
|
||
|
}
|
||
|
return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Rows()
|
||
|
}
|
||
|
|
||
|
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}, interface{}) {
|
||
|
if strings.Contains(table, ".") {
|
||
|
if tables := strings.Split(table, `.`); len(tables) == 2 {
|
||
|
return tables[0], tables[1]
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if stmt.TableExpr != nil {
|
||
|
if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 {
|
||
|
return strings.TrimPrefix(tables[0], `"`), table
|
||
|
}
|
||
|
}
|
||
|
return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table
|
||
|
}
|