You've already forked openaccounting-server
forked from cybercinch/openaccounting-server
deps: update dependencies for GORM, Viper, and SQLite support
- Add GORM v1.25.12 with MySQL and SQLite drivers - Add Viper v1.19.0 for configuration management - Add UUID package for GORM model IDs - Update vendor directory with new dependencies - Update Go module requirements and checksums 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
6
vendor/gorm.io/driver/mysql/.gitignore
generated
vendored
Normal file
6
vendor/gorm.io/driver/mysql/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
TODO*
|
||||
documents
|
||||
coverage.txt
|
||||
_book
|
||||
.idea
|
||||
vendor
|
||||
21
vendor/gorm.io/driver/mysql/License
generated
vendored
Normal file
21
vendor/gorm.io/driver/mysql/License
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
52
vendor/gorm.io/driver/mysql/README.md
generated
vendored
Normal file
52
vendor/gorm.io/driver/mysql/README.md
generated
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
# GORM MySQL Driver
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import (
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// https://github.com/go-sql-driver/mysql
|
||||
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
import (
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var datetimePrecision = 2
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{
|
||||
DSN: "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local", // data source name, refer https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
DefaultStringSize: 256, // add default size for string fields, by default, will use db type `longtext` for fields without size, not a primary key, no index defined and don't have default values
|
||||
DisableDatetimePrecision: true, // disable datetime precision support, which not supported before MySQL 5.6
|
||||
DefaultDatetimePrecision: &datetimePrecision, // default datetime precision
|
||||
DontSupportRenameIndex: true, // drop & create index when rename index, rename index not supported before MySQL 5.7, MariaDB
|
||||
DontSupportRenameColumn: true, // use change when rename column, rename rename not supported before MySQL 8, MariaDB
|
||||
SkipInitializeWithVersion: false, // smart configure based on used version
|
||||
}), &gorm.Config{})
|
||||
```
|
||||
|
||||
## Customized Driver
|
||||
|
||||
```go
|
||||
import (
|
||||
_ "example.com/my_mysql_driver"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/driver/mysql"
|
||||
)
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{
|
||||
DriverName: "my_mysql_driver_name",
|
||||
DSN: "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local", // data source name, refer https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
})
|
||||
```
|
||||
|
||||
Checkout [https://gorm.io](https://gorm.io) for details.
|
||||
25
vendor/gorm.io/driver/mysql/error_translator.go
generated
vendored
Normal file
25
vendor/gorm.io/driver/mysql/error_translator.go
generated
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"github.com/go-sql-driver/mysql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// The error codes to map mysql errors to gorm errors, here is the mysql error codes reference https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html.
|
||||
var errCodes = map[uint16]error{
|
||||
1062: gorm.ErrDuplicatedKey,
|
||||
1451: gorm.ErrForeignKeyViolated,
|
||||
1452: gorm.ErrForeignKeyViolated,
|
||||
}
|
||||
|
||||
func (dialector Dialector) Translate(err error) error {
|
||||
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
|
||||
if translatedErr, found := errCodes[mysqlErr.Number]; found {
|
||||
return translatedErr
|
||||
}
|
||||
return mysqlErr
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
518
vendor/gorm.io/driver/mysql/migrator.go
generated
vendored
Normal file
518
vendor/gorm.io/driver/mysql/migrator.go
generated
vendored
Normal file
@@ -0,0 +1,518 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
const indexSql = `
|
||||
SELECT
|
||||
TABLE_NAME,
|
||||
COLUMN_NAME,
|
||||
INDEX_NAME,
|
||||
NON_UNIQUE
|
||||
FROM
|
||||
information_schema.STATISTICS
|
||||
WHERE
|
||||
TABLE_SCHEMA = ?
|
||||
AND TABLE_NAME = ?
|
||||
ORDER BY
|
||||
INDEX_NAME,
|
||||
SEQ_IN_INDEX`
|
||||
|
||||
var typeAliasMap = map[string][]string{
|
||||
"bool": {"tinyint"},
|
||||
"tinyint": {"bool"},
|
||||
}
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
Dialector
|
||||
}
|
||||
|
||||
func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
|
||||
expr := m.Migrator.FullDataTypeOf(field)
|
||||
|
||||
if value, ok := field.TagSettings["COMMENT"]; ok {
|
||||
expr.SQL += " COMMENT " + m.Dialector.Explain("?", value)
|
||||
}
|
||||
|
||||
return expr
|
||||
}
|
||||
|
||||
// MigrateColumnUnique migrate column's UNIQUE constraint.
|
||||
// In MySQL, ColumnType's Unique is affected by UniqueIndex, so we have to take care of the UniqueIndex.
|
||||
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
unique, ok := columnType.Unique()
|
||||
if !ok || field.PrimaryKey {
|
||||
return nil // skip primary key
|
||||
}
|
||||
|
||||
queryTx, execTx := m.GetQueryAndExecTx()
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// We're currently only receiving boolean values on `Unique` tag,
|
||||
// so the UniqueConstraint name is fixed
|
||||
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
|
||||
if unique {
|
||||
// Clean up redundant unique indexes
|
||||
indexes, _ := queryTx.Migrator().GetIndexes(value)
|
||||
for _, index := range indexes {
|
||||
if uni, ok := index.Unique(); !ok || !uni {
|
||||
continue
|
||||
}
|
||||
if columns := index.Columns(); len(columns) != 1 || columns[0] != field.DBName {
|
||||
continue
|
||||
}
|
||||
if name := index.Name(); name == constraint || name == field.UniqueIndex {
|
||||
continue
|
||||
}
|
||||
if err := execTx.Migrator().DropIndex(value, index.Name()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hasConstraint := queryTx.Migrator().HasConstraint(value, constraint)
|
||||
switch {
|
||||
case field.Unique && !hasConstraint:
|
||||
if field.Unique {
|
||||
if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// field isn't Unique but ColumnType's Unique is reported by UniqueConstraint.
|
||||
case !field.Unique && hasConstraint:
|
||||
if err := execTx.Migrator().DropConstraint(value, constraint); err != nil {
|
||||
return err
|
||||
}
|
||||
if field.UniqueIndex != "" {
|
||||
if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if field.UniqueIndex != "" && !queryTx.Migrator().HasIndex(value, field.UniqueIndex) {
|
||||
if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if field.Unique {
|
||||
if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if field.UniqueIndex != "" && !queryTx.Migrator().HasIndex(value, field.UniqueIndex) {
|
||||
if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// avoid using the same name field
|
||||
f := stmt.Schema.LookUpField(name)
|
||||
if f == nil {
|
||||
return fmt.Errorf("failed to look up field with name: %s", name)
|
||||
}
|
||||
|
||||
if !f.IgnoreMigration {
|
||||
fieldType := m.FullDataTypeOf(f)
|
||||
columnName := clause.Column{Name: f.DBName}
|
||||
values := []interface{}{m.CurrentTable(stmt), columnName, fieldType}
|
||||
var alterSql strings.Builder
|
||||
alterSql.WriteString("ALTER TABLE ? ADD ? ?")
|
||||
if f.PrimaryKey || strings.Contains(strings.ToLower(fieldType.SQL), "auto_increment") {
|
||||
alterSql.WriteString(", ADD PRIMARY KEY (?)")
|
||||
values = append(values, columnName)
|
||||
}
|
||||
return m.DB.Exec(alterSql.String(), values...).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fullDataType := m.FullDataTypeOf(field)
|
||||
if m.Dialector.DontSupportRenameColumnUnique {
|
||||
fullDataType.SQL = strings.Replace(fullDataType.SQL, " UNIQUE ", " ", 1)
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? MODIFY COLUMN ? ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fullDataType,
|
||||
).Error
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) TiDBVersion() (isTiDB bool, major, minor, patch int, err error) {
|
||||
// TiDB version string looks like:
|
||||
// "5.7.25-TiDB-v6.5.0" or "5.7.25-TiDB-v6.4.0-serverless"
|
||||
tidbVersionArray := strings.Split(m.Dialector.ServerVersion, "-")
|
||||
if len(tidbVersionArray) < 3 || tidbVersionArray[1] != "TiDB" {
|
||||
// It isn't TiDB
|
||||
return
|
||||
}
|
||||
|
||||
rawVersion := strings.TrimPrefix(tidbVersionArray[2], "v")
|
||||
realVersionArray := strings.Split(rawVersion, ".")
|
||||
if major, err = strconv.Atoi(realVersionArray[0]); err != nil {
|
||||
err = fmt.Errorf("failed to parse the version of TiDB, the major version is: %s", realVersionArray[0])
|
||||
return
|
||||
}
|
||||
|
||||
if minor, err = strconv.Atoi(realVersionArray[1]); err != nil {
|
||||
err = fmt.Errorf("failed to parse the version of TiDB, the minor version is: %s", realVersionArray[1])
|
||||
return
|
||||
}
|
||||
|
||||
if patch, err = strconv.Atoi(realVersionArray[2]); err != nil {
|
||||
err = fmt.Errorf("failed to parse the version of TiDB, the patch version is: %s", realVersionArray[2])
|
||||
return
|
||||
}
|
||||
|
||||
isTiDB = true
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if !m.Dialector.DontSupportRenameColumn {
|
||||
return m.Migrator.RenameColumn(value, oldName, newName)
|
||||
}
|
||||
|
||||
var field *schema.Field
|
||||
if stmt.Schema != nil {
|
||||
if f := stmt.Schema.LookUpField(oldName); f != nil {
|
||||
oldName = f.DBName
|
||||
field = f
|
||||
}
|
||||
|
||||
if f := stmt.Schema.LookUpField(newName); f != nil {
|
||||
newName = f.DBName
|
||||
field = f
|
||||
}
|
||||
}
|
||||
|
||||
if field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? CHANGE ? ? ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: oldName},
|
||||
clause.Column{Name: newName}, m.FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to look up field with name: %s", newName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
if !m.Dialector.Config.DontSupportDropConstraint {
|
||||
return m.Migrator.DropConstraint(value, name)
|
||||
}
|
||||
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
switch constraint.(type) {
|
||||
case *schema.Constraint:
|
||||
return m.DB.Exec("ALTER TABLE ? DROP FOREIGN KEY ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
||||
case *schema.CheckConstraint:
|
||||
return m.DB.Exec("ALTER TABLE ? DROP CHECK ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
||||
}
|
||||
}
|
||||
if m.HasIndex(value, name) {
|
||||
return m.DB.Exec("ALTER TABLE ? DROP INDEX ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
if !m.Dialector.DontSupportRenameIndex {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? RENAME INDEX ? TO ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
err := m.DropIndex(value, oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(newName); idx == nil {
|
||||
if idx = stmt.Schema.LookIndex(oldName); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: newName}, m.CurrentTable(stmt), opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ? ON ??"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m.CreateIndex(value, newName)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
values = m.ReorderModels(values, false)
|
||||
return m.DB.Connection(func(tx *gorm.DB) error {
|
||||
tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Exec("SET FOREIGN_KEY_CHECKS = 1;").Error
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes column types return columnTypes,error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var (
|
||||
currentDatabase, table = m.CurrentSchema(stmt, stmt.Table)
|
||||
columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale "
|
||||
rows, err = m.DB.Session(&gorm.Session{}).Table(table).Limit(1).Rows()
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawColumnTypes, err := rows.ColumnTypes()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !m.DisableDatetimePrecision {
|
||||
columnTypeSQL += ", datetime_precision "
|
||||
}
|
||||
columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION"
|
||||
|
||||
columns, rowErr := m.DB.Table(table).Raw(columnTypeSQL, currentDatabase, table).Rows()
|
||||
if rowErr != nil {
|
||||
return rowErr
|
||||
}
|
||||
|
||||
defer columns.Close()
|
||||
|
||||
for columns.Next() {
|
||||
var (
|
||||
column migrator.ColumnType
|
||||
datetimePrecision sql.NullInt64
|
||||
extraValue sql.NullString
|
||||
columnKey sql.NullString
|
||||
values = []interface{}{
|
||||
&column.NameValue, &column.DefaultValueValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.ColumnTypeValue, &columnKey, &extraValue, &column.CommentValue, &column.DecimalSizeValue, &column.ScaleValue,
|
||||
}
|
||||
)
|
||||
|
||||
if !m.DisableDatetimePrecision {
|
||||
values = append(values, &datetimePrecision)
|
||||
}
|
||||
|
||||
if scanErr := columns.Scan(values...); scanErr != nil {
|
||||
return scanErr
|
||||
}
|
||||
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: false, Valid: true}
|
||||
column.UniqueValue = sql.NullBool{Bool: false, Valid: true}
|
||||
switch columnKey.String {
|
||||
case "PRI":
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
case "UNI":
|
||||
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
|
||||
if strings.Contains(extraValue.String, "auto_increment") {
|
||||
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
|
||||
// only trim paired single-quotes
|
||||
s := column.DefaultValueValue.String
|
||||
for (len(s) >= 3 && s[0] == '\'' && s[len(s)-1] == '\'' && s[len(s)-2] != '\\') ||
|
||||
(len(s) == 2 && s == "''") {
|
||||
s = s[1 : len(s)-1]
|
||||
}
|
||||
column.DefaultValueValue.String = s
|
||||
if m.Dialector.DontSupportNullAsDefaultValue {
|
||||
// rewrite mariadb default value like other version
|
||||
if column.DefaultValueValue.Valid && column.DefaultValueValue.String == "NULL" {
|
||||
column.DefaultValueValue.Valid = false
|
||||
column.DefaultValueValue.String = ""
|
||||
}
|
||||
}
|
||||
|
||||
if datetimePrecision.Valid {
|
||||
column.DecimalSizeValue = datetimePrecision
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
if c.Name() == column.NameValue.String {
|
||||
column.SQLColumnType = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
columnTypes = append(columnTypes, column)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return columnTypes, err
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
baseName := m.Migrator.CurrentDatabase()
|
||||
m.DB.Raw(
|
||||
"SELECT SCHEMA_NAME from Information_schema.SCHEMATA where SCHEMA_NAME LIKE ? ORDER BY SCHEMA_NAME=? DESC,SCHEMA_NAME limit 1",
|
||||
baseName+"%", baseName).Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
|
||||
Scan(&tableList).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
|
||||
indexes := make([]gorm.Index, 0)
|
||||
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
|
||||
result := make([]*Index, 0)
|
||||
schema, table := m.CurrentSchema(stmt, stmt.Table)
|
||||
scanErr := m.DB.Table(table).Raw(indexSql, schema, table).Scan(&result).Error
|
||||
if scanErr != nil {
|
||||
return scanErr
|
||||
}
|
||||
indexMap, indexNames := groupByIndexName(result)
|
||||
|
||||
for _, name := range indexNames {
|
||||
idx := indexMap[name]
|
||||
if len(idx) == 0 {
|
||||
continue
|
||||
}
|
||||
tempIdx := &migrator.Index{
|
||||
TableName: idx[0].TableName,
|
||||
NameValue: idx[0].IndexName,
|
||||
PrimaryKeyValue: sql.NullBool{
|
||||
Bool: idx[0].IndexName == "PRIMARY",
|
||||
Valid: true,
|
||||
},
|
||||
UniqueValue: sql.NullBool{
|
||||
Bool: idx[0].NonUnique == 0,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
for _, x := range idx {
|
||||
tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName)
|
||||
}
|
||||
indexes = append(indexes, tempIdx)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return indexes, err
|
||||
}
|
||||
|
||||
// Index table index info
|
||||
type Index struct {
|
||||
TableName string `gorm:"column:TABLE_NAME"`
|
||||
ColumnName string `gorm:"column:COLUMN_NAME"`
|
||||
IndexName string `gorm:"column:INDEX_NAME"`
|
||||
NonUnique int32 `gorm:"column:NON_UNIQUE"`
|
||||
}
|
||||
|
||||
func groupByIndexName(indexList []*Index) (map[string][]*Index, []string) {
|
||||
columnIndexMap := make(map[string][]*Index, len(indexList))
|
||||
indexNames := make([]string, 0, len(indexList))
|
||||
for _, idx := range indexList {
|
||||
if _, ok := columnIndexMap[idx.IndexName]; !ok {
|
||||
indexNames = append(indexNames, idx.IndexName)
|
||||
}
|
||||
columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx)
|
||||
}
|
||||
return columnIndexMap, indexNames
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) {
|
||||
if tables := strings.Split(table, `.`); len(tables) == 2 {
|
||||
return tables[0], tables[1]
|
||||
}
|
||||
m.DB = m.DB.Table(table)
|
||||
return m.CurrentDatabase(), table
|
||||
}
|
||||
|
||||
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
|
||||
return typeAliasMap[databaseTypeName]
|
||||
}
|
||||
|
||||
// TableType table type return tableType,error
|
||||
func (m Migrator) TableType(value interface{}) (tableType gorm.TableType, err error) {
|
||||
var table migrator.TableType
|
||||
|
||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var (
|
||||
values = []interface{}{
|
||||
&table.SchemaValue, &table.NameValue, &table.TypeValue, &table.CommentValue,
|
||||
}
|
||||
currentDatabase, tableName = m.CurrentSchema(stmt, stmt.Table)
|
||||
tableTypeSQL = "SELECT table_schema, table_name, table_type, table_comment FROM information_schema.tables WHERE table_schema = ? AND table_name = ?"
|
||||
)
|
||||
|
||||
row := m.DB.Table(tableName).Raw(tableTypeSQL, currentDatabase, tableName).Row()
|
||||
|
||||
if scanErr := row.Scan(values...); scanErr != nil {
|
||||
return scanErr
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return table, err
|
||||
}
|
||||
542
vendor/gorm.io/driver/mysql/mysql.go
generated
vendored
Normal file
542
vendor/gorm.io/driver/mysql/mysql.go
generated
vendored
Normal file
@@ -0,0 +1,542 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultDriverName = "mysql"
|
||||
|
||||
AutoRandomTag = "auto_random()" // Treated as an auto_random field for tidb
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DriverName string
|
||||
ServerVersion string
|
||||
DSN string
|
||||
DSNConfig *mysql.Config
|
||||
Conn gorm.ConnPool
|
||||
SkipInitializeWithVersion bool
|
||||
DefaultStringSize uint
|
||||
DefaultDatetimePrecision *int
|
||||
DisableWithReturning bool
|
||||
DisableDatetimePrecision bool
|
||||
DontSupportRenameIndex bool
|
||||
DontSupportRenameColumn bool
|
||||
DontSupportForShareClause bool
|
||||
DontSupportNullAsDefaultValue bool
|
||||
DontSupportRenameColumnUnique bool
|
||||
// As of MySQL 8.0.19, ALTER TABLE permits more general (and SQL standard) syntax
|
||||
// for dropping and altering existing constraints of any type.
|
||||
// see https://dev.mysql.com/doc/refman/8.0/en/alter-table.html
|
||||
DontSupportDropConstraint bool
|
||||
}
|
||||
|
||||
type Dialector struct {
|
||||
*Config
|
||||
}
|
||||
|
||||
var (
|
||||
// CreateClauses create clauses
|
||||
CreateClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
||||
// QueryClauses query clauses
|
||||
QueryClauses = []string{}
|
||||
// UpdateClauses update clauses
|
||||
UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"}
|
||||
// DeleteClauses delete clauses
|
||||
DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"}
|
||||
|
||||
defaultDatetimePrecision = 3
|
||||
)
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
dsnConf, _ := mysql.ParseDSN(dsn)
|
||||
return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}}
|
||||
}
|
||||
|
||||
func New(config Config) gorm.Dialector {
|
||||
switch {
|
||||
case config.DSN == "" && config.DSNConfig != nil:
|
||||
config.DSN = config.DSNConfig.FormatDSN()
|
||||
case config.DSN != "" && config.DSNConfig == nil:
|
||||
config.DSNConfig, _ = mysql.ParseDSN(config.DSN)
|
||||
}
|
||||
return &Dialector{Config: &config}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return DefaultDriverName
|
||||
}
|
||||
|
||||
// NowFunc return now func
|
||||
func (dialector Dialector) NowFunc(n int) func() time.Time {
|
||||
return func() time.Time {
|
||||
round := time.Second / time.Duration(math.Pow10(n))
|
||||
return time.Now().Round(round)
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Apply(config *gorm.Config) error {
|
||||
if config.NowFunc != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dialector.DefaultDatetimePrecision == nil {
|
||||
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
|
||||
}
|
||||
// while maintaining the readability of the code, separate the business logic from
|
||||
// the general part and leave it to the function to do it here.
|
||||
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.DriverName == "" {
|
||||
dialector.DriverName = DefaultDriverName
|
||||
}
|
||||
|
||||
if dialector.DefaultDatetimePrecision == nil {
|
||||
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
|
||||
}
|
||||
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
withReturning := false
|
||||
if !dialector.Config.SkipInitializeWithVersion {
|
||||
err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.Contains(dialector.ServerVersion, "MariaDB") {
|
||||
dialector.Config.DontSupportRenameIndex = true
|
||||
dialector.Config.DontSupportRenameColumn = true
|
||||
dialector.Config.DontSupportForShareClause = true
|
||||
dialector.Config.DontSupportNullAsDefaultValue = true
|
||||
withReturning = checkVersion(dialector.ServerVersion, "10.5")
|
||||
} else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
|
||||
dialector.Config.DontSupportRenameIndex = true
|
||||
dialector.Config.DontSupportRenameColumn = true
|
||||
dialector.Config.DontSupportForShareClause = true
|
||||
dialector.Config.DontSupportDropConstraint = true
|
||||
} else if strings.HasPrefix(dialector.ServerVersion, "5.7.") {
|
||||
dialector.Config.DontSupportRenameColumn = true
|
||||
dialector.Config.DontSupportForShareClause = true
|
||||
dialector.Config.DontSupportDropConstraint = true
|
||||
} else if strings.HasPrefix(dialector.ServerVersion, "5.") {
|
||||
dialector.Config.DisableDatetimePrecision = true
|
||||
dialector.Config.DontSupportRenameIndex = true
|
||||
dialector.Config.DontSupportRenameColumn = true
|
||||
dialector.Config.DontSupportForShareClause = true
|
||||
dialector.Config.DontSupportDropConstraint = true
|
||||
}
|
||||
|
||||
if strings.Contains(dialector.ServerVersion, "TiDB") {
|
||||
dialector.Config.DontSupportRenameColumnUnique = true
|
||||
}
|
||||
}
|
||||
|
||||
// register callbacks
|
||||
callbackConfig := &callbacks.Config{
|
||||
CreateClauses: CreateClauses,
|
||||
QueryClauses: QueryClauses,
|
||||
UpdateClauses: UpdateClauses,
|
||||
DeleteClauses: DeleteClauses,
|
||||
}
|
||||
|
||||
if !dialector.Config.DisableWithReturning && withReturning {
|
||||
if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") {
|
||||
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
|
||||
}
|
||||
|
||||
if !utils.Contains(callbackConfig.UpdateClauses, "RETURNING") {
|
||||
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
|
||||
}
|
||||
|
||||
if !utils.Contains(callbackConfig.DeleteClauses, "RETURNING") {
|
||||
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
|
||||
}
|
||||
}
|
||||
|
||||
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
if _, ok := db.ClauseBuilders[k]; !ok {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
// ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key
|
||||
ClauseOnConflict = "ON CONFLICT"
|
||||
// ClauseValues for clause.ClauseBuilder VALUES key
|
||||
ClauseValues = "VALUES"
|
||||
// ClauseFor for clause.ClauseBuilder FOR key
|
||||
ClauseFor = "FOR"
|
||||
)
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
clauseBuilders := map[string]clause.ClauseBuilder{
|
||||
ClauseOnConflict: func(c clause.Clause, builder clause.Builder) {
|
||||
onConflict, ok := c.Expression.(clause.OnConflict)
|
||||
if !ok {
|
||||
c.Build(builder)
|
||||
return
|
||||
}
|
||||
|
||||
builder.WriteString("ON DUPLICATE KEY UPDATE ")
|
||||
if len(onConflict.DoUpdates) == 0 {
|
||||
if s := builder.(*gorm.Statement).Schema; s != nil {
|
||||
var column clause.Column
|
||||
onConflict.DoNothing = false
|
||||
|
||||
if s.PrioritizedPrimaryField != nil {
|
||||
column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
|
||||
} else if len(s.DBNames) > 0 {
|
||||
column = clause.Column{Name: s.DBNames[0]}
|
||||
}
|
||||
|
||||
if column.Name != "" {
|
||||
onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
|
||||
}
|
||||
|
||||
builder.(*gorm.Statement).AddClause(onConflict)
|
||||
}
|
||||
}
|
||||
|
||||
for idx, assignment := range onConflict.DoUpdates {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(assignment.Column)
|
||||
builder.WriteByte('=')
|
||||
if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" {
|
||||
column.Table = ""
|
||||
builder.WriteString("VALUES(")
|
||||
builder.WriteQuoted(column)
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
builder.AddVar(builder, assignment.Value)
|
||||
}
|
||||
}
|
||||
},
|
||||
ClauseValues: func(c clause.Clause, builder clause.Builder) {
|
||||
if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
|
||||
builder.WriteString("VALUES()")
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
},
|
||||
}
|
||||
|
||||
if dialector.Config.DontSupportForShareClause {
|
||||
clauseBuilders[ClauseFor] = func(c clause.Clause, builder clause.Builder) {
|
||||
if values, ok := c.Expression.(clause.Locking); ok && strings.EqualFold(values.Strength, "SHARE") {
|
||||
builder.WriteString("LOCK IN SHARE MODE")
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
return clauseBuilders
|
||||
}
|
||||
|
||||
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{
|
||||
Migrator: migrator.Migrator{
|
||||
Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
},
|
||||
},
|
||||
Dialector: dialector,
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
var (
|
||||
underQuoted, selfQuoted bool
|
||||
continuousBacktick int8
|
||||
shiftDelimiter int8
|
||||
)
|
||||
|
||||
for _, v := range []byte(str) {
|
||||
switch v {
|
||||
case '`':
|
||||
continuousBacktick++
|
||||
if continuousBacktick == 2 {
|
||||
writer.WriteString("``")
|
||||
continuousBacktick = 0
|
||||
}
|
||||
case '.':
|
||||
if continuousBacktick > 0 || !selfQuoted {
|
||||
shiftDelimiter = 0
|
||||
underQuoted = false
|
||||
continuousBacktick = 0
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
writer.WriteByte(v)
|
||||
continue
|
||||
default:
|
||||
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
||||
writer.WriteByte('`')
|
||||
underQuoted = true
|
||||
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
||||
continuousBacktick -= 1
|
||||
}
|
||||
}
|
||||
|
||||
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
|
||||
writer.WriteByte(v)
|
||||
}
|
||||
shiftDelimiter++
|
||||
}
|
||||
|
||||
if continuousBacktick > 0 && !selfQuoted {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
|
||||
type localTimeInterface interface {
|
||||
In(loc *time.Location) time.Time
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
if dialector.DSNConfig != nil && dialector.DSNConfig.Loc != nil {
|
||||
for i, v := range vars {
|
||||
if p, ok := v.(localTimeInterface); ok {
|
||||
func(i int, t localTimeInterface) {
|
||||
defer func() {
|
||||
recover()
|
||||
}()
|
||||
vars[i] = t.In(dialector.DSNConfig.Loc)
|
||||
}(i, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
return logger.ExplainSQL(sql, nil, `'`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "boolean"
|
||||
case schema.Int, schema.Uint:
|
||||
return dialector.getSchemaIntAndUnitType(field)
|
||||
case schema.Float:
|
||||
return dialector.getSchemaFloatType(field)
|
||||
case schema.String:
|
||||
return dialector.getSchemaStringType(field)
|
||||
case schema.Time:
|
||||
return dialector.getSchemaTimeType(field)
|
||||
case schema.Bytes:
|
||||
return dialector.getSchemaBytesType(field)
|
||||
default:
|
||||
return dialector.getSchemaCustomType(field)
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaFloatType(field *schema.Field) string {
|
||||
if field.Precision > 0 {
|
||||
return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale)
|
||||
}
|
||||
|
||||
if field.Size <= 32 {
|
||||
return "float"
|
||||
}
|
||||
|
||||
return "double"
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
|
||||
size := field.Size
|
||||
if size == 0 {
|
||||
if dialector.DefaultStringSize > 0 {
|
||||
size = int(dialector.DefaultStringSize)
|
||||
} else {
|
||||
hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
|
||||
// TEXT, GEOMETRY or JSON column can't have a default value
|
||||
if field.PrimaryKey || field.HasDefaultValue || hasIndex {
|
||||
size = 191 // utf8mb4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if size >= 65536 && size <= int(math.Pow(2, 24)) {
|
||||
return "mediumtext"
|
||||
}
|
||||
|
||||
if size > int(math.Pow(2, 24)) || size <= 0 {
|
||||
return "longtext"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
|
||||
if !dialector.DisableDatetimePrecision && field.Precision == 0 && field.TagSettings["PRECISION"] == "" {
|
||||
field.Precision = *dialector.DefaultDatetimePrecision
|
||||
}
|
||||
|
||||
var precision string
|
||||
if field.Precision > 0 {
|
||||
precision = fmt.Sprintf("(%d)", field.Precision)
|
||||
}
|
||||
|
||||
if field.NotNull || field.PrimaryKey {
|
||||
return "datetime" + precision
|
||||
}
|
||||
return "datetime" + precision + " NULL"
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
|
||||
if field.Size > 0 && field.Size < 65536 {
|
||||
return fmt.Sprintf("varbinary(%d)", field.Size)
|
||||
}
|
||||
|
||||
if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
|
||||
return "mediumblob"
|
||||
}
|
||||
|
||||
return "longblob"
|
||||
}
|
||||
|
||||
// autoRandomType
|
||||
// field.DataType MUST be `schema.Int` or `schema.Uint`
|
||||
// Judgement logic:
|
||||
// 1. Is PrimaryKey;
|
||||
// 2. Has default value;
|
||||
// 3. Default value is "auto_random()";
|
||||
// 4. IGNORE the field.Size, it MUST be bigint;
|
||||
// 5. CLEAR the default tag, and return true;
|
||||
// 6. Otherwise, return false.
|
||||
func autoRandomType(field *schema.Field) (bool, string) {
|
||||
if field.PrimaryKey && field.HasDefaultValue &&
|
||||
strings.ToLower(strings.TrimSpace(field.DefaultValue)) == AutoRandomTag {
|
||||
field.DefaultValue = ""
|
||||
|
||||
sqlType := "bigint"
|
||||
if field.DataType == schema.Uint {
|
||||
sqlType += " unsigned"
|
||||
}
|
||||
sqlType += " auto_random"
|
||||
return true, sqlType
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
|
||||
if autoRandom, typeString := autoRandomType(field); autoRandom {
|
||||
return typeString
|
||||
}
|
||||
|
||||
constraint := func(sqlType string) string {
|
||||
if field.DataType == schema.Uint {
|
||||
sqlType += " unsigned"
|
||||
}
|
||||
if field.AutoIncrement {
|
||||
sqlType += " AUTO_INCREMENT"
|
||||
}
|
||||
return sqlType
|
||||
}
|
||||
|
||||
switch {
|
||||
case field.Size <= 8:
|
||||
return constraint("tinyint")
|
||||
case field.Size <= 16:
|
||||
return constraint("smallint")
|
||||
case field.Size <= 24:
|
||||
return constraint("mediumint")
|
||||
case field.Size <= 32:
|
||||
return constraint("int")
|
||||
default:
|
||||
return constraint("bigint")
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
|
||||
sqlType := string(field.DataType)
|
||||
|
||||
if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), " auto_increment") {
|
||||
sqlType += " AUTO_INCREMENT"
|
||||
}
|
||||
|
||||
return sqlType
|
||||
}
|
||||
|
||||
func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
return tx.Exec("SAVEPOINT " + name).Error
|
||||
}
|
||||
|
||||
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error
|
||||
}
|
||||
|
||||
// checkVersion newer or equal returns true, old returns false
|
||||
func checkVersion(newVersion, oldVersion string) bool {
|
||||
if newVersion == oldVersion {
|
||||
return true
|
||||
}
|
||||
|
||||
var (
|
||||
versionTrimmerRegexp = regexp.MustCompile(`^(\d+).*$`)
|
||||
|
||||
newVersions = strings.Split(newVersion, ".")
|
||||
oldVersions = strings.Split(oldVersion, ".")
|
||||
)
|
||||
for idx, nv := range newVersions {
|
||||
if len(oldVersions) <= idx {
|
||||
return true
|
||||
}
|
||||
|
||||
nvi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(nv, "$1"))
|
||||
ovi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(oldVersions[idx], "$1"))
|
||||
if nvi == ovi {
|
||||
continue
|
||||
}
|
||||
return nvi > ovi
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
1
vendor/gorm.io/driver/sqlite/.gitignore
generated
vendored
Normal file
1
vendor/gorm.io/driver/sqlite/.gitignore
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.idea/
|
||||
21
vendor/gorm.io/driver/sqlite/License
generated
vendored
Normal file
21
vendor/gorm.io/driver/sqlite/License
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
30
vendor/gorm.io/driver/sqlite/README.md
generated
vendored
Normal file
30
vendor/gorm.io/driver/sqlite/README.md
generated
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# GORM Sqlite Driver
|
||||
|
||||

|
||||
|
||||
## USAGE
|
||||
|
||||
```go
|
||||
import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// github.com/mattn/go-sqlite3
|
||||
db, err := gorm.Open(sqlite.Open("gorm.db"), &gorm.Config{})
|
||||
```
|
||||
|
||||
Checkout [https://gorm.io](https://gorm.io) for details.
|
||||
|
||||
### Pure go Sqlite Driver
|
||||
|
||||
checkout [https://github.com/glebarez/sqlite](https://github.com/glebarez/sqlite) for details
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
db, err := gorm.Open(sqlite.Open("gorm.db"), &gorm.Config{})
|
||||
```
|
||||
285
vendor/gorm.io/driver/sqlite/ddlmod.go
generated
vendored
Normal file
285
vendor/gorm.io/driver/sqlite/ddlmod.go
generated
vendored
Normal file
@@ -0,0 +1,285 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteSeparator = "`|\"|'|\t"
|
||||
uniqueRegexp = regexp.MustCompile(fmt.Sprintf(`^CONSTRAINT [%v]?[\w-]+[%v]? UNIQUE (.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]?(?s:.*?)ON (.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
|
||||
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
|
||||
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
|
||||
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
|
||||
)
|
||||
|
||||
type ddl struct {
|
||||
head string
|
||||
fields []string
|
||||
columns []migrator.ColumnType
|
||||
}
|
||||
|
||||
func parseDDL(strs ...string) (*ddl, error) {
|
||||
var result ddl
|
||||
for _, str := range strs {
|
||||
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
|
||||
var (
|
||||
ddlBody = sections[2]
|
||||
ddlBodyRunes = []rune(ddlBody)
|
||||
bracketLevel int
|
||||
quote rune
|
||||
buf string
|
||||
)
|
||||
ddlBodyRunesLen := len(ddlBodyRunes)
|
||||
|
||||
result.head = sections[1]
|
||||
|
||||
for idx := 0; idx < ddlBodyRunesLen; idx++ {
|
||||
var (
|
||||
next rune = 0
|
||||
c = ddlBodyRunes[idx]
|
||||
)
|
||||
if idx+1 < ddlBodyRunesLen {
|
||||
next = ddlBodyRunes[idx+1]
|
||||
}
|
||||
|
||||
if sc := string(c); separatorRegexp.MatchString(sc) {
|
||||
if c == next {
|
||||
buf += sc // Skip escaped quote
|
||||
idx++
|
||||
} else if quote > 0 {
|
||||
quote = 0
|
||||
} else {
|
||||
quote = c
|
||||
}
|
||||
} else if quote == 0 {
|
||||
if c == '(' {
|
||||
bracketLevel++
|
||||
} else if c == ')' {
|
||||
bracketLevel--
|
||||
} else if bracketLevel == 0 {
|
||||
if c == ',' {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
buf = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bracketLevel < 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
buf += string(c)
|
||||
}
|
||||
|
||||
if bracketLevel != 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
if buf != "" {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
}
|
||||
|
||||
for _, f := range result.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "CHECK") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(fUpper, "CONSTRAINT") {
|
||||
matches := uniqueRegexp.FindStringSubmatch(f)
|
||||
if len(matches) > 0 {
|
||||
cols, err := parseAllColumns(matches[1])
|
||||
if err == nil && len(cols) == 1 {
|
||||
for idx, column := range result.columns {
|
||||
if column.NameValue.String == cols[0] {
|
||||
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
result.columns[idx] = column
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
|
||||
cols, err := parseAllColumns(f)
|
||||
if err == nil {
|
||||
for _, name := range cols {
|
||||
for idx, column := range result.columns {
|
||||
if column.NameValue.String == name {
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
result.columns[idx] = column
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
|
||||
columnType := migrator.ColumnType{
|
||||
NameValue: sql.NullString{String: matches[1], Valid: true},
|
||||
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
UniqueValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: true, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
}
|
||||
|
||||
matchUpper := strings.ToUpper(matches[3])
|
||||
if strings.Contains(matchUpper, " NOT NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
|
||||
} else if strings.Contains(matchUpper, " NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " UNIQUE") {
|
||||
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " PRIMARY") {
|
||||
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
|
||||
if strings.ToLower(defaultMatches[1]) != "null" {
|
||||
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
|
||||
}
|
||||
}
|
||||
|
||||
// data type length
|
||||
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
|
||||
if len(matches) == 1 && len(matches[0]) == 2 {
|
||||
size, _ := strconv.Atoi(matches[0][1])
|
||||
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
|
||||
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
|
||||
}
|
||||
|
||||
result.columns = append(result.columns, columnType)
|
||||
}
|
||||
}
|
||||
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
|
||||
// don't report Unique by UniqueIndex
|
||||
} else {
|
||||
return nil, errors.New("invalid DDL")
|
||||
}
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (d *ddl) clone() *ddl {
|
||||
copied := new(ddl)
|
||||
*copied = *d
|
||||
|
||||
copied.fields = make([]string, len(d.fields))
|
||||
copy(copied.fields, d.fields)
|
||||
copied.columns = make([]migrator.ColumnType, len(d.columns))
|
||||
copy(copied.columns, d.columns)
|
||||
|
||||
return copied
|
||||
}
|
||||
|
||||
func (d *ddl) compile() string {
|
||||
if len(d.fields) == 0 {
|
||||
return d.head
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
|
||||
}
|
||||
|
||||
func (d *ddl) renameTable(dst, src string) error {
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
|
||||
if replaced == d.head {
|
||||
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
|
||||
}
|
||||
|
||||
d.head = replaced
|
||||
return nil
|
||||
}
|
||||
|
||||
func compileConstraintRegexp(name string) *regexp.Regexp {
|
||||
return regexp.MustCompile("^(?i:CONSTRAINT)\\s+[\"`]?" + regexp.QuoteMeta(name) + "[\"`\\s]")
|
||||
}
|
||||
|
||||
func (d *ddl) addConstraint(name string, sql string) {
|
||||
reg := compileConstraintRegexp(name)
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields[i] = sql
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d.fields = append(d.fields, sql)
|
||||
}
|
||||
|
||||
func (d *ddl) removeConstraint(name string) bool {
|
||||
reg := compileConstraintRegexp(name)
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) hasConstraint(name string) bool {
|
||||
reg := compileConstraintRegexp(name)
|
||||
|
||||
for _, f := range d.fields {
|
||||
if reg.MatchString(f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) getColumns() []string {
|
||||
res := []string{}
|
||||
|
||||
for _, f := range d.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
|
||||
strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
|
||||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
|
||||
continue
|
||||
}
|
||||
|
||||
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
|
||||
match := reg.FindStringSubmatch(f)
|
||||
|
||||
if match != nil {
|
||||
res = append(res, "`"+match[1]+"`")
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (d *ddl) removeColumn(name string) bool {
|
||||
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
117
vendor/gorm.io/driver/sqlite/ddlmod_parse_all_columns.go
generated
vendored
Normal file
117
vendor/gorm.io/driver/sqlite/ddlmod_parse_all_columns.go
generated
vendored
Normal file
@@ -0,0 +1,117 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type parseAllColumnsState int
|
||||
|
||||
const (
|
||||
parseAllColumnsState_NONE parseAllColumnsState = iota
|
||||
parseAllColumnsState_Beginning
|
||||
parseAllColumnsState_ReadingRawName
|
||||
parseAllColumnsState_ReadingQuotedName
|
||||
parseAllColumnsState_EndOfName
|
||||
parseAllColumnsState_State_End
|
||||
)
|
||||
|
||||
func parseAllColumns(in string) ([]string, error) {
|
||||
s := []rune(in)
|
||||
columns := make([]string, 0)
|
||||
state := parseAllColumnsState_NONE
|
||||
quote := rune(0)
|
||||
name := make([]rune, 0)
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch state {
|
||||
case parseAllColumnsState_NONE:
|
||||
if s[i] == '(' {
|
||||
state = parseAllColumnsState_Beginning
|
||||
}
|
||||
case parseAllColumnsState_Beginning:
|
||||
if isSpace(s[i]) {
|
||||
continue
|
||||
}
|
||||
if isQuote(s[i]) {
|
||||
state = parseAllColumnsState_ReadingQuotedName
|
||||
quote = s[i]
|
||||
continue
|
||||
}
|
||||
if s[i] == '[' {
|
||||
state = parseAllColumnsState_ReadingQuotedName
|
||||
quote = ']'
|
||||
continue
|
||||
} else if s[i] == ')' {
|
||||
return columns, fmt.Errorf("unexpected token: %s", string(s[i]))
|
||||
}
|
||||
state = parseAllColumnsState_ReadingRawName
|
||||
name = append(name, s[i])
|
||||
case parseAllColumnsState_ReadingRawName:
|
||||
if isSeparator(s[i]) {
|
||||
state = parseAllColumnsState_Beginning
|
||||
columns = append(columns, string(name))
|
||||
name = make([]rune, 0)
|
||||
continue
|
||||
}
|
||||
if s[i] == ')' {
|
||||
state = parseAllColumnsState_State_End
|
||||
columns = append(columns, string(name))
|
||||
}
|
||||
if isQuote(s[i]) {
|
||||
return nil, fmt.Errorf("unexpected token: %s", string(s[i]))
|
||||
}
|
||||
if isSpace(s[i]) {
|
||||
state = parseAllColumnsState_EndOfName
|
||||
columns = append(columns, string(name))
|
||||
name = make([]rune, 0)
|
||||
continue
|
||||
}
|
||||
name = append(name, s[i])
|
||||
case parseAllColumnsState_ReadingQuotedName:
|
||||
if s[i] == quote {
|
||||
// check if quote character is escaped
|
||||
if i+1 < len(s) && s[i+1] == quote {
|
||||
name = append(name, quote)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
state = parseAllColumnsState_EndOfName
|
||||
columns = append(columns, string(name))
|
||||
name = make([]rune, 0)
|
||||
continue
|
||||
}
|
||||
name = append(name, s[i])
|
||||
case parseAllColumnsState_EndOfName:
|
||||
if isSpace(s[i]) {
|
||||
continue
|
||||
}
|
||||
if isSeparator(s[i]) {
|
||||
state = parseAllColumnsState_Beginning
|
||||
continue
|
||||
}
|
||||
if s[i] == ')' {
|
||||
state = parseAllColumnsState_State_End
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected token: %s", string(s[i]))
|
||||
case parseAllColumnsState_State_End:
|
||||
break
|
||||
}
|
||||
}
|
||||
if state != parseAllColumnsState_State_End {
|
||||
return nil, errors.New("unexpected end")
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func isSpace(r rune) bool {
|
||||
return r == ' ' || r == '\t'
|
||||
}
|
||||
|
||||
func isQuote(r rune) bool {
|
||||
return r == '`' || r == '"' || r == '\''
|
||||
}
|
||||
|
||||
func isSeparator(r rune) bool {
|
||||
return r == ','
|
||||
}
|
||||
40
vendor/gorm.io/driver/sqlite/error_translator.go
generated
vendored
Normal file
40
vendor/gorm.io/driver/sqlite/error_translator.go
generated
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// The error codes to map sqlite errors to gorm errors, here is a reference about error codes for sqlite https://www.sqlite.org/rescode.html.
|
||||
var errCodes = map[int]error{
|
||||
1555: gorm.ErrDuplicatedKey,
|
||||
2067: gorm.ErrDuplicatedKey,
|
||||
787: gorm.ErrForeignKeyViolated,
|
||||
}
|
||||
|
||||
type ErrMessage struct {
|
||||
Code int `json:"Code"`
|
||||
ExtendedCode int `json:"ExtendedCode"`
|
||||
SystemErrno int `json:"SystemErrno"`
|
||||
}
|
||||
|
||||
// Translate it will translate the error to native gorm errors.
|
||||
// We are not using go-sqlite3 error type intentionally here because it will need the CGO_ENABLED=1 and cross-C-compiler.
|
||||
func (dialector Dialector) Translate(err error) error {
|
||||
parsedErr, marshalErr := json.Marshal(err)
|
||||
if marshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errMsg ErrMessage
|
||||
unmarshalErr := json.Unmarshal(parsedErr, &errMsg)
|
||||
if unmarshalErr != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if translatedErr, found := errCodes[errMsg.ExtendedCode]; found {
|
||||
return translatedErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
7
vendor/gorm.io/driver/sqlite/errors.go
generated
vendored
Normal file
7
vendor/gorm.io/driver/sqlite/errors.go
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
package sqlite
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrConstraintsNotImplemented = errors.New("constraints not implemented on sqlite, consider using DisableForeignKeyConstraintWhenMigrating, more details https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft#all-new-migrator")
|
||||
)
|
||||
430
vendor/gorm.io/driver/sqlite/migrator.go
generated
vendored
Normal file
430
vendor/gorm.io/driver/sqlite/migrator.go
generated
vendored
Normal file
@@ -0,0 +1,430 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
var enabled int
|
||||
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
|
||||
if enabled == 1 {
|
||||
m.DB.Exec("PRAGMA foreign_keys = OFF")
|
||||
defer m.DB.Exec("PRAGMA foreign_keys = ON")
|
||||
}
|
||||
|
||||
return fc()
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
var sqlArgs []interface{}
|
||||
for i, f := range ddl.fields {
|
||||
if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 1 && matches[1] == field.DBName {
|
||||
ddl.fields[i] = fmt.Sprintf("`%v` ?", field.DBName)
|
||||
sqlArgs = []interface{}{m.FullDataTypeOf(field)}
|
||||
// table created by old version might look like `CREATE TABLE ? (? varchar(10) UNIQUE)`.
|
||||
// FullDataTypeOf doesn't contain UNIQUE, so we need to add unique constraint.
|
||||
if strings.Contains(strings.ToUpper(matches[3]), " UNIQUE") {
|
||||
uniName := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
|
||||
uni, _ := m.GuessConstraintInterfaceAndTable(stmt, uniName)
|
||||
if uni != nil {
|
||||
uniSQL, uniArgs := uni.Build()
|
||||
ddl.addConstraint(uniName, uniSQL)
|
||||
sqlArgs = append(sqlArgs, uniArgs...)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return ddl, sqlArgs, nil
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to alter field with name %v", name)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
sqls []string
|
||||
sqlDDL *ddl
|
||||
)
|
||||
|
||||
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlDDL, err = parseDDL(sqls...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = rows.Close()
|
||||
}()
|
||||
|
||||
var rawColumnTypes []*sql.ColumnType
|
||||
rawColumnTypes, err = rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
columnType := migrator.ColumnType{SQLColumnType: c}
|
||||
for _, column := range sqlDDL.columns {
|
||||
if column.NameValue.String == c.Name() {
|
||||
column.SQLColumnType = c
|
||||
columnType = column
|
||||
break
|
||||
}
|
||||
}
|
||||
columnTypes = append(columnTypes, columnType)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
ddl.removeColumn(name)
|
||||
return ddl, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
var (
|
||||
constraintName string
|
||||
constraintSql string
|
||||
constraintValues []interface{}
|
||||
)
|
||||
|
||||
if constraint != nil {
|
||||
constraintName = constraint.GetName()
|
||||
constraintSql, constraintValues = constraint.Build()
|
||||
} else {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
ddl.addConstraint(constraintName, constraintSql)
|
||||
return ddl, constraintValues, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
}
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
|
||||
ddl.removeConstraint(name)
|
||||
return ddl, nil, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
}
|
||||
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
var null interface{}
|
||||
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
str = opt.Expression
|
||||
}
|
||||
|
||||
if opt.Collate != "" {
|
||||
str += " COLLATE " + opt.Collate
|
||||
}
|
||||
|
||||
if opt.Sort != "" {
|
||||
str += " " + opt.Sort
|
||||
}
|
||||
results = append(results, clause.Expr{SQL: str})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
createIndexSQL += " ON ??"
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var sql string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
|
||||
if sql != "" {
|
||||
if err := m.DropIndex(value, oldName); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
|
||||
}
|
||||
return fmt.Errorf("failed to find index with name %v", oldName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
|
||||
})
|
||||
}
|
||||
|
||||
type Index struct {
|
||||
Seq int
|
||||
Name string
|
||||
Unique bool
|
||||
Origin string
|
||||
Partial bool
|
||||
}
|
||||
|
||||
// GetIndexes return Indexes []gorm.Index and execErr error,
|
||||
// See the [doc]
|
||||
//
|
||||
// [doc]: https://www.sqlite.org/pragma.html#pragma_index_list
|
||||
func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
|
||||
indexes := make([]gorm.Index, 0)
|
||||
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
rst := make([]*Index, 0)
|
||||
if err := m.DB.Debug().Raw("SELECT * FROM PRAGMA_index_list(?)", stmt.Table).Scan(&rst).Error; err != nil { // alias `PRAGMA index_list(?)`
|
||||
return err
|
||||
}
|
||||
for _, index := range rst {
|
||||
if index.Origin == "u" { // skip the index was created by a UNIQUE constraint
|
||||
continue
|
||||
}
|
||||
var columns []string
|
||||
if err := m.DB.Raw("SELECT name FROM PRAGMA_index_info(?)", index.Name).Scan(&columns).Error; err != nil { // alias `PRAGMA index_info(?)`
|
||||
return err
|
||||
}
|
||||
indexes = append(indexes, &migrator.Index{
|
||||
TableName: stmt.Table,
|
||||
NameValue: index.Name,
|
||||
ColumnList: columns,
|
||||
PrimaryKeyValue: sql.NullBool{Bool: index.Origin == "pk", Valid: true}, // The exceptions are INTEGER PRIMARY KEY
|
||||
UniqueValue: sql.NullBool{Bool: index.Unique, Valid: true},
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return indexes, err
|
||||
}
|
||||
|
||||
func (m Migrator) getRawDDL(table string) (string, error) {
|
||||
var createSQL string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
|
||||
|
||||
if m.DB.Error != nil {
|
||||
return "", m.DB.Error
|
||||
}
|
||||
return createSQL, nil
|
||||
}
|
||||
|
||||
func (m Migrator) recreateTable(
|
||||
value interface{}, tablePtr *string,
|
||||
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
|
||||
) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
table := stmt.Table
|
||||
if tablePtr != nil {
|
||||
table = *tablePtr
|
||||
}
|
||||
|
||||
rawDDL, err := m.getRawDDL(table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
originDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createDDL == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
newTableName := table + "__temp"
|
||||
if err := createDDL.renameTable(newTableName, table); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
columns := createDDL.getColumns()
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return m.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := []string{
|
||||
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
|
||||
fmt.Sprintf("DROP TABLE `%v`", table),
|
||||
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
|
||||
}
|
||||
for _, query := range queries {
|
||||
if err := tx.Exec(query).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
270
vendor/gorm.io/driver/sqlite/sqlite.go
generated
vendored
Normal file
270
vendor/gorm.io/driver/sqlite/sqlite.go
generated
vendored
Normal file
@@ -0,0 +1,270 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
|
||||
"gorm.io/gorm/callbacks"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// DriverName is the default driver name for SQLite.
|
||||
const DriverName = "sqlite3"
|
||||
|
||||
type Dialector struct {
|
||||
DriverName string
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
DriverName string
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func New(config Config) gorm.Dialector {
|
||||
return &Dialector{DSN: config.DSN, DriverName: config.DriverName, Conn: config.Conn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.DriverName == "" {
|
||||
dialector.DriverName = DriverName
|
||||
}
|
||||
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
conn, err := sql.Open(dialector.DriverName, dialector.DSN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.ConnPool = conn
|
||||
}
|
||||
|
||||
var version string
|
||||
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
|
||||
return err
|
||||
}
|
||||
// https://www.sqlite.org/releaselog/3_35_0.html
|
||||
if compareVersion(version, "3.35.0") >= 0 {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
} else {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
if _, ok := db.ClauseBuilders[k]; !ok {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"INSERT": func(c clause.Clause, builder clause.Builder) {
|
||||
if insert, ok := c.Expression.(clause.Insert); ok {
|
||||
if stmt, ok := builder.(*gorm.Statement); ok {
|
||||
stmt.WriteString("INSERT ")
|
||||
if insert.Modifier != "" {
|
||||
stmt.WriteString(insert.Modifier)
|
||||
stmt.WriteByte(' ')
|
||||
}
|
||||
|
||||
stmt.WriteString("INTO ")
|
||||
if insert.Table.Name == "" {
|
||||
stmt.WriteQuoted(stmt.Table)
|
||||
} else {
|
||||
stmt.WriteQuoted(insert.Table)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
},
|
||||
"LIMIT": func(c clause.Clause, builder clause.Builder) {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
var lmt = -1
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
lmt = *limit.Limit
|
||||
}
|
||||
if lmt >= 0 || limit.Offset > 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(lmt))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
builder.WriteString(" OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
}
|
||||
}
|
||||
},
|
||||
"FOR": func(c clause.Clause, builder clause.Builder) {
|
||||
if _, ok := c.Expression.(clause.Locking); ok {
|
||||
// SQLite3 does not support row-level locking.
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
if field.AutoIncrement {
|
||||
return clause.Expr{SQL: "NULL"}
|
||||
}
|
||||
|
||||
// doesn't work, will raise error
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
var (
|
||||
underQuoted, selfQuoted bool
|
||||
continuousBacktick int8
|
||||
shiftDelimiter int8
|
||||
)
|
||||
|
||||
for _, v := range []byte(str) {
|
||||
switch v {
|
||||
case '`':
|
||||
continuousBacktick++
|
||||
if continuousBacktick == 2 {
|
||||
writer.WriteString("``")
|
||||
continuousBacktick = 0
|
||||
}
|
||||
case '.':
|
||||
if continuousBacktick > 0 || !selfQuoted {
|
||||
shiftDelimiter = 0
|
||||
underQuoted = false
|
||||
continuousBacktick = 0
|
||||
writer.WriteString("`")
|
||||
}
|
||||
writer.WriteByte(v)
|
||||
continue
|
||||
default:
|
||||
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
||||
writer.WriteString("`")
|
||||
underQuoted = true
|
||||
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
||||
continuousBacktick -= 1
|
||||
}
|
||||
}
|
||||
|
||||
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
|
||||
writer.WriteByte(v)
|
||||
}
|
||||
shiftDelimiter++
|
||||
}
|
||||
|
||||
if continuousBacktick > 0 && !selfQuoted {
|
||||
writer.WriteString("``")
|
||||
}
|
||||
writer.WriteString("`")
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement {
|
||||
// doesn't check `PrimaryKey`, to keep backward compatibility
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
return "integer"
|
||||
}
|
||||
case schema.Float:
|
||||
return "real"
|
||||
case schema.String:
|
||||
return "text"
|
||||
case schema.Time:
|
||||
// Distinguish between schema.Time and tag time
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
return val
|
||||
} else {
|
||||
return "datetime"
|
||||
}
|
||||
case schema.Bytes:
|
||||
return "blob"
|
||||
}
|
||||
|
||||
return string(field.DataType)
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareVersion(version1, version2 string) int {
|
||||
n, m := len(version1), len(version2)
|
||||
i, j := 0, 0
|
||||
for i < n || j < m {
|
||||
x := 0
|
||||
for ; i < n && version1[i] != '.'; i++ {
|
||||
x = x*10 + int(version1[i]-'0')
|
||||
}
|
||||
i++
|
||||
y := 0
|
||||
for ; j < m && version2[j] != '.'; j++ {
|
||||
y = y*10 + int(version2[j]-'0')
|
||||
}
|
||||
j++
|
||||
if x > y {
|
||||
return 1
|
||||
}
|
||||
if x < y {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
7
vendor/gorm.io/gorm/.gitignore
generated
vendored
Normal file
7
vendor/gorm.io/gorm/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
TODO*
|
||||
documents
|
||||
coverage.txt
|
||||
_book
|
||||
.idea
|
||||
vendor
|
||||
.vscode
|
||||
19
vendor/gorm.io/gorm/.golangci.yml
generated
vendored
Normal file
19
vendor/gorm.io/gorm/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
version: "2"
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- cyclop
|
||||
- gocritic
|
||||
- gosec
|
||||
- ineffassign
|
||||
- misspell
|
||||
- prealloc
|
||||
- unconvert
|
||||
- unparam
|
||||
- whitespace
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofumpt
|
||||
- goimports
|
||||
128
vendor/gorm.io/gorm/CODE_OF_CONDUCT.md
generated
vendored
Normal file
128
vendor/gorm.io/gorm/CODE_OF_CONDUCT.md
generated
vendored
Normal file
@@ -0,0 +1,128 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to participate in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community includes:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period. This
|
||||
includes avoiding interactions in community spaces and external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any interaction or public
|
||||
communication with the community for a specified period. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
21
vendor/gorm.io/gorm/LICENSE
generated
vendored
Normal file
21
vendor/gorm.io/gorm/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
44
vendor/gorm.io/gorm/README.md
generated
vendored
Normal file
44
vendor/gorm.io/gorm/README.md
generated
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
# GORM
|
||||
|
||||
The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
|
||||
[](https://goreportcard.com/report/github.com/go-gorm/gorm)
|
||||
[](https://github.com/go-gorm/gorm/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://pkg.go.dev/gorm.io/gorm?tab=doc)
|
||||
|
||||
## Overview
|
||||
|
||||
* Full-Featured ORM
|
||||
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance)
|
||||
* Hooks (Before/After Create/Save/Update/Delete/Find)
|
||||
* Eager loading with `Preload`, `Joins`
|
||||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||
* Context, Prepared Statement Mode, DryRun Mode
|
||||
* Batch Insert, FindInBatches, Find To Map
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
|
||||
* Composite Primary Key
|
||||
* Auto Migrations
|
||||
* Logger
|
||||
* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
|
||||
* Every feature comes with tests
|
||||
* Developer Friendly
|
||||
|
||||
## Getting Started
|
||||
|
||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
|
||||
|
||||
## Contributing
|
||||
|
||||
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
||||
|
||||
## Contributors
|
||||
|
||||
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
|
||||
|
||||
## License
|
||||
|
||||
© Jinzhu, 2013~time.Now
|
||||
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
|
||||
593
vendor/gorm.io/gorm/association.go
generated
vendored
Normal file
593
vendor/gorm.io/gorm/association.go
generated
vendored
Normal file
@@ -0,0 +1,593 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Association Mode contains some helper methods to handle relationship things easily.
|
||||
type Association struct {
|
||||
DB *DB
|
||||
Relationship *schema.Relationship
|
||||
Unscope bool
|
||||
Error error
|
||||
}
|
||||
|
||||
func (db *DB) Association(column string) *Association {
|
||||
association := &Association{DB: db}
|
||||
table := db.Statement.Table
|
||||
|
||||
if err := db.Statement.Parse(db.Statement.Model); err == nil {
|
||||
db.Statement.Table = table
|
||||
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
|
||||
|
||||
if association.Relationship == nil {
|
||||
association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
||||
}
|
||||
} else {
|
||||
association.Error = err
|
||||
}
|
||||
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Unscoped() *Association {
|
||||
return &Association{
|
||||
DB: association.DB,
|
||||
Relationship: association.Relationship,
|
||||
Error: association.Error,
|
||||
Unscope: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
}
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Append(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
switch association.Relationship.Type {
|
||||
case schema.HasOne, schema.BelongsTo:
|
||||
if len(values) > 0 {
|
||||
association.Error = association.Replace(values...)
|
||||
}
|
||||
default:
|
||||
association.saveAssociation( /*clear*/ false, values...)
|
||||
}
|
||||
}
|
||||
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Replace(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
|
||||
var oldBelongsToExpr clause.Expression
|
||||
// we have to record the old BelongsTo value
|
||||
if association.Unscope && rel.Type == schema.BelongsTo {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
oldBelongsToExpr = clause.IN{Column: column, Values: values}
|
||||
}
|
||||
}
|
||||
|
||||
// save associations
|
||||
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
|
||||
return association.Error
|
||||
}
|
||||
|
||||
// set old associations's foreign key to null
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
if len(values) == 0 {
|
||||
updateMap := map[string]interface{}{}
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
updateMap[ref.ForeignKey.DBName] = nil
|
||||
}
|
||||
|
||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||
}
|
||||
if association.Unscope && oldBelongsToExpr != nil {
|
||||
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
var (
|
||||
primaryFields []*schema.Field
|
||||
foreignKeys []string
|
||||
updateMap = map[string]interface{}{}
|
||||
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
|
||||
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
|
||||
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
|
||||
tx.Not(clause.IN{Column: column, Values: values})
|
||||
}
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||
updateMap[ref.ForeignKey.DBName] = nil
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||
if association.Unscope {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
|
||||
} else {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
}
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
|
||||
} else {
|
||||
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
|
||||
}
|
||||
} else {
|
||||
tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
|
||||
tx.Where(clause.IN{Column: column, Values: values})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
|
||||
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
|
||||
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
|
||||
}
|
||||
|
||||
association.Error = tx.Delete(modelValue).Error
|
||||
}
|
||||
}
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Delete(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
var (
|
||||
reflectValue = association.DB.Statement.ReflectValue
|
||||
rel = association.Relationship
|
||||
primaryFields []*schema.Field
|
||||
foreignKeys []string
|
||||
updateAttrs = map[string]interface{}{}
|
||||
conds []clause.Expression
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||
updateAttrs[ref.ForeignKey.DBName] = nil
|
||||
} else {
|
||||
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
associationDB := association.DB.Session(&Session{})
|
||||
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
if association.Unscope {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
model := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := association.DB.Model(model)
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
if association.Unscope {
|
||||
association.Error = tx.Clauses(conds...).Delete(model).Error
|
||||
} else {
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
|
||||
} else {
|
||||
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
|
||||
}
|
||||
} else {
|
||||
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
|
||||
}
|
||||
|
||||
if association.Error == nil {
|
||||
// clean up deleted values's foreign key
|
||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
|
||||
|
||||
cleanUpDeletedRelations := func(data reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
|
||||
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
|
||||
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
|
||||
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
|
||||
for i := 0; i < fieldValue.Len(); i++ {
|
||||
for idx, field := range rel.FieldSchema.PrimaryFields {
|
||||
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
|
||||
}
|
||||
|
||||
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
|
||||
validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
|
||||
case reflect.Struct:
|
||||
for idx, field := range rel.FieldSchema.PrimaryFields {
|
||||
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
|
||||
}
|
||||
|
||||
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
|
||||
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if rel.JoinTable == nil {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
|
||||
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
} else {
|
||||
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
|
||||
}
|
||||
case reflect.Struct:
|
||||
cleanUpDeletedRelations(reflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Clear() error {
|
||||
return association.Replace()
|
||||
}
|
||||
|
||||
func (association *Association) Count() (count int64) {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Count(&count).Error
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type assignBack struct {
|
||||
Source reflect.Value
|
||||
Index int
|
||||
Dest reflect.Value
|
||||
}
|
||||
|
||||
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
|
||||
var (
|
||||
reflectValue = association.DB.Statement.ReflectValue
|
||||
assignBacks []assignBack // assign association values back to arguments after save
|
||||
)
|
||||
|
||||
appendToRelations := func(source, rv reflect.Value, clear bool) {
|
||||
switch association.Relationship.Type {
|
||||
case schema.HasOne, schema.BelongsTo:
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() > 0 {
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
|
||||
|
||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !rv.CanAddr() {
|
||||
association.Error = ErrInvalidValue
|
||||
return
|
||||
}
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
|
||||
|
||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
|
||||
}
|
||||
}
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
||||
var fieldValue reflect.Value
|
||||
if clear {
|
||||
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
|
||||
} else {
|
||||
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
|
||||
reflect.Copy(fieldValue, oldFieldValue)
|
||||
}
|
||||
|
||||
appendToFieldValues := func(ev reflect.Value) {
|
||||
if ev.Type().AssignableTo(elemType) {
|
||||
fieldValue = reflect.Append(fieldValue, ev)
|
||||
} else if ev.Type().Elem().AssignableTo(elemType) {
|
||||
fieldValue = reflect.Append(fieldValue, ev.Elem())
|
||||
} else {
|
||||
association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
|
||||
}
|
||||
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
|
||||
}
|
||||
}
|
||||
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !rv.CanAddr() {
|
||||
association.Error = ErrInvalidValue
|
||||
return
|
||||
}
|
||||
appendToFieldValues(rv.Addr())
|
||||
}
|
||||
|
||||
if association.Error == nil {
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
selectedSaveColumns := []string{association.Relationship.Name}
|
||||
omitColumns := []string{}
|
||||
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
|
||||
for name, ok := range selectColumns {
|
||||
columnName := ""
|
||||
if strings.HasPrefix(name, association.Relationship.Name) {
|
||||
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
|
||||
columnName = name
|
||||
}
|
||||
} else if strings.HasPrefix(name, clause.Associations) {
|
||||
columnName = name
|
||||
}
|
||||
|
||||
if columnName != "" {
|
||||
if ok {
|
||||
selectedSaveColumns = append(selectedSaveColumns, columnName)
|
||||
} else {
|
||||
omitColumns = append(omitColumns, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
|
||||
}
|
||||
}
|
||||
|
||||
associationDB := association.DB.Session(&Session{}).Model(nil)
|
||||
if !association.DB.FullSaveAssociations {
|
||||
associationDB.Select(selectedSaveColumns)
|
||||
}
|
||||
if len(omitColumns) > 0 {
|
||||
associationDB.Omit(omitColumns...)
|
||||
}
|
||||
associationDB = associationDB.Session(&Session{})
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if len(values) != reflectValue.Len() {
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
||||
association.Error = err
|
||||
break
|
||||
}
|
||||
|
||||
if association.Relationship.JoinTable == nil {
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
||||
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
|
||||
association.Error = err
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
association.Error = ErrInvalidValueOfLength
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
|
||||
if association.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO support save slice data, sql with case?
|
||||
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
|
||||
}
|
||||
case reflect.Struct:
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||
|
||||
if association.Relationship.JoinTable == nil && association.Error == nil {
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
||||
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, value := range values {
|
||||
rv := reflect.Indirect(reflect.ValueOf(value))
|
||||
appendToRelations(reflectValue, rv, clear && idx == 0)
|
||||
if association.Error != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
|
||||
}
|
||||
}
|
||||
|
||||
for _, assignBack := range assignBacks {
|
||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
|
||||
if assignBack.Index > 0 {
|
||||
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
|
||||
} else {
|
||||
reflect.Indirect(assignBack.Dest).Set(fieldValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) buildCondition() *DB {
|
||||
var (
|
||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
|
||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE")
|
||||
if len(joinStmt.SQL.String()) > 0 {
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
}
|
||||
|
||||
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: queryConds},
|
||||
}}})
|
||||
} else {
|
||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
360
vendor/gorm.io/gorm/callbacks.go
generated
vendored
Normal file
360
vendor/gorm.io/gorm/callbacks.go
generated
vendored
Normal file
@@ -0,0 +1,360 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func initializeCallbacks(db *DB) *callbacks {
|
||||
return &callbacks{
|
||||
processors: map[string]*processor{
|
||||
"create": {db: db},
|
||||
"query": {db: db},
|
||||
"update": {db: db},
|
||||
"delete": {db: db},
|
||||
"row": {db: db},
|
||||
"raw": {db: db},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// callbacks gorm callbacks manager
|
||||
type callbacks struct {
|
||||
processors map[string]*processor
|
||||
}
|
||||
|
||||
type processor struct {
|
||||
db *DB
|
||||
Clauses []string
|
||||
fns []func(*DB)
|
||||
callbacks []*callback
|
||||
}
|
||||
|
||||
type callback struct {
|
||||
name string
|
||||
before string
|
||||
after string
|
||||
remove bool
|
||||
replace bool
|
||||
match func(*DB) bool
|
||||
handler func(*DB)
|
||||
processor *processor
|
||||
}
|
||||
|
||||
func (cs *callbacks) Create() *processor {
|
||||
return cs.processors["create"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Query() *processor {
|
||||
return cs.processors["query"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Update() *processor {
|
||||
return cs.processors["update"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Delete() *processor {
|
||||
return cs.processors["delete"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Row() *processor {
|
||||
return cs.processors["row"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Raw() *processor {
|
||||
return cs.processors["raw"]
|
||||
}
|
||||
|
||||
func (p *processor) Execute(db *DB) *DB {
|
||||
// call scopes
|
||||
for len(db.Statement.scopes) > 0 {
|
||||
db = db.executeScopes()
|
||||
}
|
||||
|
||||
var (
|
||||
curTime = time.Now()
|
||||
stmt = db.Statement
|
||||
resetBuildClauses bool
|
||||
)
|
||||
|
||||
if len(stmt.BuildClauses) == 0 {
|
||||
stmt.BuildClauses = p.Clauses
|
||||
resetBuildClauses = true
|
||||
}
|
||||
|
||||
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(stmt)
|
||||
}
|
||||
|
||||
// assign model values
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
} else if stmt.Dest == nil {
|
||||
stmt.Dest = stmt.Model
|
||||
}
|
||||
|
||||
// parse model values
|
||||
if stmt.Model != nil {
|
||||
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
|
||||
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
|
||||
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assign stmt.ReflectValue
|
||||
if stmt.Dest != nil {
|
||||
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
|
||||
for stmt.ReflectValue.Kind() == reflect.Ptr {
|
||||
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
|
||||
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
|
||||
}
|
||||
|
||||
stmt.ReflectValue = stmt.ReflectValue.Elem()
|
||||
}
|
||||
if !stmt.ReflectValue.IsValid() {
|
||||
db.AddError(ErrInvalidValue)
|
||||
}
|
||||
}
|
||||
|
||||
for _, f := range p.fns {
|
||||
f(db)
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||
sql, vars := stmt.SQL.String(), stmt.Vars
|
||||
if filter, ok := db.Logger.(ParamsFilter); ok {
|
||||
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
}
|
||||
return db.Dialector.Explain(sql, vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
}
|
||||
|
||||
if !stmt.DB.DryRun {
|
||||
stmt.SQL.Reset()
|
||||
stmt.Vars = nil
|
||||
}
|
||||
|
||||
if resetBuildClauses {
|
||||
stmt.BuildClauses = nil
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (p *processor) Get(name string) func(*DB) {
|
||||
for i := len(p.callbacks) - 1; i >= 0; i-- {
|
||||
if v := p.callbacks[i]; v.name == name && !v.remove {
|
||||
return v.handler
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *processor) Before(name string) *callback {
|
||||
return &callback{before: name, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) After(name string) *callback {
|
||||
return &callback{after: name, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) Match(fc func(*DB) bool) *callback {
|
||||
return &callback{match: fc, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) Register(name string, fn func(*DB)) error {
|
||||
return (&callback{processor: p}).Register(name, fn)
|
||||
}
|
||||
|
||||
func (p *processor) Remove(name string) error {
|
||||
return (&callback{processor: p}).Remove(name)
|
||||
}
|
||||
|
||||
func (p *processor) Replace(name string, fn func(*DB)) error {
|
||||
return (&callback{processor: p}).Replace(name, fn)
|
||||
}
|
||||
|
||||
func (p *processor) compile() (err error) {
|
||||
var callbacks []*callback
|
||||
removedMap := map[string]bool{}
|
||||
for _, callback := range p.callbacks {
|
||||
if callback.match == nil || callback.match(p.db) {
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
if callback.remove {
|
||||
removedMap[callback.name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(removedMap) > 0 {
|
||||
callbacks = removeCallbacks(callbacks, removedMap)
|
||||
}
|
||||
p.callbacks = callbacks
|
||||
|
||||
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *callback) Before(name string) *callback {
|
||||
c.before = name
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *callback) After(name string) *callback {
|
||||
c.after = name
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *callback) Register(name string, fn func(*DB)) error {
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
func (c *callback) Remove(name string) error {
|
||||
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.remove = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.replace = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
func getRIndex(strs []string, str string) int {
|
||||
for i := len(strs) - 1; i >= 0; i-- {
|
||||
if strs[i] == str {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
var (
|
||||
names, sorted []string
|
||||
sortCallback func(*callback) error
|
||||
)
|
||||
sort.SliceStable(cs, func(i, j int) bool {
|
||||
if cs[j].before == "*" && cs[i].before != "*" {
|
||||
return true
|
||||
}
|
||||
if cs[j].after == "*" && cs[i].after != "*" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
for _, c := range cs {
|
||||
// show warning message the callback name already exists
|
||||
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
||||
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
|
||||
}
|
||||
names = append(names, c.name)
|
||||
}
|
||||
|
||||
sortCallback = func(c *callback) error {
|
||||
if c.before != "" { // if defined before callback
|
||||
if c.before == "*" && len(sorted) > 0 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
sorted = append([]string{c.name}, sorted...)
|
||||
}
|
||||
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||
} else if curIdx > sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.before); idx != -1 {
|
||||
// if before callback exists
|
||||
cs[idx].after = c.name
|
||||
}
|
||||
}
|
||||
|
||||
if c.after != "" { // if defined after callback
|
||||
if c.after == "*" && len(sorted) > 0 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
sorted = append(sorted, c.name)
|
||||
}
|
||||
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if after callback sorted, append current callback to last
|
||||
sorted = append(sorted, c.name)
|
||||
} else if curIdx < sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.after); idx != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
// set after callback's before callback to current callback
|
||||
after := cs[idx]
|
||||
|
||||
if after.before == "" {
|
||||
after.before = c.name
|
||||
}
|
||||
|
||||
if err := sortCallback(after); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := sortCallback(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if current callback haven't been sorted, append it to last
|
||||
if getRIndex(sorted, c.name) == -1 {
|
||||
sorted = append(sorted, c.name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, c := range cs {
|
||||
if err = sortCallback(c); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range sorted {
|
||||
if idx := getRIndex(names, name); !cs[idx].remove {
|
||||
fns = append(fns, cs[idx].handler)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
|
||||
callbacks := make([]*callback, 0, len(cs))
|
||||
for _, callback := range cs {
|
||||
if nameMap[callback.name] {
|
||||
continue
|
||||
}
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
return callbacks
|
||||
}
|
||||
453
vendor/gorm.io/gorm/callbacks/associations.go
generated
vendored
Normal file
453
vendor/gorm.io/gorm/callbacks/associations.go
generated
vendored
Normal file
@@ -0,0 +1,453 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
|
||||
|
||||
// Save Belongs To associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
setupReferences := func(obj reflect.Value, elem reflect.Value) {
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
dest[ref.ForeignKey.DBName] = pv
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
dest[rel.Name] = elem.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
rValLen = db.Statement.ReflectValue.Len()
|
||||
objs = make([]reflect.Value, 0, rValLen)
|
||||
fieldType = rel.Field.FieldType
|
||||
isPtr = fieldType.Kind() == reflect.Ptr
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
||||
if !isPtr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
objs = append(objs, obj)
|
||||
elems = reflect.Append(elems, rv)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
|
||||
setupReferences(db.Statement.ReflectValue, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
|
||||
|
||||
// Save Has One associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasOne {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
fieldType = rel.Field.FieldType
|
||||
isPtr = fieldType.Kind() == reflect.Ptr
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
|
||||
}
|
||||
}
|
||||
|
||||
elems = reflect.Append(elems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if f.Kind() != reflect.Ptr {
|
||||
f = f.Addr()
|
||||
}
|
||||
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
|
||||
}
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save Has Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasMany {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
|
||||
}
|
||||
}
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
elems = reflect.Append(elems, elem.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
|
||||
// Save Many2Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
objs := []reflect.Value{}
|
||||
|
||||
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
|
||||
} else {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
|
||||
}
|
||||
}
|
||||
joins = reflect.Append(joins, joinValue)
|
||||
}
|
||||
|
||||
identityMap := map[string]bool{}
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
if !isPtr {
|
||||
elem = elem.Addr()
|
||||
}
|
||||
objs = append(objs, v)
|
||||
elems = reflect.Append(elems, elem)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
}
|
||||
|
||||
// optimize elems of reflect value length
|
||||
if elemLen := elems.Len(); elemLen > 0 {
|
||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
|
||||
}
|
||||
|
||||
for i := 0; i < elemLen; i++ {
|
||||
appendToJoins(objs[i], elems.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
if joins.Len() > 0 {
|
||||
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
DisableNestedTransaction: true,
|
||||
}).Create(joins.Interface()).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
|
||||
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
|
||||
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
|
||||
for _, dbName := range s.PrimaryFieldDBNames {
|
||||
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
|
||||
}
|
||||
|
||||
onConflict.UpdateAll = stmt.DB.FullSaveAssociations
|
||||
if !onConflict.UpdateAll {
|
||||
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
|
||||
}
|
||||
} else {
|
||||
onConflict.DoNothing = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||
// stop save association loop
|
||||
if checkAssociationsSaved(db, rValues) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
selects, omits []string
|
||||
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
|
||||
refName = rel.Name + "."
|
||||
values = rValues.Interface()
|
||||
)
|
||||
|
||||
for name, ok := range selectColumns {
|
||||
columnName := ""
|
||||
if strings.HasPrefix(name, refName) {
|
||||
columnName = strings.TrimPrefix(name, refName)
|
||||
}
|
||||
|
||||
if columnName != "" {
|
||||
if ok {
|
||||
selects = append(selects, columnName)
|
||||
} else {
|
||||
omits = append(omits, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
|
||||
FullSaveAssociations: db.FullSaveAssociations,
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
DisableNestedTransaction: true,
|
||||
})
|
||||
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if tx.Statement.FullSaveAssociations {
|
||||
tx = tx.Set("gorm:update_track_time", true)
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
} else if restricted && len(omits) == 0 {
|
||||
tx = tx.Omit(clause.Associations)
|
||||
}
|
||||
|
||||
if len(omits) > 0 {
|
||||
tx = tx.Omit(omits...)
|
||||
}
|
||||
|
||||
return db.AddError(tx.Create(values).Error)
|
||||
}
|
||||
|
||||
// check association values has been saved
|
||||
// if values kind is Struct, check it has been saved
|
||||
// if values kind is Slice/Array, check all items have been saved
|
||||
var visitMapStoreKey = "gorm:saved_association_map"
|
||||
|
||||
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
|
||||
if visit, ok := db.Get(visitMapStoreKey); ok {
|
||||
if v, ok := visit.(*visitMap); ok {
|
||||
if loadOrStoreVisitMap(v, values) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vistMap := make(visitMap)
|
||||
loadOrStoreVisitMap(&vistMap, values)
|
||||
db.Set(visitMapStoreKey, &vistMap)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
83
vendor/gorm.io/gorm/callbacks/callbacks.go
generated
vendored
Normal file
83
vendor/gorm.io/gorm/callbacks/callbacks.go
generated
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
||||
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
|
||||
updateClauses = []string{"UPDATE", "SET", "WHERE"}
|
||||
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
LastInsertIDReversed bool
|
||||
CreateClauses []string
|
||||
QueryClauses []string
|
||||
UpdateClauses []string
|
||||
DeleteClauses []string
|
||||
}
|
||||
|
||||
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||
enableTransaction := func(db *gorm.DB) bool {
|
||||
return !db.SkipDefaultTransaction
|
||||
}
|
||||
|
||||
if len(config.CreateClauses) == 0 {
|
||||
config.CreateClauses = createClauses
|
||||
}
|
||||
if len(config.QueryClauses) == 0 {
|
||||
config.QueryClauses = queryClauses
|
||||
}
|
||||
if len(config.DeleteClauses) == 0 {
|
||||
config.DeleteClauses = deleteClauses
|
||||
}
|
||||
if len(config.UpdateClauses) == 0 {
|
||||
config.UpdateClauses = updateClauses
|
||||
}
|
||||
|
||||
createCallback := db.Callback().Create()
|
||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
|
||||
createCallback.Register("gorm:create", Create(config))
|
||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
||||
createCallback.Register("gorm:after_create", AfterCreate)
|
||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
createCallback.Clauses = config.CreateClauses
|
||||
|
||||
queryCallback := db.Callback().Query()
|
||||
queryCallback.Register("gorm:query", Query)
|
||||
queryCallback.Register("gorm:preload", Preload)
|
||||
queryCallback.Register("gorm:after_query", AfterQuery)
|
||||
queryCallback.Clauses = config.QueryClauses
|
||||
|
||||
deleteCallback := db.Callback().Delete()
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||
deleteCallback.Register("gorm:delete", Delete(config))
|
||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
deleteCallback.Clauses = config.DeleteClauses
|
||||
|
||||
updateCallback := db.Callback().Update()
|
||||
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
||||
updateCallback.Register("gorm:update", Update(config))
|
||||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
updateCallback.Clauses = config.UpdateClauses
|
||||
|
||||
rowCallback := db.Callback().Row()
|
||||
rowCallback.Register("gorm:row", RowQuery)
|
||||
rowCallback.Clauses = config.QueryClauses
|
||||
|
||||
rawCallback := db.Callback().Raw()
|
||||
rawCallback.Register("gorm:raw", RawExec)
|
||||
rawCallback.Clauses = config.QueryClauses
|
||||
}
|
||||
32
vendor/gorm.io/gorm/callbacks/callmethod.go
generated
vendored
Normal file
32
vendor/gorm.io/gorm/callbacks/callmethod.go
generated
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||
tx := db.Session(&gorm.Session{NewDB: true})
|
||||
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
db.Statement.CurDestIndex = 0
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
|
||||
fc(value.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
return
|
||||
}
|
||||
db.Statement.CurDestIndex++
|
||||
}
|
||||
case reflect.Struct:
|
||||
if db.Statement.ReflectValue.CanAddr() {
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
406
vendor/gorm.io/gorm/callbacks/create.go
generated
vendored
Normal file
406
vendor/gorm.io/gorm/callbacks/create.go
generated
vendored
Normal file
@@ -0,0 +1,406 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// BeforeCreate before create hooks
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeCreate {
|
||||
if i, ok := value.(BeforeCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeCreate(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Create create hook
|
||||
func Create(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
||||
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
isDryRun := !db.DryRun && db.Error == nil
|
||||
if !isDryRun {
|
||||
return
|
||||
}
|
||||
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
if ok {
|
||||
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
|
||||
mode |= gorm.ScanOnConflictDoNothing
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.Statement.ConnPool.QueryContext(
|
||||
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
|
||||
)
|
||||
if db.AddError(err) == nil {
|
||||
defer func() {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, mode)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
result, err := db.Statement.ConnPool.ExecContext(
|
||||
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
|
||||
)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
pkField *schema.Field
|
||||
pkFieldName = "@id"
|
||||
)
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
|
||||
if !insertOk {
|
||||
if !supportReturning {
|
||||
db.AddError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
return
|
||||
}
|
||||
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||
}
|
||||
|
||||
// append @id column with value for auto-increment primary key
|
||||
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||
switch values := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values[pkFieldName] = insertID
|
||||
case *map[string]interface{}:
|
||||
(*values)[pkFieldName] = insertID
|
||||
case []map[string]interface{}, *[]map[string]interface{}:
|
||||
mapValues, ok := values.([]map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||
if *v != nil {
|
||||
mapValues = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.LastInsertIDReversed {
|
||||
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
|
||||
for _, mapValue := range mapValues {
|
||||
if mapValue != nil {
|
||||
mapValue[pkFieldName] = insertID
|
||||
}
|
||||
insertID += schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
default:
|
||||
if pkField == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AfterCreate after create hooks
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterCreate {
|
||||
if i, ok := value.(AfterCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterCreate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToCreateValues convert to create values
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
curTime := stmt.DB.NowFunc()
|
||||
|
||||
switch value := stmt.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, value)
|
||||
case *map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, *value)
|
||||
case []map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
case *[]map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
|
||||
default:
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
_, updateTrackTime = stmt.Get("gorm:update_track_time")
|
||||
isZero bool
|
||||
)
|
||||
stmt.Settings.Delete("gorm:update_track_time")
|
||||
|
||||
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
|
||||
|
||||
for _, db := range stmt.Schema.DBNames {
|
||||
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
rValLen := stmt.ReflectValue.Len()
|
||||
if rValLen == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.SQL.Grow(rValLen * 18)
|
||||
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
|
||||
values.Values = make([][]interface{}, rValLen)
|
||||
|
||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
if !rv.IsValid() {
|
||||
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||
return
|
||||
}
|
||||
|
||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[i][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
|
||||
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
||||
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
|
||||
}
|
||||
defaultValueFieldsHavingValue[field][i] = rvOfvalue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[0][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
|
||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
|
||||
if stmt.Schema != nil && len(values.Columns) >= 1 {
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
|
||||
|
||||
columns := make([]string, 0, len(values.Columns)-1)
|
||||
for _, column := range values.Columns {
|
||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
|
||||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
|
||||
if field.AutoUpdateTime > 0 {
|
||||
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||
switch field.AutoUpdateTime {
|
||||
case schema.UnixNanosecond:
|
||||
assignment.Value = curTime.UnixNano()
|
||||
case schema.UnixMillisecond:
|
||||
assignment.Value = curTime.UnixMilli()
|
||||
case schema.UnixSecond:
|
||||
assignment.Value = curTime.Unix()
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
|
||||
} else {
|
||||
columns = append(columns, column.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
|
||||
if len(onConflict.DoUpdates) == 0 {
|
||||
onConflict.DoNothing = true
|
||||
}
|
||||
|
||||
// use primary fields as default OnConflict columns
|
||||
if len(onConflict.Columns) == 0 {
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
}
|
||||
stmt.AddClause(onConflict)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
195
vendor/gorm.io/gorm/callbacks/delete.go
generated
vendored
Normal file
195
vendor/gorm.io/gorm/callbacks/delete.go
generated
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteBeforeAssociations(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||
if !restricted {
|
||||
return
|
||||
}
|
||||
|
||||
for column, v := range selectColumns {
|
||||
if !v {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, ok := db.Statement.Schema.Relationships.Relations[column]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch rel.Type {
|
||||
case schema.HasOne, schema.HasMany:
|
||||
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
|
||||
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
|
||||
withoutConditions := false
|
||||
if db.Statement.Unscoped {
|
||||
tx = tx.Unscoped()
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
selects := make([]string, 0, len(db.Statement.Selects))
|
||||
for _, s := range db.Statement.Selects {
|
||||
if s == clause.Associations {
|
||||
selects = append(selects, s)
|
||||
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) {
|
||||
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
|
||||
}
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
}
|
||||
}
|
||||
|
||||
for _, cond := range queryConds {
|
||||
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
|
||||
withoutConditions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
queryConds = make([]clause.Expression, 0, len(rel.References))
|
||||
foreignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
relForeignKeys = make([]string, 0, len(rel.References))
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
table = rel.JoinTable.Table
|
||||
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
queryConds = append(queryConds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
|
||||
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
||||
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
||||
|
||||
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func Delete(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(100)
|
||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
if !ok {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
gorm.Scan(rows, db, mode)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
db.AddError(rows.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterDeleteInterface); ok {
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
152
vendor/gorm.io/gorm/callbacks/helper.go
generated
vendored
Normal file
152
vendor/gorm.io/gorm/callbacks/helper.go
generated
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConvertMapToValuesForCreate convert map to values
|
||||
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
|
||||
values.Columns = make([]clause.Column, 0, len(mapValue))
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
|
||||
|
||||
keys := make([]string, 0, len(mapValue))
|
||||
for k := range mapValue {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
value := mapValue[k]
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: k})
|
||||
if len(values.Values) == 0 {
|
||||
values.Values = [][]interface{}{{}}
|
||||
}
|
||||
|
||||
values.Values[0] = append(values.Values[0], value)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
|
||||
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
|
||||
columns := make([]string, 0, len(mapValues))
|
||||
|
||||
// when the length of mapValues is zero,return directly here
|
||||
// no need to call stmt.SelectAndOmitColumns method
|
||||
if len(mapValues) == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
result = make(map[string][]interface{}, len(mapValues))
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
)
|
||||
|
||||
for idx, mapValue := range mapValues {
|
||||
for k, v := range mapValue {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := result[k]; !ok {
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
result[k] = make([]interface{}, len(mapValues))
|
||||
columns = append(columns, k)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
result[k][idx] = v
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(columns)
|
||||
values.Values = make([][]interface{}, len(mapValues))
|
||||
values.Columns = make([]clause.Column, len(columns))
|
||||
for idx, column := range columns {
|
||||
values.Columns[idx] = clause.Column{Name: column}
|
||||
|
||||
for i, v := range result[column] {
|
||||
if len(values.Values[i]) == 0 {
|
||||
values.Values[i] = make([]interface{}, len(columns))
|
||||
}
|
||||
|
||||
values.Values[i][idx] = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
|
||||
if supportReturning {
|
||||
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
|
||||
returning, _ := c.Expression.(clause.Returning)
|
||||
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
|
||||
return true, 0
|
||||
}
|
||||
return true, gorm.ScanUpdate
|
||||
}
|
||||
}
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func checkMissingWhereConditions(db *gorm.DB) {
|
||||
if !db.AllowGlobalUpdate && db.Error == nil {
|
||||
where, withCondition := db.Statement.Clauses["WHERE"]
|
||||
if withCondition {
|
||||
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
|
||||
whereClause, _ := where.Expression.(clause.Where)
|
||||
withCondition = len(whereClause.Exprs) > 1
|
||||
}
|
||||
}
|
||||
if !withCondition {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type visitMap = map[reflect.Value]bool
|
||||
|
||||
// Check if circular values, return true if loaded
|
||||
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
loaded = true
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
|
||||
loaded = false
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Interface:
|
||||
if v.CanAddr() {
|
||||
p := v.Addr()
|
||||
if _, ok := (*visitMap)[p]; ok {
|
||||
return true
|
||||
}
|
||||
(*visitMap)[p] = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
39
vendor/gorm.io/gorm/callbacks/interfaces.go
generated
vendored
Normal file
39
vendor/gorm.io/gorm/callbacks/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package callbacks
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*gorm.DB) error
|
||||
}
|
||||
351
vendor/gorm.io/gorm/callbacks/preload.go
generated
vendored
Normal file
351
vendor/gorm.io/gorm/callbacks/preload.go
generated
vendored
Normal file
@@ -0,0 +1,351 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// parsePreloadMap extracts nested preloads. e.g.
|
||||
//
|
||||
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||
// parsePreloadMap(schema, map[string][]interface{}{
|
||||
// clause.Associations: {"arg1"},
|
||||
// "k1": {"arg2"},
|
||||
// "k2.k3": {"arg3"},
|
||||
// "k4.k5.k6": {"arg4"},
|
||||
// })
|
||||
// // preloadMap is
|
||||
// map[string]map[string][]interface{}{
|
||||
// "k0": {},
|
||||
// "k7": {
|
||||
// "k8": {},
|
||||
// },
|
||||
// "k1": {},
|
||||
// "k2": {
|
||||
// "k3": {"arg3"},
|
||||
// },
|
||||
// "k4": {
|
||||
// "k5.k6": {"arg4"},
|
||||
// },
|
||||
// }
|
||||
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
setPreloadMap := func(name, value string, args []interface{}) {
|
||||
if _, ok := preloadMap[name]; !ok {
|
||||
preloadMap[name] = map[string][]interface{}{}
|
||||
}
|
||||
if value != "" {
|
||||
preloadMap[name][value] = args
|
||||
}
|
||||
}
|
||||
|
||||
for name, args := range preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, relation := range s.Relationships.Relations {
|
||||
if relation.Schema == s {
|
||||
setPreloadMap(relation.Name, value, args)
|
||||
}
|
||||
}
|
||||
|
||||
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||
for _, value := range embeddedValues(embeddedRelations) {
|
||||
setPreloadMap(embedded, value, args)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setPreloadMap(preloadFields[0], value, args)
|
||||
}
|
||||
}
|
||||
return preloadMap
|
||||
}
|
||||
|
||||
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
if embeddedRelations == nil {
|
||||
return nil
|
||||
}
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
names = append(names, embeddedValues(relations)...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||
// If the current relationship is embedded or joined, current query will be ignored.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||
|
||||
// avoid random traversal of the map
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||
for _, join := range joins {
|
||||
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||
joined = true
|
||||
continue
|
||||
}
|
||||
join0, join1, cut := strings.Cut(join, ".")
|
||||
if cut {
|
||||
if _, ok := relationships.Relations[join0]; ok && name == join0 {
|
||||
joined = true
|
||||
nestedJoins = append(nestedJoins, join1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return joined, nestedJoins
|
||||
}
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
if joined, nestedJoins := isJoined(name); joined {
|
||||
switch rv := db.Statement.ReflectValue; rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() > 0 {
|
||||
reflectValue := rel.FieldSchema.MakeSlice().Elem()
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
|
||||
if frv.Kind() != reflect.Ptr {
|
||||
reflectValue = reflect.Append(reflectValue, frv.Addr())
|
||||
} else {
|
||||
if frv.IsNil() {
|
||||
continue
|
||||
}
|
||||
reflectValue = reflect.Append(reflectValue, frv)
|
||||
}
|
||||
}
|
||||
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Pointer:
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return gorm.ErrInvalidData
|
||||
}
|
||||
} else {
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := tx.Statement.Parse(dest); err != nil {
|
||||
tx.AddError(err)
|
||||
return tx
|
||||
}
|
||||
tx.Statement.ReflectValue = reflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
return tx
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
relForeignKeys []string
|
||||
relForeignFields []*schema.Field
|
||||
foreignFields []*schema.Field
|
||||
foreignValues [][]interface{}
|
||||
identityMap = map[string][]reflect.Value{}
|
||||
inlineConds []interface{}
|
||||
)
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
var (
|
||||
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
joinForeignKeys = make([]string, 0, len(rel.References))
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
|
||||
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
}
|
||||
}
|
||||
|
||||
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||
if len(joinForeignValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// convert join identity map to relation identity map
|
||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
|
||||
for i := 0; i < joinResults.Len(); i++ {
|
||||
joinIndexValue := joinResults.Index(i)
|
||||
for idx, field := range joinForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||
}
|
||||
|
||||
for idx, field := range joinRelForeignFields {
|
||||
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||
}
|
||||
|
||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||
joinKey := utils.ToStringKey(joinFieldValues...)
|
||||
identityMap[joinKey] = append(identityMap[joinKey], results...)
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
|
||||
} else {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||
if len(foreignValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// nested preload
|
||||
for p, pvs := range preloads {
|
||||
tx = tx.Preload(p, pvs...)
|
||||
}
|
||||
|
||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
||||
if len(values) != 0 {
|
||||
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
|
||||
|
||||
for _, cond := range conds {
|
||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||
tx = fc(tx)
|
||||
} else {
|
||||
inlineConds = append(inlineConds, cond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(inlineConds) > 0 {
|
||||
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
|
||||
}
|
||||
|
||||
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fieldValues := make([]interface{}, len(relForeignFields))
|
||||
|
||||
// clean up old values before preloading
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
|
||||
default:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
|
||||
default:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < reflectResults.Len(); i++ {
|
||||
elem := reflectResults.Index(i)
|
||||
for idx, field := range relForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
|
||||
}
|
||||
|
||||
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
|
||||
}
|
||||
|
||||
for _, data := range datas {
|
||||
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
|
||||
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
|
||||
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
|
||||
}
|
||||
|
||||
reflectFieldValue = reflect.Indirect(reflectFieldValue)
|
||||
switch reflectFieldValue.Kind() {
|
||||
case reflect.Struct:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
|
||||
case reflect.Slice, reflect.Array:
|
||||
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
|
||||
} else {
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Error
|
||||
}
|
||||
314
vendor/gorm.io/gorm/callbacks/query.go
generated
vendored
Normal file
314
vendor/gorm.io/gorm/callbacks/query.go
generated
vendored
Normal file
@@ -0,0 +1,314 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, 0)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BuildQuerySQL(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.QueryClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(100)
|
||||
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
||||
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
||||
var conds []clause.Expression
|
||||
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
||||
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: conds})
|
||||
}
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
|
||||
for idx, name := range db.Statement.Selects {
|
||||
if db.Statement.Schema == nil {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
|
||||
} else {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
}
|
||||
}
|
||||
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
|
||||
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
|
||||
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
|
||||
for _, dbName := range db.Statement.Schema.DBNames {
|
||||
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
|
||||
}
|
||||
}
|
||||
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
|
||||
queryFields := db.QueryFields
|
||||
if !queryFields {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
|
||||
case reflect.Slice:
|
||||
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
|
||||
}
|
||||
}
|
||||
|
||||
if queryFields {
|
||||
stmt := gorm.Statement{DB: db}
|
||||
// smaller struct
|
||||
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
|
||||
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
|
||||
|
||||
for idx, dbName := range stmt.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inline joins
|
||||
fromClause := clause.From{}
|
||||
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||
fromClause = v
|
||||
}
|
||||
|
||||
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
|
||||
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
var relations []*schema.Relationship
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||
if ok {
|
||||
isRelations = true
|
||||
relations = append(relations, relation)
|
||||
} else {
|
||||
// handle nested join like "Manager.Company"
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
guessNestedRelations = append(guessNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = guessNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if join.Expression != nil {
|
||||
return clause.Join{
|
||||
Type: join.JoinType,
|
||||
Expression: join.Expression,
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clause.Join{
|
||||
Type: joinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
}
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for idx, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
curAliasName := rel.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
|
||||
}
|
||||
|
||||
if _, ok := specifiedRelationsName[curAliasName]; !ok {
|
||||
aliasName := curAliasName
|
||||
if idx == len(relations)-1 && join.Alias != "" {
|
||||
aliasName = join.Alias
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
|
||||
specifiedRelationsName[curAliasName] = aliasName
|
||||
}
|
||||
|
||||
parentTableName = curAliasName
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.AddClause(fromClause)
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clauseSelect)
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
||||
if db.Statement.Schema == nil {
|
||||
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
|
||||
return
|
||||
}
|
||||
|
||||
joins := make([]string, 0, len(db.Statement.Joins))
|
||||
for _, join := range db.Statement.Joins {
|
||||
joins = append(joins, join.Name)
|
||||
}
|
||||
|
||||
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||
if tx.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// clear the joins after query because preload need it
|
||||
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||
fromClause := db.Statement.Clauses["FROM"]
|
||||
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
|
||||
db.Statement.Clauses["FROM"] = fromClause
|
||||
}
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
db.AddError(i.AfterFind(tx))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
22
vendor/gorm.io/gorm/callbacks/raw.go
generated
vendored
Normal file
22
vendor/gorm.io/gorm/callbacks/raw.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RawExec(db *gorm.DB) {
|
||||
if db.Error == nil && !db.DryRun {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
23
vendor/gorm.io/gorm/callbacks/row.go
generated
vendored
Normal file
23
vendor/gorm.io/gorm/callbacks/row.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RowQuery(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
if db.DryRun || db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
|
||||
db.Statement.Settings.Delete("rows")
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
}
|
||||
|
||||
db.RowsAffected = -1
|
||||
}
|
||||
}
|
||||
32
vendor/gorm.io/gorm/callbacks/transaction.go
generated
vendored
Normal file
32
vendor/gorm.io/gorm/callbacks/transaction.go
generated
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func BeginTransaction(db *gorm.DB) {
|
||||
if !db.Config.SkipDefaultTransaction && db.Error == nil {
|
||||
if tx := db.Begin(); tx.Error == nil {
|
||||
db.Statement.ConnPool = tx.Statement.ConnPool
|
||||
db.InstanceSet("gorm:started_transaction", true)
|
||||
} else if tx.Error == gorm.ErrInvalidTransaction {
|
||||
tx.Error = nil
|
||||
} else {
|
||||
db.Error = tx.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CommitOrRollbackTransaction(db *gorm.DB) {
|
||||
if !db.Config.SkipDefaultTransaction {
|
||||
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
|
||||
if db.Error != nil {
|
||||
db.Rollback()
|
||||
} else {
|
||||
db.Commit()
|
||||
}
|
||||
|
||||
db.Statement.ConnPool = db.ConnPool
|
||||
}
|
||||
}
|
||||
}
|
||||
313
vendor/gorm.io/gorm/callbacks/update.go
generated
vendored
Normal file
313
vendor/gorm.io/gorm/callbacks/update.go
generated
vendored
Normal file
@@ -0,0 +1,313 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func SetupUpdateReflectValue(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
||||
}
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeUpdate before update hooks
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeUpdate {
|
||||
if i, ok := value.(BeforeUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeUpdate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Update update hook
|
||||
func Update(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
defer delete(db.Statement.Clauses, "SET")
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
dest := db.Statement.Dest
|
||||
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
|
||||
gorm.Scan(rows, db, mode)
|
||||
db.Statement.Dest = dest
|
||||
db.AddError(rows.Close())
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AfterUpdate after update hooks
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterUpdate {
|
||||
if i, ok := value.(AfterUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterUpdate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToAssignments convert to update assignments
|
||||
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
|
||||
assignValue func(field *schema.Field, value interface{})
|
||||
)
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.Context, stmt.ReflectValue, value)
|
||||
}
|
||||
}
|
||||
default:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
updatingValue := reflect.ValueOf(stmt.Dest)
|
||||
for updatingValue.Kind() == reflect.Ptr {
|
||||
updatingValue = updatingValue.Elem()
|
||||
}
|
||||
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if size := stmt.ReflectValue.Len(); size > 0 {
|
||||
var isZero bool
|
||||
for i := 0; i < size; i++ {
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
||||
if !isZero {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isZero {
|
||||
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch value := updatingValue.Interface().(type) {
|
||||
case map[string]interface{}:
|
||||
set = make([]clause.Assignment, 0, len(value))
|
||||
|
||||
keys := make([]string, 0, len(value))
|
||||
for k := range value {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
kv := value[k]
|
||||
if _, ok := kv.(*gorm.DB); ok {
|
||||
kv = []interface{}{kv}
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
if field.DBName != "" {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
|
||||
}
|
||||
}
|
||||
|
||||
if !stmt.SkipHooks && stmt.Schema != nil {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.LookUpField(dbName)
|
||||
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
|
||||
now := stmt.DB.NowFunc()
|
||||
assignValue(field, now)
|
||||
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
} else {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
updatingSchema := stmt.Schema
|
||||
var isDiffSchema bool
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
// different schema
|
||||
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
||||
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
||||
updatingSchema = updatingStmt.Schema
|
||||
isDiffSchema = true
|
||||
}
|
||||
}
|
||||
|
||||
switch updatingValue.Kind() {
|
||||
case reflect.Struct:
|
||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
if field := updatingSchema.LookUpField(dbName); field != nil {
|
||||
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
||||
value, isZero := field.ValueOf(stmt.Context, updatingValue)
|
||||
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixMilli()
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
} else {
|
||||
value = stmt.DB.NowFunc()
|
||||
}
|
||||
isZero = false
|
||||
}
|
||||
|
||||
if (ok || !isZero) && field.Updatable {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||
assignField := field
|
||||
if isDiffSchema {
|
||||
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
|
||||
assignField = originField
|
||||
}
|
||||
}
|
||||
assignValue(assignField, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
471
vendor/gorm.io/gorm/chainable_api.go
generated
vendored
Normal file
471
vendor/gorm.io/gorm/chainable_api.go
generated
vendored
Normal file
@@ -0,0 +1,471 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
//
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Model = value
|
||||
return
|
||||
}
|
||||
|
||||
// Clauses Add clauses
|
||||
//
|
||||
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
|
||||
// advanced techniques like specifying lock strength and optimizer hints. See the
|
||||
// [docs] for more depth.
|
||||
//
|
||||
// // add a simple limit clause
|
||||
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
|
||||
// // tell the optimizer to use the `idx_user_name` index
|
||||
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
|
||||
// // specify the lock strength to UPDATE
|
||||
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
|
||||
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
var whereConds []interface{}
|
||||
|
||||
for _, cond := range conds {
|
||||
if c, ok := cond.(clause.Interface); ok {
|
||||
tx.Statement.AddClause(c)
|
||||
} else if optimizer, ok := cond.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(tx.Statement)
|
||||
} else {
|
||||
whereConds = append(whereConds, cond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(whereConds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
//
|
||||
// // Get a user
|
||||
// db.Table("users").Take(&result)
|
||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
|
||||
if results[1] != "" {
|
||||
tx.Statement.Table = results[1]
|
||||
} else {
|
||||
tx.Statement.Table = results[2]
|
||||
}
|
||||
}
|
||||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
tx.Statement.Table = tables[1]
|
||||
} else if name != "" {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
tx.Statement.Table = name
|
||||
} else {
|
||||
tx.Statement.TableExpr = nil
|
||||
tx.Statement.Table = ""
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Distinct specify distinct fields that you want querying
|
||||
//
|
||||
// // Select distinct names of users
|
||||
// db.Distinct("name").Find(&results)
|
||||
// // Select distinct name/age pairs from users
|
||||
// db.Distinct("name", "age").Find(&results)
|
||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Distinct = true
|
||||
if len(args) > 0 {
|
||||
tx = tx.Select(args[0], args[1:]...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Select specify fields that you want when querying, creating, updating
|
||||
//
|
||||
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
|
||||
// Select accepts both string arguments and arrays.
|
||||
//
|
||||
// // Select name and age of user using multiple arguments
|
||||
// db.Select("name", "age").Find(&users)
|
||||
// // Select name and age of user using an array
|
||||
// db.Select([]string{"name", "age"}).Find(&users)
|
||||
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
switch v := query.(type) {
|
||||
case []string:
|
||||
tx.Statement.Selects = v
|
||||
|
||||
for _, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg)
|
||||
case []string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
|
||||
clause.Expression = nil
|
||||
tx.Statement.Clauses["SELECT"] = clause
|
||||
}
|
||||
case string:
|
||||
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.Expr{SQL: v, Vars: args},
|
||||
})
|
||||
} else if strings.Count(v, "@") > 0 && len(args) > 0 {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.NamedExpr{SQL: v, Vars: args},
|
||||
})
|
||||
} else {
|
||||
tx.Statement.Selects = []string{v}
|
||||
|
||||
for _, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg)
|
||||
case []string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
|
||||
default:
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.Expr{SQL: v, Vars: args},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
|
||||
clause.Expression = nil
|
||||
tx.Statement.Clauses["SELECT"] = clause
|
||||
}
|
||||
}
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Omit specify fields that you want to ignore when creating, updating and querying
|
||||
func (db *DB) Omit(columns ...string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
|
||||
} else {
|
||||
tx.Statement.Omits = columns
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
|
||||
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.ColumnMapping = m
|
||||
return
|
||||
}
|
||||
|
||||
// Where add conditions
|
||||
//
|
||||
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
|
||||
//
|
||||
// // Find the first user with name jinzhu
|
||||
// db.Where("name = ?", "jinzhu").First(&user)
|
||||
// // Find the first user with name jinzhu and age 20
|
||||
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
|
||||
// // Find the first user with name jinzhu and age not equal to 20
|
||||
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/query.html#Conditions
|
||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: conds})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Not add NOT conditions
|
||||
//
|
||||
// Not works similarly to where, and has the same syntax.
|
||||
//
|
||||
// // Find the first user with name not equal to jinzhu
|
||||
// db.Not("name = ?", "jinzhu").First(&user)
|
||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Or add OR conditions
|
||||
//
|
||||
// Or is used to chain together queries with an OR.
|
||||
//
|
||||
// // Find the first user with name equal to jinzhu or john
|
||||
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
|
||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
//
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.LeftJoin, query, args...)
|
||||
}
|
||||
|
||||
// InnerJoins specify inner joins conditions
|
||||
// db.InnerJoins("Account").Find(&user)
|
||||
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.InnerJoin, query, args...)
|
||||
}
|
||||
|
||||
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(args) == 1 {
|
||||
if db, ok := args[0].(*DB); ok {
|
||||
j := join{
|
||||
Name: query, Conds: args, Selects: db.Statement.Selects,
|
||||
Omits: db.Statement.Omits, JoinType: joinType,
|
||||
}
|
||||
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
j.On = &where
|
||||
}
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, j)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
|
||||
return
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
//
|
||||
// // Select the sum age of users with given names
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
|
||||
func (db *DB) Group(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
//
|
||||
// // Select the sum age of users with name jinzhu
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
|
||||
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
Having: tx.Statement.BuildCondition(query, args...),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Order specify order when retrieving records from database
|
||||
//
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
|
||||
// {Column: clause.Column{Name: "name"}, Desc: true},
|
||||
// {Column: clause.Column{Name: "age"}, Desc: true},
|
||||
// }})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
switch v := value.(type) {
|
||||
case clause.OrderBy:
|
||||
tx.Statement.AddClause(v)
|
||||
case clause.OrderByColumn:
|
||||
tx.Statement.AddClause(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{v},
|
||||
})
|
||||
case string:
|
||||
if v != "" {
|
||||
tx.Statement.AddClause(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{
|
||||
Column: clause.Column{Name: v, Raw: true},
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Limit specify the number of records to be retrieved
|
||||
//
|
||||
// Limit conditions can be cancelled by using `Limit(-1)`.
|
||||
//
|
||||
// // retrieve 3 users
|
||||
// db.Limit(3).Find(&users)
|
||||
// // retrieve 3 users into users1, and all users into users2
|
||||
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
|
||||
func (db *DB) Limit(limit int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Limit: &limit})
|
||||
return
|
||||
}
|
||||
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
//
|
||||
// Offset conditions can be cancelled by using `Offset(-1)`.
|
||||
//
|
||||
// // select the third user
|
||||
// db.Offset(2).First(&user)
|
||||
// // select the first user by cancelling an earlier chained offset
|
||||
// db.Offset(5).Offset(-1).First(&user)
|
||||
func (db *DB) Offset(offset int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Offset: offset})
|
||||
return
|
||||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
||||
//
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
|
||||
return tx
|
||||
}
|
||||
|
||||
func (db *DB) executeScopes() (tx *DB) {
|
||||
scopes := db.Statement.scopes
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
db = scope(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
//
|
||||
// // get all users, and preload all non-cancelled orders
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Preloads == nil {
|
||||
tx.Statement.Preloads = map[string][]interface{}{}
|
||||
}
|
||||
tx.Statement.Preloads[query] = args
|
||||
return
|
||||
}
|
||||
|
||||
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Attrs only adds attributes if the record is not found.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign an email if the record is not found, otherwise ignore provided email
|
||||
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.attrs = attrs
|
||||
return
|
||||
}
|
||||
|
||||
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
|
||||
// records will be updated even if they are found.
|
||||
//
|
||||
// // assign an email regardless of if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.assigns = attrs
|
||||
return
|
||||
}
|
||||
|
||||
// Unscoped disables the global scope of soft deletion in a query.
|
||||
// By default, GORM uses soft deletion, marking records as "deleted"
|
||||
// by setting a timestamp on a specific field (e.g., `deleted_at`).
|
||||
// Unscoped allows queries to include records marked as deleted,
|
||||
// overriding the soft deletion behavior.
|
||||
// Example:
|
||||
//
|
||||
// var users []User
|
||||
// db.Unscoped().Find(&users)
|
||||
// // Retrieves all users, including deleted ones.
|
||||
func (db *DB) Unscoped() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Unscoped = true
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
return
|
||||
}
|
||||
89
vendor/gorm.io/gorm/clause/clause.go
generated
vendored
Normal file
89
vendor/gorm.io/gorm/clause/clause.go
generated
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
package clause
|
||||
|
||||
// Interface clause interface
|
||||
type Interface interface {
|
||||
Name() string
|
||||
Build(Builder)
|
||||
MergeClause(*Clause)
|
||||
}
|
||||
|
||||
// ClauseBuilder clause builder, allows to customize how to build clause
|
||||
type ClauseBuilder func(Clause, Builder)
|
||||
|
||||
type Writer interface {
|
||||
WriteByte(byte) error
|
||||
WriteString(string) (int, error)
|
||||
}
|
||||
|
||||
// Builder builder interface
|
||||
type Builder interface {
|
||||
Writer
|
||||
WriteQuoted(field interface{})
|
||||
AddVar(Writer, ...interface{})
|
||||
AddError(error) error
|
||||
}
|
||||
|
||||
// Clause
|
||||
type Clause struct {
|
||||
Name string // WHERE
|
||||
BeforeExpression Expression
|
||||
AfterNameExpression Expression
|
||||
AfterExpression Expression
|
||||
Expression Expression
|
||||
Builder ClauseBuilder
|
||||
}
|
||||
|
||||
// Build build clause
|
||||
func (c Clause) Build(builder Builder) {
|
||||
if c.Builder != nil {
|
||||
c.Builder(c, builder)
|
||||
} else if c.Expression != nil {
|
||||
if c.BeforeExpression != nil {
|
||||
c.BeforeExpression.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if c.Name != "" {
|
||||
builder.WriteString(c.Name)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if c.AfterNameExpression != nil {
|
||||
c.AfterNameExpression.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
c.Expression.Build(builder)
|
||||
|
||||
if c.AfterExpression != nil {
|
||||
builder.WriteByte(' ')
|
||||
c.AfterExpression.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
PrimaryKey string = "~~~py~~~" // primary key
|
||||
CurrentTable string = "~~~ct~~~" // current table
|
||||
Associations string = "~~~as~~~" // associations
|
||||
)
|
||||
|
||||
var (
|
||||
currentTable = Table{Name: CurrentTable}
|
||||
PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey}
|
||||
)
|
||||
|
||||
// Column quote with name
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
Alias string
|
||||
Raw bool
|
||||
}
|
||||
|
||||
// Table quote with name
|
||||
type Table struct {
|
||||
Name string
|
||||
Alias string
|
||||
Raw bool
|
||||
}
|
||||
23
vendor/gorm.io/gorm/clause/delete.go
generated
vendored
Normal file
23
vendor/gorm.io/gorm/clause/delete.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package clause
|
||||
|
||||
type Delete struct {
|
||||
Modifier string
|
||||
}
|
||||
|
||||
func (d Delete) Name() string {
|
||||
return "DELETE"
|
||||
}
|
||||
|
||||
func (d Delete) Build(builder Builder) {
|
||||
builder.WriteString("DELETE")
|
||||
|
||||
if d.Modifier != "" {
|
||||
builder.WriteByte(' ')
|
||||
builder.WriteString(d.Modifier)
|
||||
}
|
||||
}
|
||||
|
||||
func (d Delete) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
clause.Expression = d
|
||||
}
|
||||
385
vendor/gorm.io/gorm/clause/expression.go
generated
vendored
Normal file
385
vendor/gorm.io/gorm/clause/expression.go
generated
vendored
Normal file
@@ -0,0 +1,385 @@
|
||||
package clause
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Expression expression interface
|
||||
type Expression interface {
|
||||
Build(builder Builder)
|
||||
}
|
||||
|
||||
// NegationExpressionBuilder negation expression builder
|
||||
type NegationExpressionBuilder interface {
|
||||
NegationBuild(builder Builder)
|
||||
}
|
||||
|
||||
// Expr raw expression
|
||||
type Expr struct {
|
||||
SQL string
|
||||
Vars []interface{}
|
||||
WithoutParentheses bool
|
||||
}
|
||||
|
||||
// Build build raw expression
|
||||
func (expr Expr) Build(builder Builder) {
|
||||
var (
|
||||
afterParenthesis bool
|
||||
idx int
|
||||
)
|
||||
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '?' && len(expr.Vars) > idx {
|
||||
if afterParenthesis || expr.WithoutParentheses {
|
||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
} else {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
builder.AddVar(builder, nil)
|
||||
} else {
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
}
|
||||
default:
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
|
||||
idx++
|
||||
} else {
|
||||
if v == '(' {
|
||||
afterParenthesis = true
|
||||
} else {
|
||||
afterParenthesis = false
|
||||
}
|
||||
builder.WriteByte(v)
|
||||
}
|
||||
}
|
||||
|
||||
if idx < len(expr.Vars) {
|
||||
for _, v := range expr.Vars[idx:] {
|
||||
builder.AddVar(builder, sql.NamedArg{Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NamedExpr raw expression for named expr
|
||||
type NamedExpr struct {
|
||||
SQL string
|
||||
Vars []interface{}
|
||||
}
|
||||
|
||||
// Build build raw expression
|
||||
func (expr NamedExpr) Build(builder Builder) {
|
||||
var (
|
||||
idx int
|
||||
inName bool
|
||||
afterParenthesis bool
|
||||
namedMap = make(map[string]interface{}, len(expr.Vars))
|
||||
)
|
||||
|
||||
for _, v := range expr.Vars {
|
||||
switch value := v.(type) {
|
||||
case sql.NamedArg:
|
||||
namedMap[value.Name] = value.Value
|
||||
case map[string]interface{}:
|
||||
for k, v := range value {
|
||||
namedMap[k] = v
|
||||
}
|
||||
default:
|
||||
var appendFieldsToMap func(reflect.Value)
|
||||
appendFieldsToMap = func(reflectValue reflect.Value) {
|
||||
reflectValue = reflect.Indirect(reflectValue)
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
modelType := reflectValue.Type()
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
|
||||
|
||||
if fieldStruct.Anonymous {
|
||||
appendFieldsToMap(reflectValue.Field(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
appendFieldsToMap(reflect.ValueOf(value))
|
||||
}
|
||||
}
|
||||
|
||||
name := make([]byte, 0, 10)
|
||||
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '@' && !inName {
|
||||
inName = true
|
||||
name = name[:0]
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
builder.AddVar(builder, nv)
|
||||
} else {
|
||||
builder.WriteByte('@')
|
||||
builder.WriteString(string(name))
|
||||
}
|
||||
inName = false
|
||||
}
|
||||
|
||||
afterParenthesis = false
|
||||
builder.WriteByte(v)
|
||||
} else if v == '?' && len(expr.Vars) > idx {
|
||||
if afterParenthesis {
|
||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
} else {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
builder.AddVar(builder, nil)
|
||||
} else {
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
}
|
||||
default:
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
|
||||
idx++
|
||||
} else if inName {
|
||||
name = append(name, v)
|
||||
} else {
|
||||
if v == '(' {
|
||||
afterParenthesis = true
|
||||
} else {
|
||||
afterParenthesis = false
|
||||
}
|
||||
builder.WriteByte(v)
|
||||
}
|
||||
}
|
||||
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
builder.AddVar(builder, nv)
|
||||
} else {
|
||||
builder.WriteByte('@')
|
||||
builder.WriteString(string(name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IN Whether a value is within a set of values
|
||||
type IN struct {
|
||||
Column interface{}
|
||||
Values []interface{}
|
||||
}
|
||||
|
||||
func (in IN) Build(builder Builder) {
|
||||
builder.WriteQuoted(in.Column)
|
||||
|
||||
switch len(in.Values) {
|
||||
case 0:
|
||||
builder.WriteString(" IN (NULL)")
|
||||
case 1:
|
||||
if _, ok := in.Values[0].([]interface{}); !ok {
|
||||
builder.WriteString(" = ")
|
||||
builder.AddVar(builder, in.Values[0])
|
||||
break
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
builder.WriteString(" IN (")
|
||||
builder.AddVar(builder, in.Values...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
func (in IN) NegationBuild(builder Builder) {
|
||||
builder.WriteQuoted(in.Column)
|
||||
switch len(in.Values) {
|
||||
case 0:
|
||||
builder.WriteString(" IS NOT NULL")
|
||||
case 1:
|
||||
if _, ok := in.Values[0].([]interface{}); !ok {
|
||||
builder.WriteString(" <> ")
|
||||
builder.AddVar(builder, in.Values[0])
|
||||
break
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
builder.WriteString(" NOT IN (")
|
||||
builder.AddVar(builder, in.Values...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
// Eq equal to for where
|
||||
type Eq struct {
|
||||
Column interface{}
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (eq Eq) Build(builder Builder) {
|
||||
builder.WriteQuoted(eq.Column)
|
||||
|
||||
switch eq.Value.(type) {
|
||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||
rv := reflect.ValueOf(eq.Value)
|
||||
if rv.Len() == 0 {
|
||||
builder.WriteString(" IN (NULL)")
|
||||
} else {
|
||||
builder.WriteString(" IN (")
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
default:
|
||||
if eqNil(eq.Value) {
|
||||
builder.WriteString(" IS NULL")
|
||||
} else {
|
||||
builder.WriteString(" = ")
|
||||
builder.AddVar(builder, eq.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (eq Eq) NegationBuild(builder Builder) {
|
||||
Neq(eq).Build(builder)
|
||||
}
|
||||
|
||||
// Neq not equal to for where
|
||||
type Neq Eq
|
||||
|
||||
func (neq Neq) Build(builder Builder) {
|
||||
builder.WriteQuoted(neq.Column)
|
||||
|
||||
switch neq.Value.(type) {
|
||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||
builder.WriteString(" NOT IN (")
|
||||
rv := reflect.ValueOf(neq.Value)
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
default:
|
||||
if eqNil(neq.Value) {
|
||||
builder.WriteString(" IS NOT NULL")
|
||||
} else {
|
||||
builder.WriteString(" <> ")
|
||||
builder.AddVar(builder, neq.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (neq Neq) NegationBuild(builder Builder) {
|
||||
Eq(neq).Build(builder)
|
||||
}
|
||||
|
||||
// Gt greater than for where
|
||||
type Gt Eq
|
||||
|
||||
func (gt Gt) Build(builder Builder) {
|
||||
builder.WriteQuoted(gt.Column)
|
||||
builder.WriteString(" > ")
|
||||
builder.AddVar(builder, gt.Value)
|
||||
}
|
||||
|
||||
func (gt Gt) NegationBuild(builder Builder) {
|
||||
Lte(gt).Build(builder)
|
||||
}
|
||||
|
||||
// Gte greater than or equal to for where
|
||||
type Gte Eq
|
||||
|
||||
func (gte Gte) Build(builder Builder) {
|
||||
builder.WriteQuoted(gte.Column)
|
||||
builder.WriteString(" >= ")
|
||||
builder.AddVar(builder, gte.Value)
|
||||
}
|
||||
|
||||
func (gte Gte) NegationBuild(builder Builder) {
|
||||
Lt(gte).Build(builder)
|
||||
}
|
||||
|
||||
// Lt less than for where
|
||||
type Lt Eq
|
||||
|
||||
func (lt Lt) Build(builder Builder) {
|
||||
builder.WriteQuoted(lt.Column)
|
||||
builder.WriteString(" < ")
|
||||
builder.AddVar(builder, lt.Value)
|
||||
}
|
||||
|
||||
func (lt Lt) NegationBuild(builder Builder) {
|
||||
Gte(lt).Build(builder)
|
||||
}
|
||||
|
||||
// Lte less than or equal to for where
|
||||
type Lte Eq
|
||||
|
||||
func (lte Lte) Build(builder Builder) {
|
||||
builder.WriteQuoted(lte.Column)
|
||||
builder.WriteString(" <= ")
|
||||
builder.AddVar(builder, lte.Value)
|
||||
}
|
||||
|
||||
func (lte Lte) NegationBuild(builder Builder) {
|
||||
Gt(lte).Build(builder)
|
||||
}
|
||||
|
||||
// Like whether string matches regular expression
|
||||
type Like Eq
|
||||
|
||||
func (like Like) Build(builder Builder) {
|
||||
builder.WriteQuoted(like.Column)
|
||||
builder.WriteString(" LIKE ")
|
||||
builder.AddVar(builder, like.Value)
|
||||
}
|
||||
|
||||
func (like Like) NegationBuild(builder Builder) {
|
||||
builder.WriteQuoted(like.Column)
|
||||
builder.WriteString(" NOT LIKE ")
|
||||
builder.AddVar(builder, like.Value)
|
||||
}
|
||||
|
||||
func eqNil(value interface{}) bool {
|
||||
if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) {
|
||||
value, _ = valuer.Value()
|
||||
}
|
||||
|
||||
return value == nil || eqNilReflect(value)
|
||||
}
|
||||
|
||||
func eqNilReflect(value interface{}) bool {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
|
||||
}
|
||||
37
vendor/gorm.io/gorm/clause/from.go
generated
vendored
Normal file
37
vendor/gorm.io/gorm/clause/from.go
generated
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
package clause
|
||||
|
||||
// From from clause
|
||||
type From struct {
|
||||
Tables []Table
|
||||
Joins []Join
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (from From) Name() string {
|
||||
return "FROM"
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (from From) Build(builder Builder) {
|
||||
if len(from.Tables) > 0 {
|
||||
for idx, table := range from.Tables {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(table)
|
||||
}
|
||||
} else {
|
||||
builder.WriteQuoted(currentTable)
|
||||
}
|
||||
|
||||
for _, join := range from.Joins {
|
||||
builder.WriteByte(' ')
|
||||
join.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge from clause
|
||||
func (from From) MergeClause(clause *Clause) {
|
||||
clause.Expression = from
|
||||
}
|
||||
48
vendor/gorm.io/gorm/clause/group_by.go
generated
vendored
Normal file
48
vendor/gorm.io/gorm/clause/group_by.go
generated
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
package clause
|
||||
|
||||
// GroupBy group by clause
|
||||
type GroupBy struct {
|
||||
Columns []Column
|
||||
Having []Expression
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (groupBy GroupBy) Name() string {
|
||||
return "GROUP BY"
|
||||
}
|
||||
|
||||
// Build build group by clause
|
||||
func (groupBy GroupBy) Build(builder Builder) {
|
||||
for idx, column := range groupBy.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
|
||||
if len(groupBy.Having) > 0 {
|
||||
builder.WriteString(" HAVING ")
|
||||
Where{Exprs: groupBy.Having}.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge group by clause
|
||||
func (groupBy GroupBy) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(GroupBy); ok {
|
||||
copiedColumns := make([]Column, len(v.Columns))
|
||||
copy(copiedColumns, v.Columns)
|
||||
groupBy.Columns = append(copiedColumns, groupBy.Columns...)
|
||||
|
||||
copiedHaving := make([]Expression, len(v.Having))
|
||||
copy(copiedHaving, v.Having)
|
||||
groupBy.Having = append(copiedHaving, groupBy.Having...)
|
||||
}
|
||||
clause.Expression = groupBy
|
||||
|
||||
if len(groupBy.Columns) == 0 {
|
||||
clause.Name = ""
|
||||
} else {
|
||||
clause.Name = groupBy.Name()
|
||||
}
|
||||
}
|
||||
39
vendor/gorm.io/gorm/clause/insert.go
generated
vendored
Normal file
39
vendor/gorm.io/gorm/clause/insert.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package clause
|
||||
|
||||
type Insert struct {
|
||||
Table Table
|
||||
Modifier string
|
||||
}
|
||||
|
||||
// Name insert clause name
|
||||
func (insert Insert) Name() string {
|
||||
return "INSERT"
|
||||
}
|
||||
|
||||
// Build build insert clause
|
||||
func (insert Insert) Build(builder Builder) {
|
||||
if insert.Modifier != "" {
|
||||
builder.WriteString(insert.Modifier)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
builder.WriteString("INTO ")
|
||||
if insert.Table.Name == "" {
|
||||
builder.WriteQuoted(currentTable)
|
||||
} else {
|
||||
builder.WriteQuoted(insert.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge insert clause
|
||||
func (insert Insert) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Insert); ok {
|
||||
if insert.Modifier == "" {
|
||||
insert.Modifier = v.Modifier
|
||||
}
|
||||
if insert.Table.Name == "" {
|
||||
insert.Table = v.Table
|
||||
}
|
||||
}
|
||||
clause.Expression = insert
|
||||
}
|
||||
79
vendor/gorm.io/gorm/clause/joins.go
generated
vendored
Normal file
79
vendor/gorm.io/gorm/clause/joins.go
generated
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
package clause
|
||||
|
||||
import "gorm.io/gorm/utils"
|
||||
|
||||
type JoinType string
|
||||
|
||||
const (
|
||||
CrossJoin JoinType = "CROSS"
|
||||
InnerJoin JoinType = "INNER"
|
||||
LeftJoin JoinType = "LEFT"
|
||||
RightJoin JoinType = "RIGHT"
|
||||
)
|
||||
|
||||
type JoinTarget struct {
|
||||
Type JoinType
|
||||
Association string
|
||||
Subquery Expression
|
||||
Table string
|
||||
}
|
||||
|
||||
func Has(name string) JoinTarget {
|
||||
return JoinTarget{Type: InnerJoin, Association: name}
|
||||
}
|
||||
|
||||
func (jt JoinType) Association(name string) JoinTarget {
|
||||
return JoinTarget{Type: jt, Association: name}
|
||||
}
|
||||
|
||||
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
|
||||
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
|
||||
}
|
||||
|
||||
func (jt JoinTarget) As(name string) JoinTarget {
|
||||
jt.Table = name
|
||||
return jt
|
||||
}
|
||||
|
||||
// Join clause for from
|
||||
type Join struct {
|
||||
Type JoinType
|
||||
Table Table
|
||||
ON Where
|
||||
Using []string
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
func JoinTable(names ...string) Table {
|
||||
return Table{
|
||||
Name: utils.JoinNestedRelationNames(names),
|
||||
}
|
||||
}
|
||||
|
||||
func (join Join) Build(builder Builder) {
|
||||
if join.Expression != nil {
|
||||
join.Expression.Build(builder)
|
||||
} else {
|
||||
if join.Type != "" {
|
||||
builder.WriteString(string(join.Type))
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
builder.WriteString("JOIN ")
|
||||
builder.WriteQuoted(join.Table)
|
||||
|
||||
if len(join.ON.Exprs) > 0 {
|
||||
builder.WriteString(" ON ")
|
||||
join.ON.Build(builder)
|
||||
} else if len(join.Using) > 0 {
|
||||
builder.WriteString(" USING (")
|
||||
for idx, c := range join.Using {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(c)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
46
vendor/gorm.io/gorm/clause/limit.go
generated
vendored
Normal file
46
vendor/gorm.io/gorm/clause/limit.go
generated
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
package clause
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Limit *int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (limit Limit) Name() string {
|
||||
return "LIMIT"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (limit Limit) Build(builder Builder) {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.AddVar(builder, *limit.Limit)
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
builder.WriteString("OFFSET ")
|
||||
builder.AddVar(builder, limit.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (limit Limit) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
|
||||
if v, ok := clause.Expression.(Limit); ok {
|
||||
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
|
||||
limit.Limit = v.Limit
|
||||
}
|
||||
|
||||
if limit.Offset == 0 && v.Offset > 0 {
|
||||
limit.Offset = v.Offset
|
||||
} else if limit.Offset < 0 {
|
||||
limit.Offset = 0
|
||||
}
|
||||
}
|
||||
|
||||
clause.Expression = limit
|
||||
}
|
||||
38
vendor/gorm.io/gorm/clause/locking.go
generated
vendored
Normal file
38
vendor/gorm.io/gorm/clause/locking.go
generated
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
package clause
|
||||
|
||||
const (
|
||||
LockingStrengthUpdate = "UPDATE"
|
||||
LockingStrengthShare = "SHARE"
|
||||
LockingOptionsSkipLocked = "SKIP LOCKED"
|
||||
LockingOptionsNoWait = "NOWAIT"
|
||||
)
|
||||
|
||||
type Locking struct {
|
||||
Strength string
|
||||
Table Table
|
||||
Options string
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (locking Locking) Name() string {
|
||||
return "FOR"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (locking Locking) Build(builder Builder) {
|
||||
builder.WriteString(locking.Strength)
|
||||
if locking.Table.Name != "" {
|
||||
builder.WriteString(" OF ")
|
||||
builder.WriteQuoted(locking.Table)
|
||||
}
|
||||
|
||||
if locking.Options != "" {
|
||||
builder.WriteByte(' ')
|
||||
builder.WriteString(locking.Options)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (locking Locking) MergeClause(clause *Clause) {
|
||||
clause.Expression = locking
|
||||
}
|
||||
59
vendor/gorm.io/gorm/clause/on_conflict.go
generated
vendored
Normal file
59
vendor/gorm.io/gorm/clause/on_conflict.go
generated
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
package clause
|
||||
|
||||
type OnConflict struct {
|
||||
Columns []Column
|
||||
Where Where
|
||||
TargetWhere Where
|
||||
OnConstraint string
|
||||
DoNothing bool
|
||||
DoUpdates Set
|
||||
UpdateAll bool
|
||||
}
|
||||
|
||||
func (OnConflict) Name() string {
|
||||
return "ON CONFLICT"
|
||||
}
|
||||
|
||||
// Build build onConflict clause
|
||||
func (onConflict OnConflict) Build(builder Builder) {
|
||||
if onConflict.OnConstraint != "" {
|
||||
builder.WriteString("ON CONSTRAINT ")
|
||||
builder.WriteString(onConflict.OnConstraint)
|
||||
builder.WriteByte(' ')
|
||||
} else {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
if onConflict.DoNothing {
|
||||
builder.WriteString("DO NOTHING")
|
||||
} else {
|
||||
builder.WriteString("DO UPDATE SET ")
|
||||
onConflict.DoUpdates.Build(builder)
|
||||
}
|
||||
|
||||
if len(onConflict.Where.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.Where.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge onConflict clauses
|
||||
func (onConflict OnConflict) MergeClause(clause *Clause) {
|
||||
clause.Expression = onConflict
|
||||
}
|
||||
54
vendor/gorm.io/gorm/clause/order_by.go
generated
vendored
Normal file
54
vendor/gorm.io/gorm/clause/order_by.go
generated
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
package clause
|
||||
|
||||
type OrderByColumn struct {
|
||||
Column Column
|
||||
Desc bool
|
||||
Reorder bool
|
||||
}
|
||||
|
||||
type OrderBy struct {
|
||||
Columns []OrderByColumn
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (orderBy OrderBy) Name() string {
|
||||
return "ORDER BY"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (orderBy OrderBy) Build(builder Builder) {
|
||||
if orderBy.Expression != nil {
|
||||
orderBy.Expression.Build(builder)
|
||||
} else {
|
||||
for idx, column := range orderBy.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column.Column)
|
||||
if column.Desc {
|
||||
builder.WriteString(" DESC")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (orderBy OrderBy) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(OrderBy); ok {
|
||||
for i := len(orderBy.Columns) - 1; i >= 0; i-- {
|
||||
if orderBy.Columns[i].Reorder {
|
||||
orderBy.Columns = orderBy.Columns[i:]
|
||||
clause.Expression = orderBy
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
copiedColumns := make([]OrderByColumn, len(v.Columns))
|
||||
copy(copiedColumns, v.Columns)
|
||||
orderBy.Columns = append(copiedColumns, orderBy.Columns...)
|
||||
}
|
||||
|
||||
clause.Expression = orderBy
|
||||
}
|
||||
37
vendor/gorm.io/gorm/clause/returning.go
generated
vendored
Normal file
37
vendor/gorm.io/gorm/clause/returning.go
generated
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
package clause
|
||||
|
||||
type Returning struct {
|
||||
Columns []Column
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (returning Returning) Name() string {
|
||||
return "RETURNING"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (returning Returning) Build(builder Builder) {
|
||||
if len(returning.Columns) > 0 {
|
||||
for idx, column := range returning.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
} else {
|
||||
builder.WriteByte('*')
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (returning Returning) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
|
||||
if v.Columns != nil {
|
||||
returning.Columns = append(v.Columns, returning.Columns...)
|
||||
} else {
|
||||
returning.Columns = nil
|
||||
}
|
||||
}
|
||||
clause.Expression = returning
|
||||
}
|
||||
59
vendor/gorm.io/gorm/clause/select.go
generated
vendored
Normal file
59
vendor/gorm.io/gorm/clause/select.go
generated
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
package clause
|
||||
|
||||
// Select select attrs when querying, updating, creating
|
||||
type Select struct {
|
||||
Distinct bool
|
||||
Columns []Column
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
func (s Select) Name() string {
|
||||
return "SELECT"
|
||||
}
|
||||
|
||||
func (s Select) Build(builder Builder) {
|
||||
if len(s.Columns) > 0 {
|
||||
if s.Distinct {
|
||||
builder.WriteString("DISTINCT ")
|
||||
}
|
||||
|
||||
for idx, column := range s.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
} else {
|
||||
builder.WriteByte('*')
|
||||
}
|
||||
}
|
||||
|
||||
func (s Select) MergeClause(clause *Clause) {
|
||||
if s.Expression != nil {
|
||||
if s.Distinct {
|
||||
if expr, ok := s.Expression.(Expr); ok {
|
||||
expr.SQL = "DISTINCT " + expr.SQL
|
||||
clause.Expression = expr
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
clause.Expression = s.Expression
|
||||
} else {
|
||||
clause.Expression = s
|
||||
}
|
||||
}
|
||||
|
||||
// CommaExpression represents a group of expressions separated by commas.
|
||||
type CommaExpression struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (comma CommaExpression) Build(builder Builder) {
|
||||
for idx, expr := range comma.Exprs {
|
||||
if idx > 0 {
|
||||
_, _ = builder.WriteString(", ")
|
||||
}
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
60
vendor/gorm.io/gorm/clause/set.go
generated
vendored
Normal file
60
vendor/gorm.io/gorm/clause/set.go
generated
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
package clause
|
||||
|
||||
import "sort"
|
||||
|
||||
type Set []Assignment
|
||||
|
||||
type Assignment struct {
|
||||
Column Column
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (set Set) Name() string {
|
||||
return "SET"
|
||||
}
|
||||
|
||||
func (set Set) Build(builder Builder) {
|
||||
if len(set) > 0 {
|
||||
for idx, assignment := range set {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(assignment.Column)
|
||||
builder.WriteByte('=')
|
||||
builder.AddVar(builder, assignment.Value)
|
||||
}
|
||||
} else {
|
||||
builder.WriteQuoted(Column{Name: PrimaryKey})
|
||||
builder.WriteByte('=')
|
||||
builder.WriteQuoted(Column{Name: PrimaryKey})
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge assignments clauses
|
||||
func (set Set) MergeClause(clause *Clause) {
|
||||
copiedAssignments := make([]Assignment, len(set))
|
||||
copy(copiedAssignments, set)
|
||||
clause.Expression = Set(copiedAssignments)
|
||||
}
|
||||
|
||||
func Assignments(values map[string]interface{}) Set {
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
assignments := make([]Assignment, len(keys))
|
||||
for idx, key := range keys {
|
||||
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
|
||||
}
|
||||
return assignments
|
||||
}
|
||||
|
||||
func AssignmentColumns(values []string) Set {
|
||||
assignments := make([]Assignment, len(values))
|
||||
for idx, value := range values {
|
||||
assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}}
|
||||
}
|
||||
return assignments
|
||||
}
|
||||
38
vendor/gorm.io/gorm/clause/update.go
generated
vendored
Normal file
38
vendor/gorm.io/gorm/clause/update.go
generated
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
package clause
|
||||
|
||||
type Update struct {
|
||||
Modifier string
|
||||
Table Table
|
||||
}
|
||||
|
||||
// Name update clause name
|
||||
func (update Update) Name() string {
|
||||
return "UPDATE"
|
||||
}
|
||||
|
||||
// Build build update clause
|
||||
func (update Update) Build(builder Builder) {
|
||||
if update.Modifier != "" {
|
||||
builder.WriteString(update.Modifier)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if update.Table.Name == "" {
|
||||
builder.WriteQuoted(currentTable)
|
||||
} else {
|
||||
builder.WriteQuoted(update.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge update clause
|
||||
func (update Update) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Update); ok {
|
||||
if update.Modifier == "" {
|
||||
update.Modifier = v.Modifier
|
||||
}
|
||||
if update.Table.Name == "" {
|
||||
update.Table = v.Table
|
||||
}
|
||||
}
|
||||
clause.Expression = update
|
||||
}
|
||||
45
vendor/gorm.io/gorm/clause/values.go
generated
vendored
Normal file
45
vendor/gorm.io/gorm/clause/values.go
generated
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
package clause
|
||||
|
||||
type Values struct {
|
||||
Columns []Column
|
||||
Values [][]interface{}
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (Values) Name() string {
|
||||
return "VALUES"
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (values Values) Build(builder Builder) {
|
||||
if len(values.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range values.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
|
||||
builder.WriteString(" VALUES ")
|
||||
|
||||
for idx, value := range values.Values {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteByte('(')
|
||||
builder.AddVar(builder, value...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
} else {
|
||||
builder.WriteString("DEFAULT VALUES")
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge values clauses
|
||||
func (values Values) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
clause.Expression = values
|
||||
}
|
||||
245
vendor/gorm.io/gorm/clause/where.go
generated
vendored
Normal file
245
vendor/gorm.io/gorm/clause/where.go
generated
vendored
Normal file
@@ -0,0 +1,245 @@
|
||||
package clause
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
AndWithSpace = " AND "
|
||||
OrWithSpace = " OR "
|
||||
)
|
||||
|
||||
// Where where clause
|
||||
type Where struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (where Where) Name() string {
|
||||
return "WHERE"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (where Where) Build(builder Builder) {
|
||||
if len(where.Exprs) == 1 {
|
||||
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
|
||||
where.Exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
|
||||
// Switch position if the first query expression is a single Or condition
|
||||
for idx, expr := range where.Exprs {
|
||||
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
||||
if idx != 0 {
|
||||
where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
buildExprs(where.Exprs, builder, AndWithSpace)
|
||||
}
|
||||
|
||||
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
||||
wrapInParentheses := false
|
||||
|
||||
for idx, expr := range exprs {
|
||||
if idx > 0 {
|
||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||
builder.WriteString(OrWithSpace)
|
||||
} else {
|
||||
builder.WriteString(joinCond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(exprs) > 1 {
|
||||
switch v := expr.(type) {
|
||||
case OrConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||
}
|
||||
}
|
||||
case AndConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||
}
|
||||
}
|
||||
case Expr:
|
||||
sql := strings.ToUpper(v.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||
case NamedExpr:
|
||||
sql := strings.ToUpper(v.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
expr.Build(builder)
|
||||
builder.WriteByte(')')
|
||||
wrapInParentheses = false
|
||||
} else {
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge where clauses
|
||||
func (where Where) MergeClause(clause *Clause) {
|
||||
if w, ok := clause.Expression.(Where); ok {
|
||||
exprs := make([]Expression, len(w.Exprs)+len(where.Exprs))
|
||||
copy(exprs, w.Exprs)
|
||||
copy(exprs[len(w.Exprs):], where.Exprs)
|
||||
where.Exprs = exprs
|
||||
}
|
||||
|
||||
clause.Expression = where
|
||||
}
|
||||
|
||||
func And(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(exprs) == 1 {
|
||||
if _, ok := exprs[0].(OrConditions); !ok {
|
||||
return exprs[0]
|
||||
}
|
||||
}
|
||||
|
||||
return AndConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type AndConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (and AndConditions) Build(builder Builder) {
|
||||
if len(and.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
buildExprs(and.Exprs, builder, AndWithSpace)
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
buildExprs(and.Exprs, builder, AndWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
func Or(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return OrConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type OrConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (or OrConditions) Build(builder Builder) {
|
||||
if len(or.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
buildExprs(or.Exprs, builder, OrWithSpace)
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
buildExprs(or.Exprs, builder, OrWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
func Not(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(exprs) == 1 {
|
||||
if andCondition, ok := exprs[0].(AndConditions); ok {
|
||||
exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
return NotConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type NotConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (not NotConditions) Build(builder Builder) {
|
||||
anyNegationBuilder := false
|
||||
for _, c := range not.Exprs {
|
||||
if _, ok := c.(NegationExpressionBuilder); ok {
|
||||
anyNegationBuilder = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if anyNegationBuilder {
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
switch c.(type) {
|
||||
case OrConditions:
|
||||
builder.WriteString(OrWithSpace)
|
||||
default:
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
3
vendor/gorm.io/gorm/clause/with.go
generated
vendored
Normal file
3
vendor/gorm.io/gorm/clause/with.go
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
package clause
|
||||
|
||||
type With struct{}
|
||||
54
vendor/gorm.io/gorm/errors.go
generated
vendored
Normal file
54
vendor/gorm.io/gorm/errors.go
generated
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRecordNotFound record not found error
|
||||
ErrRecordNotFound = logger.ErrRecordNotFound
|
||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||
ErrInvalidTransaction = errors.New("invalid transaction")
|
||||
// ErrNotImplemented not implemented
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
// ErrMissingWhereClause missing where clause
|
||||
ErrMissingWhereClause = errors.New("WHERE conditions required")
|
||||
// ErrUnsupportedRelation unsupported relations
|
||||
ErrUnsupportedRelation = errors.New("unsupported relations")
|
||||
// ErrPrimaryKeyRequired primary keys required
|
||||
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||
// ErrModelValueRequired model value required
|
||||
ErrModelValueRequired = errors.New("model value required")
|
||||
// ErrModelAccessibleFieldsRequired model accessible fields required
|
||||
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
|
||||
// ErrSubQueryRequired sub query required
|
||||
ErrSubQueryRequired = errors.New("sub query required")
|
||||
// ErrInvalidData unsupported data
|
||||
ErrInvalidData = errors.New("unsupported data")
|
||||
// ErrUnsupportedDriver unsupported driver
|
||||
ErrUnsupportedDriver = errors.New("unsupported driver")
|
||||
// ErrRegistered registered
|
||||
ErrRegistered = errors.New("registered")
|
||||
// ErrInvalidField invalid field
|
||||
ErrInvalidField = errors.New("invalid field")
|
||||
// ErrEmptySlice empty slice found
|
||||
ErrEmptySlice = errors.New("empty slice found")
|
||||
// ErrDryRunModeUnsupported dry run mode unsupported
|
||||
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
|
||||
// ErrInvalidDB invalid db
|
||||
ErrInvalidDB = errors.New("invalid db")
|
||||
// ErrInvalidValue invalid value
|
||||
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
|
||||
// ErrInvalidValueOfLength invalid values do not match length
|
||||
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
||||
// ErrPreloadNotAllowed preload is not allowed when count is used
|
||||
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
||||
// ErrDuplicatedKey occurs when there is a unique key constraint violation
|
||||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||
// ErrCheckConstraintViolated occurs when there is a check constraint violation
|
||||
ErrCheckConstraintViolated = errors.New("violates check constraint")
|
||||
)
|
||||
780
vendor/gorm.io/gorm/finisher_api.go
generated
vendored
Normal file
780
vendor/gorm.io/gorm/finisher_api.go
generated
vendored
Normal file
@@ -0,0 +1,780 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Create inserts value, returning the inserted data's primary key in value's id
|
||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||
if db.CreateBatchSize > 0 {
|
||||
return db.CreateInBatches(value, db.CreateBatchSize)
|
||||
}
|
||||
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
return tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
|
||||
// CreateInBatches inserts value in batches of batchSize
|
||||
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var rowsAffected int64
|
||||
tx = db.getInstance()
|
||||
|
||||
// the reflection length judgment of the optimized value
|
||||
reflectLen := reflectValue.Len()
|
||||
|
||||
callFc := func(tx *DB) error {
|
||||
for i := 0; i < reflectLen; i += batchSize {
|
||||
ends := i + batchSize
|
||||
if ends > reflectLen {
|
||||
ends = reflectLen
|
||||
}
|
||||
|
||||
subtx := tx.getInstance()
|
||||
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
|
||||
subtx.callbacks.Create().Execute(subtx)
|
||||
if subtx.Error != nil {
|
||||
return subtx.Error
|
||||
}
|
||||
rowsAffected += subtx.RowsAffected
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
|
||||
tx.AddError(callFc(tx.Session(&Session{})))
|
||||
} else {
|
||||
tx.AddError(tx.Transaction(callFc))
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
default:
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
tx = tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
|
||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflect.Indirect(reflectValue)
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
||||
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
|
||||
}
|
||||
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
|
||||
case reflect.Struct:
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
|
||||
return tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
selectedUpdate := len(tx.Statement.Selects) != 0
|
||||
// when updating, use all fields including those zero-value fields
|
||||
if !selectedUpdate {
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||
}
|
||||
|
||||
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
|
||||
|
||||
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
|
||||
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
|
||||
}
|
||||
|
||||
return updateTx
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// First finds the first record ordered by primary key, matching given conditions conds
|
||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Take finds the first record returned by the database in no specified order, matching given conditions conds
|
||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1)
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Last finds the last record ordered by primary key, matching given conditions conds
|
||||
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
Desc: true,
|
||||
})
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Find finds all records matching given conditions conds
|
||||
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// FindInBatches finds all records in batches of batchSize
|
||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||
var (
|
||||
tx = db.Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
}).Session(&Session{})
|
||||
queryDB = tx
|
||||
rowsAffected int64
|
||||
batch int
|
||||
)
|
||||
|
||||
// user specified offset or limit
|
||||
var totalSize int
|
||||
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
if limit.Limit != nil {
|
||||
totalSize = *limit.Limit
|
||||
}
|
||||
|
||||
if totalSize > 0 && batchSize > totalSize {
|
||||
batchSize = totalSize
|
||||
}
|
||||
|
||||
// reset to offset to 0 in next batch
|
||||
tx = tx.Offset(-1).Session(&Session{})
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
result := queryDB.Limit(batchSize).Find(dest)
|
||||
rowsAffected += result.RowsAffected
|
||||
batch++
|
||||
|
||||
if result.Error == nil && result.RowsAffected != 0 {
|
||||
fcTx := result.Session(&Session{NewDB: true})
|
||||
fcTx.RowsAffected = result.RowsAffected
|
||||
tx.AddError(fc(fcTx, batch))
|
||||
} else if result.Error != nil {
|
||||
tx.AddError(result.Error)
|
||||
}
|
||||
|
||||
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
||||
break
|
||||
}
|
||||
|
||||
if totalSize > 0 {
|
||||
if totalSize <= int(rowsAffected) {
|
||||
break
|
||||
}
|
||||
if totalSize/batchSize == batch {
|
||||
batchSize = totalSize % batchSize
|
||||
}
|
||||
}
|
||||
|
||||
// Optimize for-break
|
||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
}
|
||||
|
||||
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||
if zero {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
}
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
return tx
|
||||
}
|
||||
|
||||
func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
for _, value := range values {
|
||||
switch v := value.(type) {
|
||||
case []clause.Expression:
|
||||
for _, expr := range v {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil {
|
||||
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
case clause.Column:
|
||||
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
}
|
||||
} else if andCond, ok := expr.(clause.AndConditions); ok {
|
||||
db.assignInterfacesToValue(andCond.Exprs)
|
||||
}
|
||||
}
|
||||
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
|
||||
db.assignInterfacesToValue(exprs)
|
||||
}
|
||||
default:
|
||||
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, f := range s.Fields {
|
||||
if f.Readable {
|
||||
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
|
||||
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
|
||||
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(values) > 0 {
|
||||
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
|
||||
db.assignInterfacesToValue(exprs)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
|
||||
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
|
||||
result := queryTx.Find(dest, conds...)
|
||||
if result.Error != nil {
|
||||
tx.Error = result.Error
|
||||
return tx
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
result.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.attrs) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.assigns) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for i := 0; i < len(exprs); i++ {
|
||||
expr := exprs[i]
|
||||
|
||||
if eq, ok := expr.(clause.AndConditions); ok {
|
||||
exprs = append(exprs, eq.Exprs...)
|
||||
} else if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
case clause.Column:
|
||||
assigns[column.Name] = eq.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Model(dest).Updates(assigns)
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Updates(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
tx.Statement.SkipHooks = true
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
tx.Statement.SkipHooks = true
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
|
||||
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
|
||||
// time if null.
|
||||
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.Dest = value
|
||||
return tx.callbacks.Delete().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) Count(count *int64) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model == nil {
|
||||
tx.Statement.Model = tx.Statement.Dest
|
||||
defer func() {
|
||||
tx.Statement.Model = nil
|
||||
}()
|
||||
}
|
||||
|
||||
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
|
||||
defer func() {
|
||||
tx.Statement.Clauses["SELECT"] = selectClause
|
||||
}()
|
||||
} else {
|
||||
defer delete(tx.Statement.Clauses, "SELECT")
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
|
||||
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
|
||||
expr := clause.Expr{SQL: "count(*)"}
|
||||
|
||||
if len(tx.Statement.Selects) == 1 {
|
||||
dbName := tx.Statement.Selects[0]
|
||||
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
|
||||
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
||||
dbName = f.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if tx.Statement.Distinct {
|
||||
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
} else if dbName != "*" {
|
||||
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.AddClause(clause.Select{Expression: expr})
|
||||
}
|
||||
|
||||
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
|
||||
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
|
||||
delete(tx.Statement.Clauses, "ORDER BY")
|
||||
defer func() {
|
||||
tx.Statement.Clauses["ORDER BY"] = orderByClause
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Dest = count
|
||||
tx = tx.callbacks.Query().Execute(tx)
|
||||
|
||||
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
|
||||
*count = tx.RowsAffected
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
tx := db.getInstance().Set("rows", false)
|
||||
tx = tx.callbacks.Row().Execute(tx)
|
||||
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||
if !ok && tx.DryRun {
|
||||
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
tx := db.getInstance().Set("rows", true)
|
||||
tx = tx.callbacks.Row().Execute(tx)
|
||||
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||
if !ok && tx.DryRun && tx.Error == nil {
|
||||
tx.Error = ErrDryRunModeUnsupported
|
||||
}
|
||||
return rows, tx.Error
|
||||
}
|
||||
|
||||
// Scan scans selected value to the struct dest
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
config := *db.Config
|
||||
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
||||
config.Logger = newLogger
|
||||
|
||||
tx = db.getInstance()
|
||||
tx.Config = &config
|
||||
|
||||
if rows, err := tx.Rows(); err == nil {
|
||||
if rows.Next() {
|
||||
tx.ScanRows(rows, dest)
|
||||
} else {
|
||||
tx.RowsAffected = 0
|
||||
tx.AddError(rows.Err())
|
||||
}
|
||||
tx.AddError(rows.Close())
|
||||
}
|
||||
|
||||
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
|
||||
return newLogger.SQL, tx.RowsAffected
|
||||
}, tx.Error)
|
||||
tx.Logger = currentLogger
|
||||
return
|
||||
}
|
||||
|
||||
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
|
||||
//
|
||||
// var ages []int64
|
||||
// db.Model(&users).Pluck("age", &ages)
|
||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model != nil {
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(column); f != nil {
|
||||
column = f.DBName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) != 1 {
|
||||
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||
Distinct: tx.Statement.Distinct,
|
||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||
})
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||
tx := db.getInstance()
|
||||
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
|
||||
tx.AddError(err)
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
||||
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
elem := tx.Statement.ReflectValue.Elem()
|
||||
if !elem.IsValid() {
|
||||
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
|
||||
tx.Statement.ReflectValue.Set(elem)
|
||||
}
|
||||
tx.Statement.ReflectValue = elem
|
||||
}
|
||||
Scan(rows, tx, ScanInitialized)
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
|
||||
// returned to the connection pool.
|
||||
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
||||
if db.Error != nil {
|
||||
return db.Error
|
||||
}
|
||||
|
||||
tx := db.getInstance()
|
||||
sqlDB, err := tx.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := sqlDB.Conn(tx.Statement.Context)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
tx.Statement.ConnPool = conn
|
||||
return fc(tx)
|
||||
}
|
||||
|
||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
||||
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
||||
// they are rolled back.
|
||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||
panicked := true
|
||||
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
// nested transaction
|
||||
if !db.DisableNestedTransaction {
|
||||
spID := new(maphash.Hash).Sum64()
|
||||
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
db.RollbackTo(fmt.Sprintf("sp%d", spID))
|
||||
}
|
||||
}()
|
||||
}
|
||||
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
|
||||
} else {
|
||||
tx := db.Begin(opts...)
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = fc(tx); err == nil {
|
||||
panicked = false
|
||||
return tx.Commit().Error
|
||||
}
|
||||
}
|
||||
|
||||
panicked = false
|
||||
return
|
||||
}
|
||||
|
||||
// Begin begins a transaction with any transaction options opts
|
||||
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
var (
|
||||
// clone statement
|
||||
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
|
||||
opt *sql.TxOptions
|
||||
err error
|
||||
)
|
||||
|
||||
if len(opts) > 0 {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
ctx := tx.Statement.Context
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
if db.Config.DefaultTransactionTimeout > 0 {
|
||||
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
switch beginner := tx.Statement.ConnPool.(type) {
|
||||
case TxBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||
case ConnPoolBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||
default:
|
||||
err = ErrInvalidTransaction
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
tx.AddError(err)
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// Commit commits the changes in a transaction
|
||||
func (db *DB) Commit() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Commit())
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Rollback rollbacks the changes in a transaction
|
||||
func (db *DB) Rollback() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
if !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Rollback())
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) SavePoint(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.SavePoint(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) RollbackTo(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because RollbackTo not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.RollbackTo(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Exec executes raw sql
|
||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
|
||||
return tx.callbacks.Raw().Execute(tx)
|
||||
}
|
||||
605
vendor/gorm.io/gorm/generics.go
generated
vendored
Normal file
605
vendor/gorm.io/gorm/generics.go
generated
vendored
Normal file
@@ -0,0 +1,605 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type result struct {
|
||||
Result sql.Result
|
||||
RowsAffected int64
|
||||
}
|
||||
|
||||
func (info *result) ModifyStatement(stmt *Statement) {
|
||||
stmt.Result = info
|
||||
}
|
||||
|
||||
// Build implements clause.Expression interface
|
||||
func (result) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func WithResult() *result {
|
||||
return &result{}
|
||||
}
|
||||
|
||||
type Interface[T any] interface {
|
||||
Raw(sql string, values ...interface{}) ExecInterface[T]
|
||||
Exec(ctx context.Context, sql string, values ...interface{}) error
|
||||
CreateInterface[T]
|
||||
}
|
||||
|
||||
type CreateInterface[T any] interface {
|
||||
ChainInterface[T]
|
||||
Table(name string, args ...interface{}) CreateInterface[T]
|
||||
Create(ctx context.Context, r *T) error
|
||||
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
|
||||
}
|
||||
|
||||
type ChainInterface[T any] interface {
|
||||
ExecInterface[T]
|
||||
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
|
||||
Where(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Not(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Limit(offset int) ChainInterface[T]
|
||||
Offset(offset int) ChainInterface[T]
|
||||
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
|
||||
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
|
||||
Select(query string, args ...interface{}) ChainInterface[T]
|
||||
Omit(columns ...string) ChainInterface[T]
|
||||
MapColumns(m map[string]string) ChainInterface[T]
|
||||
Distinct(args ...interface{}) ChainInterface[T]
|
||||
Group(name string) ChainInterface[T]
|
||||
Having(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Order(value interface{}) ChainInterface[T]
|
||||
|
||||
Build(builder clause.Builder)
|
||||
|
||||
Delete(ctx context.Context) (rowsAffected int, err error)
|
||||
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
|
||||
Updates(ctx context.Context, t T) (rowsAffected int, err error)
|
||||
Count(ctx context.Context, column string) (result int64, err error)
|
||||
}
|
||||
|
||||
type ExecInterface[T any] interface {
|
||||
Scan(ctx context.Context, r interface{}) error
|
||||
First(context.Context) (T, error)
|
||||
Last(ctx context.Context) (T, error)
|
||||
Take(context.Context) (T, error)
|
||||
Find(ctx context.Context) ([]T, error)
|
||||
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
|
||||
Row(ctx context.Context) *sql.Row
|
||||
Rows(ctx context.Context) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type JoinBuilder interface {
|
||||
Select(...string) JoinBuilder
|
||||
Omit(...string) JoinBuilder
|
||||
Where(query interface{}, args ...interface{}) JoinBuilder
|
||||
Not(query interface{}, args ...interface{}) JoinBuilder
|
||||
Or(query interface{}, args ...interface{}) JoinBuilder
|
||||
}
|
||||
|
||||
type PreloadBuilder interface {
|
||||
Select(...string) PreloadBuilder
|
||||
Omit(...string) PreloadBuilder
|
||||
Where(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Not(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Or(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Limit(offset int) PreloadBuilder
|
||||
Offset(offset int) PreloadBuilder
|
||||
Order(value interface{}) PreloadBuilder
|
||||
LimitPerRecord(num int) PreloadBuilder
|
||||
}
|
||||
|
||||
type op func(*DB) *DB
|
||||
|
||||
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||
v := &g[T]{
|
||||
db: db,
|
||||
ops: make([]op, 0, 5),
|
||||
}
|
||||
|
||||
if len(opts) > 0 {
|
||||
v.ops = append(v.ops, func(db *DB) *DB {
|
||||
return db.Clauses(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
v.createG = &createG[T]{
|
||||
chainG: chainG[T]{
|
||||
execG: execG[T]{g: v},
|
||||
},
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
type g[T any] struct {
|
||||
*createG[T]
|
||||
db *DB
|
||||
ops []op
|
||||
}
|
||||
|
||||
func (g *g[T]) apply(ctx context.Context) *DB {
|
||||
db := g.db
|
||||
if !db.DryRun {
|
||||
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
||||
}
|
||||
|
||||
for _, op := range g.ops {
|
||||
db = op(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
|
||||
return execG[T]{g: &g[T]{
|
||||
db: c.db,
|
||||
ops: append(c.ops, func(db *DB) *DB {
|
||||
return db.Raw(sql, values...)
|
||||
}),
|
||||
}}
|
||||
}
|
||||
|
||||
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
|
||||
return c.apply(ctx).Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
type createG[T any] struct {
|
||||
chainG[T]
|
||||
}
|
||||
|
||||
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
|
||||
return createG[T]{c.with(func(db *DB) *DB {
|
||||
return db.Table(name, args...)
|
||||
})}
|
||||
}
|
||||
|
||||
func (c createG[T]) Create(ctx context.Context, r *T) error {
|
||||
return c.g.apply(ctx).Create(r).Error
|
||||
}
|
||||
|
||||
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
|
||||
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
|
||||
}
|
||||
|
||||
type chainG[T any] struct {
|
||||
execG[T]
|
||||
}
|
||||
|
||||
func (c chainG[T]) getInstance() *DB {
|
||||
var r T
|
||||
return c.g.apply(context.Background()).Model(r).getInstance()
|
||||
}
|
||||
|
||||
func (c chainG[T]) with(v op) chainG[T] {
|
||||
return chainG[T]{
|
||||
execG: execG[T]{g: &g[T]{
|
||||
db: c.g.db,
|
||||
ops: append(append([]op(nil), c.g.ops...), v),
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
for _, fc := range scopes {
|
||||
fc(db.Statement)
|
||||
}
|
||||
return db
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Table(name, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Where(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Not(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Or(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Limit(offset)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Offset(offset)
|
||||
})
|
||||
}
|
||||
|
||||
type joinBuilder struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
type preloadBuilder struct {
|
||||
limitPerRecord int
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||
q.db.Limit(limit)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||
q.db.Offset(offset)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||
q.db.Order(value)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
|
||||
q.limitPerRecord = num
|
||||
return q
|
||||
}
|
||||
|
||||
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
if jt.Table == "" {
|
||||
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
||||
}
|
||||
|
||||
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
|
||||
if on != nil {
|
||||
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
j := join{
|
||||
Name: jt.Association,
|
||||
Alias: jt.Table,
|
||||
Selects: q.db.Statement.Selects,
|
||||
Omits: q.db.Statement.Omits,
|
||||
JoinType: jt.Type,
|
||||
}
|
||||
|
||||
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
j.On = &where
|
||||
}
|
||||
|
||||
if jt.Subquery != nil {
|
||||
joinType := j.JoinType
|
||||
if joinType == "" {
|
||||
joinType = clause.LeftJoin
|
||||
}
|
||||
|
||||
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
|
||||
stmt := db.getInstance().Statement
|
||||
if len(j.Selects) == 0 {
|
||||
j.Selects = stmt.Selects
|
||||
}
|
||||
if len(j.Omits) == 0 {
|
||||
j.Omits = stmt.Omits
|
||||
}
|
||||
}
|
||||
|
||||
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
|
||||
|
||||
if j.On != nil {
|
||||
expr.SQL += " ON ?"
|
||||
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
|
||||
}
|
||||
|
||||
j.Expression = expr
|
||||
}
|
||||
|
||||
db.Statement.Joins = append(db.Statement.Joins, j)
|
||||
sort.Slice(db.Statement.Joins, func(i, j int) bool {
|
||||
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
|
||||
})
|
||||
return db
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Select(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Omit(columns...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.MapColumns(m)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Distinct(args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Group(name string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Group(name)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Having(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Order(value)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Preload(association, func(tx *DB) *DB {
|
||||
q := preloadBuilder{db: tx.getInstance()}
|
||||
if query != nil {
|
||||
if err := query(&q); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
||||
if !ok {
|
||||
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
|
||||
relationships := db.Statement.Schema.Relationships
|
||||
for _, field := range preloadFields {
|
||||
var ok bool
|
||||
relation, ok = relationships.Relations[field]
|
||||
if ok {
|
||||
relationships = relation.FieldSchema.Relationships
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if q.limitPerRecord > 0 {
|
||||
if relation.JoinTable != nil {
|
||||
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
|
||||
return tx
|
||||
}
|
||||
|
||||
refColumns := []clause.Column{}
|
||||
for _, rel := range relation.References {
|
||||
if rel.OwnPrimaryKey {
|
||||
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
|
||||
}
|
||||
}
|
||||
|
||||
if len(refColumns) != 0 {
|
||||
selectExpr := clause.CommaExpression{}
|
||||
for _, column := range q.db.Statement.Selects {
|
||||
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
|
||||
}
|
||||
|
||||
if len(selectExpr.Exprs) == 0 {
|
||||
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
|
||||
}
|
||||
|
||||
partitionBy := clause.CommaExpression{}
|
||||
for _, column := range refColumns {
|
||||
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
|
||||
}
|
||||
|
||||
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
|
||||
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
|
||||
vars := []interface{}{partitionBy}
|
||||
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
|
||||
vars = append(vars, orderBy)
|
||||
} else {
|
||||
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
|
||||
}})
|
||||
}
|
||||
vars = append(vars, rnnColumn)
|
||||
|
||||
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
|
||||
|
||||
q.db.Clauses(clause.Select{Expression: selectExpr})
|
||||
|
||||
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
|
||||
}
|
||||
}
|
||||
|
||||
return q.db
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
|
||||
r := new(T)
|
||||
res := c.g.apply(ctx).Delete(r)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
|
||||
var r T
|
||||
res := c.g.apply(ctx).Model(r).Update(name, value)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
|
||||
res := c.g.apply(ctx).Updates(t)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
|
||||
var r T
|
||||
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (c chainG[T]) Build(builder clause.Builder) {
|
||||
subdb := c.getInstance()
|
||||
subdb.Logger = logger.Discard
|
||||
subdb.DryRun = true
|
||||
|
||||
if stmt, ok := builder.(*Statement); ok {
|
||||
if subdb.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
vars = subdb.Statement.Vars
|
||||
sql = subdb.Statement.SQL.String()
|
||||
)
|
||||
|
||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||
for _, vv := range vars {
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||
bindvar := strings.Builder{}
|
||||
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
subdb.Statement.SQL.Reset()
|
||||
subdb.Statement.Vars = stmt.Vars
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
}
|
||||
} else {
|
||||
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
}
|
||||
|
||||
builder.WriteString(subdb.Statement.SQL.String())
|
||||
stmt.Vars = subdb.Statement.Vars
|
||||
}
|
||||
}
|
||||
|
||||
type execG[T any] struct {
|
||||
g *g[T]
|
||||
}
|
||||
|
||||
func (g execG[T]) First(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).First(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Model(r).Find(&result).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (g execG[T]) Last(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Last(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Take(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Take(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
|
||||
var r []T
|
||||
err := g.g.apply(ctx).Find(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
|
||||
var data []T
|
||||
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
|
||||
return fc(data, batch)
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (g execG[T]) Row(ctx context.Context) *sql.Row {
|
||||
return g.g.apply(ctx).Row()
|
||||
}
|
||||
|
||||
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
|
||||
return g.g.apply(ctx).Rows()
|
||||
}
|
||||
536
vendor/gorm.io/gorm/gorm.go
generated
vendored
Normal file
536
vendor/gorm.io/gorm/gorm.go
generated
vendored
Normal file
@@ -0,0 +1,536 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// for Config.cacheStore store PreparedStmtDB key
|
||||
const preparedStmtDBKey = "preparedStmt"
|
||||
|
||||
// Config GORM config
|
||||
type Config struct {
|
||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||
SkipDefaultTransaction bool
|
||||
DefaultTransactionTimeout time.Duration
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
// FullSaveAssociations full save associations
|
||||
FullSaveAssociations bool
|
||||
// Logger
|
||||
Logger logger.Interface
|
||||
// NowFunc the function to be used when creating a new timestamp
|
||||
NowFunc func() time.Time
|
||||
// DryRun generate sql without execute
|
||||
DryRun bool
|
||||
// PrepareStmt executes the given query in cached statement
|
||||
PrepareStmt bool
|
||||
// PrepareStmt cache support LRU expired,
|
||||
// default maxsize=int64 Max value and ttl=1h
|
||||
PrepareStmtMaxSize int
|
||||
PrepareStmtTTL time.Duration
|
||||
|
||||
// DisableAutomaticPing
|
||||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
DisableForeignKeyConstraintWhenMigrating bool
|
||||
// IgnoreRelationshipsWhenMigrating
|
||||
IgnoreRelationshipsWhenMigrating bool
|
||||
// DisableNestedTransaction disable nested transaction
|
||||
DisableNestedTransaction bool
|
||||
// AllowGlobalUpdate allow global update
|
||||
AllowGlobalUpdate bool
|
||||
// QueryFields executes the SQL query with all fields of the table
|
||||
QueryFields bool
|
||||
// CreateBatchSize default create batch size
|
||||
CreateBatchSize int
|
||||
// TranslateError enabling error translation
|
||||
TranslateError bool
|
||||
// PropagateUnscoped propagate Unscoped to every other nested statement
|
||||
PropagateUnscoped bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
// ConnPool db conn pool
|
||||
ConnPool ConnPool
|
||||
// Dialector database dialector
|
||||
Dialector
|
||||
// Plugins registered plugins
|
||||
Plugins map[string]Plugin
|
||||
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
}
|
||||
|
||||
// Apply update config to new config
|
||||
func (c *Config) Apply(config *Config) error {
|
||||
if config != c {
|
||||
*config = *c
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterInitialize initialize plugins after db connected
|
||||
func (c *Config) AfterInitialize(db *DB) error {
|
||||
if db != nil {
|
||||
for _, plugin := range c.Plugins {
|
||||
if err := plugin.Initialize(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Option gorm option interface
|
||||
type Option interface {
|
||||
Apply(*Config) error
|
||||
AfterInitialize(*DB) error
|
||||
}
|
||||
|
||||
// DB GORM DB definition
|
||||
type DB struct {
|
||||
*Config
|
||||
Error error
|
||||
RowsAffected int64
|
||||
Statement *Statement
|
||||
clone int
|
||||
}
|
||||
|
||||
// Session session config when create session with Session() method
|
||||
type Session struct {
|
||||
DryRun bool
|
||||
PrepareStmt bool
|
||||
NewDB bool
|
||||
Initialized bool
|
||||
SkipHooks bool
|
||||
SkipDefaultTransaction bool
|
||||
DisableNestedTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
PropagateUnscoped bool
|
||||
QueryFields bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
CreateBatchSize int
|
||||
}
|
||||
|
||||
// Open initialize db session based on dialector
|
||||
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
config := &Config{}
|
||||
|
||||
sort.Slice(opts, func(i, j int) bool {
|
||||
_, isConfig := opts[i].(*Config)
|
||||
_, isConfig2 := opts[j].(*Config)
|
||||
return isConfig && !isConfig2
|
||||
})
|
||||
|
||||
var skipAfterInitialize bool
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
if applyErr := opt.Apply(config); applyErr != nil {
|
||||
return nil, applyErr
|
||||
}
|
||||
defer func(opt Option) {
|
||||
if skipAfterInitialize {
|
||||
return
|
||||
}
|
||||
if errr := opt.AfterInitialize(db); errr != nil {
|
||||
err = errr
|
||||
}
|
||||
}(opt)
|
||||
}
|
||||
}
|
||||
|
||||
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
|
||||
if err = d.Apply(config); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if config.NamingStrategy == nil {
|
||||
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = logger.Default
|
||||
}
|
||||
|
||||
if config.NowFunc == nil {
|
||||
config.NowFunc = func() time.Time { return time.Now().Local() }
|
||||
}
|
||||
|
||||
if dialector != nil {
|
||||
config.Dialector = dialector
|
||||
}
|
||||
|
||||
if config.Plugins == nil {
|
||||
config.Plugins = map[string]Plugin{}
|
||||
}
|
||||
|
||||
if config.cacheStore == nil {
|
||||
config.cacheStore = &sync.Map{}
|
||||
}
|
||||
|
||||
db = &DB{Config: config, clone: 1}
|
||||
|
||||
db.callbacks = initializeCallbacks(db)
|
||||
|
||||
if config.ClauseBuilders == nil {
|
||||
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
|
||||
}
|
||||
|
||||
if config.Dialector != nil {
|
||||
err = config.Dialector.Initialize(db)
|
||||
if err != nil {
|
||||
if db, _ := db.DB(); db != nil {
|
||||
_ = db.Close()
|
||||
}
|
||||
|
||||
// DB is not initialized, so we skip AfterInitialize
|
||||
skipAfterInitialize = true
|
||||
return
|
||||
}
|
||||
|
||||
if config.TranslateError {
|
||||
if _, ok := db.Dialector.(ErrorTranslator); !ok {
|
||||
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
db.ConnPool = preparedStmt
|
||||
}
|
||||
|
||||
db.Statement = &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Context: context.Background(),
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
|
||||
if err == nil && !config.DisableAutomaticPing {
|
||||
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
|
||||
err = pinger.Ping()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Session create new db session
|
||||
func (db *DB) Session(config *Session) *DB {
|
||||
var (
|
||||
txConfig = *db.Config
|
||||
tx = &DB{
|
||||
Config: &txConfig,
|
||||
Statement: db.Statement,
|
||||
Error: db.Error,
|
||||
clone: 1,
|
||||
}
|
||||
)
|
||||
if config.CreateBatchSize > 0 {
|
||||
tx.Config.CreateBatchSize = config.CreateBatchSize
|
||||
}
|
||||
|
||||
if config.SkipDefaultTransaction {
|
||||
tx.Config.SkipDefaultTransaction = true
|
||||
}
|
||||
|
||||
if config.AllowGlobalUpdate {
|
||||
txConfig.AllowGlobalUpdate = true
|
||||
}
|
||||
|
||||
if config.FullSaveAssociations {
|
||||
txConfig.FullSaveAssociations = true
|
||||
}
|
||||
|
||||
if config.PropagateUnscoped {
|
||||
txConfig.PropagateUnscoped = true
|
||||
}
|
||||
|
||||
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
}
|
||||
|
||||
if config.Context != nil {
|
||||
tx.Statement.Context = config.Context
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
var preparedStmt *PreparedStmtDB
|
||||
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt = v.(*PreparedStmtDB)
|
||||
} else {
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
}
|
||||
|
||||
switch t := tx.Statement.ConnPool.(type) {
|
||||
case Tx:
|
||||
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||
Tx: t,
|
||||
PreparedStmtDB: preparedStmt,
|
||||
}
|
||||
default:
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
|
||||
if config.SkipHooks {
|
||||
tx.Statement.SkipHooks = true
|
||||
}
|
||||
|
||||
if config.DisableNestedTransaction {
|
||||
txConfig.DisableNestedTransaction = true
|
||||
}
|
||||
|
||||
if !config.NewDB {
|
||||
tx.clone = 2
|
||||
}
|
||||
|
||||
if config.DryRun {
|
||||
tx.Config.DryRun = true
|
||||
}
|
||||
|
||||
if config.QueryFields {
|
||||
tx.Config.QueryFields = true
|
||||
}
|
||||
|
||||
if config.Logger != nil {
|
||||
tx.Config.Logger = config.Logger
|
||||
}
|
||||
|
||||
if config.NowFunc != nil {
|
||||
tx.Config.NowFunc = config.NowFunc
|
||||
}
|
||||
|
||||
if config.Initialized {
|
||||
tx = tx.getInstance()
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// WithContext change current instance db's context to ctx
|
||||
func (db *DB) WithContext(ctx context.Context) *DB {
|
||||
return db.Session(&Session{Context: ctx})
|
||||
}
|
||||
|
||||
// Debug start debug mode
|
||||
func (db *DB) Debug() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return tx.Session(&Session{
|
||||
Logger: db.Logger.LogMode(logger.Info),
|
||||
})
|
||||
}
|
||||
|
||||
// Set store value with key into current db instance's context
|
||||
func (db *DB) Set(key string, value interface{}) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Statement.Settings.Store(key, value)
|
||||
return tx
|
||||
}
|
||||
|
||||
// Get get value with key from current db instance's context
|
||||
func (db *DB) Get(key string) (interface{}, bool) {
|
||||
return db.Statement.Settings.Load(key)
|
||||
}
|
||||
|
||||
// InstanceSet store value with key into current db instance's context
|
||||
func (db *DB) InstanceSet(key string, value interface{}) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
|
||||
return tx
|
||||
}
|
||||
|
||||
// InstanceGet get value with key from current db instance's context
|
||||
func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
||||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||
}
|
||||
|
||||
// Callback returns callback manager
|
||||
func (db *DB) Callback() *callbacks {
|
||||
return db.callbacks
|
||||
}
|
||||
|
||||
// AddError add error to db
|
||||
func (db *DB) AddError(err error) error {
|
||||
if err != nil {
|
||||
if db.Config.TranslateError {
|
||||
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
|
||||
err = errTranslator.Translate(err)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Error == nil {
|
||||
db.Error = err
|
||||
} else {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||
}
|
||||
}
|
||||
return db.Error
|
||||
}
|
||||
|
||||
// DB returns `*sql.DB`
|
||||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||
connPool = db.Statement.ConnPool
|
||||
}
|
||||
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
|
||||
return sqldb, err
|
||||
}
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
func (db *DB) getInstance() *DB {
|
||||
if db.clone > 0 {
|
||||
tx := &DB{Config: db.Config, Error: db.Error}
|
||||
|
||||
if db.clone == 1 {
|
||||
// clone with new statement
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Vars: make([]interface{}, 0, 8),
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
}
|
||||
if db.Config.PropagateUnscoped {
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
}
|
||||
} else {
|
||||
// with clone statement
|
||||
tx.Statement = db.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// Expr returns clause.Expr, which can be used to pass SQL expression as params
|
||||
func Expr(expr string, args ...interface{}) clause.Expr {
|
||||
return clause.Expr{SQL: expr, Vars: args}
|
||||
}
|
||||
|
||||
// SetupJoinTable setup join table schema
|
||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
stmt = tx.Statement
|
||||
modelSchema, joinSchema *schema.Schema
|
||||
)
|
||||
|
||||
err := stmt.Parse(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modelSchema = stmt.Schema
|
||||
|
||||
err = stmt.Parse(joinTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
joinSchema = stmt.Schema
|
||||
|
||||
relation, ok := modelSchema.Relationships.Relations[field]
|
||||
isRelation := ok && relation.JoinTable != nil
|
||||
if !isRelation {
|
||||
return fmt.Errorf("failed to find relation: %s", field)
|
||||
}
|
||||
|
||||
for _, ref := range relation.References {
|
||||
f := joinSchema.LookUpField(ref.ForeignKey.DBName)
|
||||
if f == nil {
|
||||
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
f.DataType = ref.ForeignKey.DataType
|
||||
f.GORMDataType = ref.ForeignKey.GORMDataType
|
||||
if f.Size == 0 {
|
||||
f.Size = ref.ForeignKey.Size
|
||||
}
|
||||
ref.ForeignKey = f
|
||||
}
|
||||
|
||||
for name, rel := range relation.JoinTable.Relationships.Relations {
|
||||
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
|
||||
rel.Schema = joinSchema
|
||||
joinSchema.Relationships.Relations[name] = rel
|
||||
}
|
||||
}
|
||||
relation.JoinTable = joinSchema
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use use plugin
|
||||
func (db *DB) Use(plugin Plugin) error {
|
||||
name := plugin.Name()
|
||||
if _, ok := db.Plugins[name]; ok {
|
||||
return ErrRegistered
|
||||
}
|
||||
if err := plugin.Initialize(db); err != nil {
|
||||
return err
|
||||
}
|
||||
db.Plugins[name] = plugin
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToSQL for generate SQL string.
|
||||
//
|
||||
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
|
||||
// .Limit(10).Offset(5)
|
||||
// .Order("name ASC")
|
||||
// .First(&User{})
|
||||
// })
|
||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
|
||||
stmt := tx.Statement
|
||||
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
}
|
||||
92
vendor/gorm.io/gorm/interfaces.go
generated
vendored
Normal file
92
vendor/gorm.io/gorm/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Dialector GORM database dialector
|
||||
type Dialector interface {
|
||||
Name() string
|
||||
Initialize(*DB) error
|
||||
Migrator(db *DB) Migrator
|
||||
DataTypeOf(*schema.Field) string
|
||||
DefaultValueOf(*schema.Field) clause.Expression
|
||||
BindVarTo(writer clause.Writer, stmt *Statement, v interface{})
|
||||
QuoteTo(clause.Writer, string)
|
||||
Explain(sql string, vars ...interface{}) string
|
||||
}
|
||||
|
||||
// Plugin GORM plugin interface
|
||||
type Plugin interface {
|
||||
Name() string
|
||||
Initialize(*DB) error
|
||||
}
|
||||
|
||||
type ParamsFilter interface {
|
||||
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
|
||||
}
|
||||
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// SavePointerDialectorInterface save pointer interface
|
||||
type SavePointerDialectorInterface interface {
|
||||
SavePoint(tx *DB, name string) error
|
||||
RollbackTo(tx *DB, name string) error
|
||||
}
|
||||
|
||||
// TxBeginner tx beginner
|
||||
type TxBeginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// ConnPoolBeginner conn pool beginner
|
||||
type ConnPoolBeginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
||||
}
|
||||
|
||||
// TxCommitter tx committer
|
||||
type TxCommitter interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// Tx sql.Tx interface
|
||||
type Tx interface {
|
||||
ConnPool
|
||||
TxCommitter
|
||||
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
|
||||
}
|
||||
|
||||
// Valuer gorm valuer interface
|
||||
type Valuer interface {
|
||||
GormValue(context.Context, *DB) clause.Expr
|
||||
}
|
||||
|
||||
// GetDBConnector SQL db connector
|
||||
type GetDBConnector interface {
|
||||
GetDBConn() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// Rows rows interface
|
||||
type Rows interface {
|
||||
Columns() ([]string, error)
|
||||
ColumnTypes() ([]*sql.ColumnType, error)
|
||||
Next() bool
|
||||
Scan(dest ...interface{}) error
|
||||
Err() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type ErrorTranslator interface {
|
||||
Translate(err error) error
|
||||
}
|
||||
493
vendor/gorm.io/gorm/internal/lru/lru.go
generated
vendored
Normal file
493
vendor/gorm.io/gorm/internal/lru/lru.go
generated
vendored
Normal file
@@ -0,0 +1,493 @@
|
||||
package lru
|
||||
|
||||
// golang -lru
|
||||
// https://github.com/hashicorp/golang-lru
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EvictCallback is used to get a callback when a cache entry is evicted
|
||||
type EvictCallback[K comparable, V any] func(key K, value V)
|
||||
|
||||
// LRU implements a thread-safe LRU with expirable entries.
|
||||
type LRU[K comparable, V any] struct {
|
||||
size int
|
||||
evictList *LruList[K, V]
|
||||
items map[K]*Entry[K, V]
|
||||
onEvict EvictCallback[K, V]
|
||||
|
||||
// expirable options
|
||||
mu sync.Mutex
|
||||
ttl time.Duration
|
||||
done chan struct{}
|
||||
|
||||
// buckets for expiration
|
||||
buckets []bucket[K, V]
|
||||
// uint8 because it's number between 0 and numBuckets
|
||||
nextCleanupBucket uint8
|
||||
}
|
||||
|
||||
// bucket is a container for holding entries to be expired
|
||||
type bucket[K comparable, V any] struct {
|
||||
entries map[K]*Entry[K, V]
|
||||
newestEntry time.Time
|
||||
}
|
||||
|
||||
// noEvictionTTL - very long ttl to prevent eviction
|
||||
const noEvictionTTL = time.Hour * 24 * 365 * 10
|
||||
|
||||
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
|
||||
// casting it as uint8 explicitly requires type conversions in multiple places
|
||||
const numBuckets = 100
|
||||
|
||||
// NewLRU returns a new thread-safe cache with expirable entries.
|
||||
//
|
||||
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
|
||||
//
|
||||
// Providing 0 TTL turns expiring off.
|
||||
//
|
||||
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
|
||||
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
|
||||
if size < 0 {
|
||||
size = 0
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = noEvictionTTL
|
||||
}
|
||||
|
||||
res := LRU[K, V]{
|
||||
ttl: ttl,
|
||||
size: size,
|
||||
evictList: NewList[K, V](),
|
||||
items: make(map[K]*Entry[K, V]),
|
||||
onEvict: onEvict,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// initialize the buckets
|
||||
res.buckets = make([]bucket[K, V], numBuckets)
|
||||
for i := 0; i < numBuckets; i++ {
|
||||
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
|
||||
}
|
||||
|
||||
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
|
||||
//
|
||||
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
|
||||
// it's decided to add functionality to close it in the version later than v2.
|
||||
if res.ttl != noEvictionTTL {
|
||||
go func(done <-chan struct{}) {
|
||||
ticker := time.NewTicker(res.ttl / numBuckets)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
res.deleteExpired()
|
||||
}
|
||||
}
|
||||
}(res.done)
|
||||
}
|
||||
return &res
|
||||
}
|
||||
|
||||
// Purge clears the cache completely.
|
||||
// onEvict is called for each evicted key.
|
||||
func (c *LRU[K, V]) Purge() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for k, v := range c.items {
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(k, v.Value)
|
||||
}
|
||||
delete(c.items, k)
|
||||
}
|
||||
for _, b := range c.buckets {
|
||||
for _, ent := range b.entries {
|
||||
delete(b.entries, ent.Key)
|
||||
}
|
||||
}
|
||||
c.evictList.Init()
|
||||
}
|
||||
|
||||
// Add adds a value to the cache. Returns true if an eviction occurred.
|
||||
// Returns false if there was no eviction: the item was already in the cache,
|
||||
// or the size was not exceeded.
|
||||
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
now := time.Now()
|
||||
|
||||
// Check for existing item
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.evictList.MoveToFront(ent)
|
||||
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
|
||||
ent.Value = value
|
||||
ent.ExpiresAt = now.Add(c.ttl)
|
||||
c.addToBucket(ent)
|
||||
return false
|
||||
}
|
||||
|
||||
// Add new item
|
||||
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
|
||||
c.items[key] = ent
|
||||
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
|
||||
|
||||
evict := c.size > 0 && c.evictList.Length() > c.size
|
||||
// Verify size not exceeded
|
||||
if evict {
|
||||
c.removeOldest()
|
||||
}
|
||||
return evict
|
||||
}
|
||||
|
||||
// Get looks up a key's value from the cache.
|
||||
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var ent *Entry[K, V]
|
||||
if ent, ok = c.items[key]; ok {
|
||||
// Expired item check
|
||||
if time.Now().After(ent.ExpiresAt) {
|
||||
return value, false
|
||||
}
|
||||
c.evictList.MoveToFront(ent)
|
||||
return ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Contains checks if a key is in the cache, without updating the recent-ness
|
||||
// or deleting it for being stale.
|
||||
func (c *LRU[K, V]) Contains(key K) (ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, ok = c.items[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Peek returns the key value (or undefined if not found) without updating
|
||||
// the "recently used"-ness of the key.
|
||||
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var ent *Entry[K, V]
|
||||
if ent, ok = c.items[key]; ok {
|
||||
// Expired item check
|
||||
if time.Now().After(ent.ExpiresAt) {
|
||||
return value, false
|
||||
}
|
||||
return ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Remove removes the provided key from the cache, returning if the
|
||||
// key was contained.
|
||||
func (c *LRU[K, V]) Remove(key K) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.removeElement(ent)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RemoveOldest removes the oldest item from the cache.
|
||||
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
c.removeElement(ent)
|
||||
return ent.Key, ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetOldest returns the oldest entry
|
||||
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
return ent.Key, ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *LRU[K, V]) KeyValues() map[K]V {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
maps := make(map[K]V)
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
maps[ent.Key] = ent.Value
|
||||
// keys = append(keys, ent.Key)
|
||||
}
|
||||
return maps
|
||||
}
|
||||
|
||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
|
||||
// Expired entries are filtered out.
|
||||
func (c *LRU[K, V]) Keys() []K {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
keys := make([]K, 0, len(c.items))
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, ent.Key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// Values returns a slice of the values in the cache, from oldest to newest.
|
||||
// Expired entries are filtered out.
|
||||
func (c *LRU[K, V]) Values() []V {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
values := make([]V, 0, len(c.items))
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
values = append(values, ent.Value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Len returns the number of items in the cache.
|
||||
func (c *LRU[K, V]) Len() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.evictList.Length()
|
||||
}
|
||||
|
||||
// Resize changes the cache size. Size of 0 means unlimited.
|
||||
func (c *LRU[K, V]) Resize(size int) (evicted int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if size <= 0 {
|
||||
c.size = 0
|
||||
return 0
|
||||
}
|
||||
diff := c.evictList.Length() - size
|
||||
if diff < 0 {
|
||||
diff = 0
|
||||
}
|
||||
for i := 0; i < diff; i++ {
|
||||
c.removeOldest()
|
||||
}
|
||||
c.size = size
|
||||
return diff
|
||||
}
|
||||
|
||||
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
|
||||
// func (c *LRU[K, V]) Close() {
|
||||
// c.mu.Lock()
|
||||
// defer c.mu.Unlock()
|
||||
// select {
|
||||
// case <-c.done:
|
||||
// return
|
||||
// default:
|
||||
// }
|
||||
// close(c.done)
|
||||
// }
|
||||
|
||||
// removeOldest removes the oldest item from the cache. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeOldest() {
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
c.removeElement(ent)
|
||||
}
|
||||
}
|
||||
|
||||
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
|
||||
c.evictList.Remove(e)
|
||||
delete(c.items, e.Key)
|
||||
c.removeFromBucket(e)
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(e.Key, e.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
|
||||
// in it to expire first.
|
||||
func (c *LRU[K, V]) deleteExpired() {
|
||||
c.mu.Lock()
|
||||
bucketIdx := c.nextCleanupBucket
|
||||
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
|
||||
// wait for newest entry to expire before cleanup without holding lock
|
||||
if timeToExpire > 0 {
|
||||
c.mu.Unlock()
|
||||
time.Sleep(timeToExpire)
|
||||
c.mu.Lock()
|
||||
}
|
||||
for _, ent := range c.buckets[bucketIdx].entries {
|
||||
c.removeElement(ent)
|
||||
}
|
||||
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
|
||||
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
|
||||
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
|
||||
e.ExpireBucket = bucketID
|
||||
c.buckets[bucketID].entries[e.Key] = e
|
||||
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
|
||||
c.buckets[bucketID].newestEntry = e.ExpiresAt
|
||||
}
|
||||
}
|
||||
|
||||
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
|
||||
delete(c.buckets[e.ExpireBucket].entries, e.Key)
|
||||
}
|
||||
|
||||
// Cap returns the capacity of the cache
|
||||
func (c *LRU[K, V]) Cap() int {
|
||||
return c.size
|
||||
}
|
||||
|
||||
// Entry is an LRU Entry
|
||||
type Entry[K comparable, V any] struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *Entry[K, V]
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *LruList[K, V]
|
||||
|
||||
// The LRU Key of this element.
|
||||
Key K
|
||||
|
||||
// The Value stored with this element.
|
||||
Value V
|
||||
|
||||
// The time this element would be cleaned up, optional
|
||||
ExpiresAt time.Time
|
||||
|
||||
// The expiry bucket item was put in, optional
|
||||
ExpireBucket uint8
|
||||
}
|
||||
|
||||
// PrevEntry returns the previous list element or nil.
|
||||
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LruList represents a doubly linked list.
|
||||
// The zero Value for LruList is an empty list ready to use.
|
||||
type LruList[K comparable, V any] struct {
|
||||
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list Length excluding (this) sentinel element
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *LruList[K, V]) Init() *LruList[K, V] {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// NewList returns an initialized list.
|
||||
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
|
||||
|
||||
// Length returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *LruList[K, V]) Length() int { return l.len }
|
||||
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *LruList[K, V]) Back() *Entry[K, V] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero List Value.
|
||||
func (l *LruList[K, V]) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
}
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
|
||||
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
|
||||
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
|
||||
}
|
||||
|
||||
// Remove removes e from its list, decrements l.len
|
||||
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
l.len--
|
||||
|
||||
return e.Value
|
||||
}
|
||||
|
||||
// move moves e to next to at.
|
||||
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
|
||||
if e == at {
|
||||
return
|
||||
}
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(k, v, time.Time{}, &l.root)
|
||||
}
|
||||
|
||||
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
|
||||
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(k, v, expiresAt, &l.root)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.move(e, &l.root)
|
||||
}
|
||||
183
vendor/gorm.io/gorm/internal/stmt_store/stmt_store.go
generated
vendored
Normal file
183
vendor/gorm.io/gorm/internal/stmt_store/stmt_store.go
generated
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
package stmt_store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/lru"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Error() error {
|
||||
return stmt.prepareErr
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Close() error {
|
||||
<-stmt.prepared
|
||||
|
||||
if stmt.Stmt != nil {
|
||||
return stmt.Stmt.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
|
||||
// This interface provides methods for creating new statements, retrieving all cache keys,
|
||||
// getting cached statements, setting cached statements, and deleting cached statements.
|
||||
type Store interface {
|
||||
// New creates a new Stmt object and caches it.
|
||||
// Parameters:
|
||||
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
|
||||
// key: The key representing the SQL query, used for caching and preparing the statement.
|
||||
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
|
||||
// connPool: A connection pool that provides database connections.
|
||||
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
|
||||
// Returns:
|
||||
// *Stmt: A newly created statement object for executing SQL operations.
|
||||
// error: An error if the statement preparation fails.
|
||||
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
|
||||
|
||||
// Keys returns a slice of all cache keys in the store.
|
||||
Keys() []string
|
||||
|
||||
// Get retrieves a Stmt object from the store based on the given key.
|
||||
// Parameters:
|
||||
// key: The key used to look up the Stmt object.
|
||||
// Returns:
|
||||
// *Stmt: The found Stmt object, or nil if not found.
|
||||
// bool: Indicates whether the corresponding Stmt object was successfully found.
|
||||
Get(key string) (*Stmt, bool)
|
||||
|
||||
// Set stores the given Stmt object in the store and associates it with the specified key.
|
||||
// Parameters:
|
||||
// key: The key used to associate the Stmt object.
|
||||
// value: The Stmt object to be stored.
|
||||
Set(key string, value *Stmt)
|
||||
|
||||
// Delete removes the Stmt object corresponding to the specified key from the store.
|
||||
// Parameters:
|
||||
// key: The key associated with the Stmt object to be deleted.
|
||||
Delete(key string)
|
||||
}
|
||||
|
||||
// defaultMaxSize defines the default maximum capacity of the cache.
|
||||
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
|
||||
// the cache can theoretically store as many elements as possible.
|
||||
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
|
||||
const (
|
||||
defaultMaxSize = math.MaxInt
|
||||
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
|
||||
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
|
||||
defaultTTL = time.Hour * 24
|
||||
)
|
||||
|
||||
// New creates and returns a new Store instance.
|
||||
//
|
||||
// Parameters:
|
||||
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
|
||||
// it defaults to defaultMaxSize.
|
||||
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
|
||||
// it defaults to defaultTTL.
|
||||
//
|
||||
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
|
||||
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
|
||||
// to release associated resources.
|
||||
//
|
||||
// Returns:
|
||||
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
|
||||
// eviction callback, and TTL.
|
||||
func New(size int, ttl time.Duration) Store {
|
||||
if size <= 0 {
|
||||
size = defaultMaxSize
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = defaultTTL
|
||||
}
|
||||
|
||||
onEvicted := func(k string, v *Stmt) {
|
||||
if v != nil {
|
||||
go v.Close()
|
||||
}
|
||||
}
|
||||
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
|
||||
}
|
||||
|
||||
type lruStore struct {
|
||||
lru *lru.LRU[string, *Stmt]
|
||||
}
|
||||
|
||||
func (s *lruStore) Keys() []string {
|
||||
return s.lru.Keys()
|
||||
}
|
||||
|
||||
func (s *lruStore) Get(key string) (*Stmt, bool) {
|
||||
stmt, ok := s.lru.Get(key)
|
||||
if ok && stmt != nil {
|
||||
<-stmt.prepared
|
||||
}
|
||||
return stmt, ok
|
||||
}
|
||||
|
||||
func (s *lruStore) Set(key string, value *Stmt) {
|
||||
s.lru.Add(key, value)
|
||||
}
|
||||
|
||||
func (s *lruStore) Delete(key string) {
|
||||
s.lru.Remove(key)
|
||||
}
|
||||
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// New creates a new Stmt object for executing SQL queries.
|
||||
// It caches the Stmt object for future use and handles preparation and error states.
|
||||
// Parameters:
|
||||
//
|
||||
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
|
||||
// key: The key representing the SQL query, used for caching and preparing the statement.
|
||||
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
|
||||
// conn: A connection pool that provides database connections.
|
||||
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// *Stmt: A newly created statement object for executing SQL operations.
|
||||
// error: An error if the statement preparation fails.
|
||||
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
|
||||
// Create a Stmt object and set its Transaction property.
|
||||
// The prepared channel is used to synchronize the statement preparation state.
|
||||
cacheStmt := &Stmt{
|
||||
Transaction: isTransaction,
|
||||
prepared: make(chan struct{}),
|
||||
}
|
||||
// Cache the Stmt object with the associated key.
|
||||
s.Set(key, cacheStmt)
|
||||
// Unlock after completing initialization to prevent deadlocks.
|
||||
locker.Unlock()
|
||||
|
||||
// Ensure the prepared channel is closed after the function execution completes.
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Prepare the SQL statement using the provided connection.
|
||||
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
|
||||
if err != nil {
|
||||
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
|
||||
cacheStmt.prepareErr = err
|
||||
s.Delete(key)
|
||||
return &Stmt{}, err
|
||||
}
|
||||
|
||||
// Return the successfully prepared Stmt object.
|
||||
return cacheStmt, nil
|
||||
}
|
||||
225
vendor/gorm.io/gorm/logger/logger.go
generated
vendored
Normal file
225
vendor/gorm.io/gorm/logger/logger.go
generated
vendored
Normal file
@@ -0,0 +1,225 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// ErrRecordNotFound record not found error
|
||||
var ErrRecordNotFound = errors.New("record not found")
|
||||
|
||||
// Colors
|
||||
const (
|
||||
Reset = "\033[0m"
|
||||
Red = "\033[31m"
|
||||
Green = "\033[32m"
|
||||
Yellow = "\033[33m"
|
||||
Blue = "\033[34m"
|
||||
Magenta = "\033[35m"
|
||||
Cyan = "\033[36m"
|
||||
White = "\033[37m"
|
||||
BlueBold = "\033[34;1m"
|
||||
MagentaBold = "\033[35;1m"
|
||||
RedBold = "\033[31;1m"
|
||||
YellowBold = "\033[33;1m"
|
||||
)
|
||||
|
||||
// LogLevel log level
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
// Silent silent log level
|
||||
Silent LogLevel = iota + 1
|
||||
// Error error log level
|
||||
Error
|
||||
// Warn warn log level
|
||||
Warn
|
||||
// Info info log level
|
||||
Info
|
||||
)
|
||||
|
||||
// Writer log writer interface
|
||||
type Writer interface {
|
||||
Printf(string, ...interface{})
|
||||
}
|
||||
|
||||
// Config logger config
|
||||
type Config struct {
|
||||
SlowThreshold time.Duration
|
||||
Colorful bool
|
||||
IgnoreRecordNotFoundError bool
|
||||
ParameterizedQueries bool
|
||||
LogLevel LogLevel
|
||||
}
|
||||
|
||||
// Interface logger interface
|
||||
type Interface interface {
|
||||
LogMode(LogLevel) Interface
|
||||
Info(context.Context, string, ...interface{})
|
||||
Warn(context.Context, string, ...interface{})
|
||||
Error(context.Context, string, ...interface{})
|
||||
Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
|
||||
}
|
||||
|
||||
var (
|
||||
// Discard logger will print any log to io.Discard
|
||||
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
||||
// Default Default logger
|
||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: Warn,
|
||||
IgnoreRecordNotFoundError: false,
|
||||
Colorful: true,
|
||||
})
|
||||
// Recorder logger records running SQL into a recorder instance
|
||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||
|
||||
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
|
||||
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
return sql, params
|
||||
}
|
||||
)
|
||||
|
||||
// New initialize logger
|
||||
func New(writer Writer, config Config) Interface {
|
||||
var (
|
||||
infoStr = "%s\n[info] "
|
||||
warnStr = "%s\n[warn] "
|
||||
errStr = "%s\n[error] "
|
||||
traceStr = "%s\n[%.3fms] [rows:%v] %s"
|
||||
traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s"
|
||||
traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
|
||||
)
|
||||
|
||||
if config.Colorful {
|
||||
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
||||
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
||||
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
|
||||
traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
|
||||
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
|
||||
}
|
||||
|
||||
return &logger{
|
||||
Writer: writer,
|
||||
Config: config,
|
||||
infoStr: infoStr,
|
||||
warnStr: warnStr,
|
||||
errStr: errStr,
|
||||
traceStr: traceStr,
|
||||
traceWarnStr: traceWarnStr,
|
||||
traceErrStr: traceErrStr,
|
||||
}
|
||||
}
|
||||
|
||||
type logger struct {
|
||||
Writer
|
||||
Config
|
||||
infoStr, warnStr, errStr string
|
||||
traceStr, traceErrStr, traceWarnStr string
|
||||
}
|
||||
|
||||
// LogMode log mode
|
||||
func (l *logger) LogMode(level LogLevel) Interface {
|
||||
newlogger := *l
|
||||
newlogger.LogLevel = level
|
||||
return &newlogger
|
||||
}
|
||||
|
||||
// Info print info
|
||||
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Info {
|
||||
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn print warn messages
|
||||
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Warn {
|
||||
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error print error messages
|
||||
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Error {
|
||||
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Trace print sql message
|
||||
//
|
||||
//nolint:cyclop
|
||||
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
if l.LogLevel <= Silent {
|
||||
return
|
||||
}
|
||||
|
||||
elapsed := time.Since(begin)
|
||||
switch {
|
||||
case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
|
||||
sql, rows := fc()
|
||||
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
case l.LogLevel == Info:
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ParamsFilter filter params
|
||||
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if l.Config.ParameterizedQueries {
|
||||
return sql, nil
|
||||
}
|
||||
return sql, params
|
||||
}
|
||||
|
||||
type traceRecorder struct {
|
||||
Interface
|
||||
BeginAt time.Time
|
||||
SQL string
|
||||
RowsAffected int64
|
||||
Err error
|
||||
}
|
||||
|
||||
// New trace recorder
|
||||
func (l *traceRecorder) New() *traceRecorder {
|
||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||
}
|
||||
|
||||
// Trace implement logger interface
|
||||
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
l.BeginAt = begin
|
||||
l.SQL, l.RowsAffected = fc()
|
||||
l.Err = err
|
||||
}
|
||||
|
||||
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if RecorderParamsFilter == nil {
|
||||
return sql, params
|
||||
}
|
||||
return RecorderParamsFilter(ctx, sql, params...)
|
||||
}
|
||||
181
vendor/gorm.io/gorm/logger/sql.go
generated
vendored
Normal file
181
vendor/gorm.io/gorm/logger/sql.go
generated
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
tmFmtWithMS = "2006-01-02 15:04:05.999"
|
||||
tmFmtZero = "0000-00-00 00:00:00"
|
||||
nullStr = "NULL"
|
||||
)
|
||||
|
||||
func isPrintable(s string) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsPrint(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// A list of Go types that should be converted to SQL primitives
|
||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||
|
||||
// RegEx matches only numeric values
|
||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||
|
||||
func isNumeric(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||
var (
|
||||
convertParams func(interface{}, int)
|
||||
vars = make([]string, len(avars))
|
||||
)
|
||||
|
||||
convertParams = func(v interface{}, idx int) {
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
vars[idx] = strconv.FormatBool(v)
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
vars[idx] = escaper + tmFmtZero + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
|
||||
}
|
||||
case *time.Time:
|
||||
if v != nil {
|
||||
if v.IsZero() {
|
||||
vars[idx] = escaper + tmFmtZero + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
|
||||
}
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
case driver.Valuer:
|
||||
reflectValue := reflect.ValueOf(v)
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
r, _ := v.Value()
|
||||
convertParams(r, idx)
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
case fmt.Stringer:
|
||||
reflectValue := reflect.ValueOf(v)
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
|
||||
case reflect.Bool:
|
||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||
case reflect.String:
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
default:
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
}
|
||||
case []byte:
|
||||
if s := string(v); isPrintable(s) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + "<binary>" + escaper
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
vars[idx] = utils.ToString(v)
|
||||
case float32:
|
||||
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||
case float64:
|
||||
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case string:
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
vars[idx] = nullStr
|
||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
convertParams(v, idx)
|
||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||
} else if isNumeric(rv.Kind()) {
|
||||
if rv.CanInt() || rv.CanUint() {
|
||||
vars[idx] = fmt.Sprintf("%d", rv.Interface())
|
||||
} else {
|
||||
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
|
||||
}
|
||||
} else {
|
||||
for _, t := range convertibleTypes {
|
||||
if rv.Type().ConvertibleTo(t) {
|
||||
convertParams(rv.Convert(t).Interface(), idx)
|
||||
return
|
||||
}
|
||||
}
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, v := range avars {
|
||||
convertParams(v, idx)
|
||||
}
|
||||
|
||||
if numericPlaceholder == nil {
|
||||
var idx int
|
||||
var newSQL strings.Builder
|
||||
|
||||
for _, v := range []byte(sql) {
|
||||
if v == '?' {
|
||||
if len(vars) > idx {
|
||||
newSQL.WriteString(vars[idx])
|
||||
idx++
|
||||
continue
|
||||
}
|
||||
}
|
||||
newSQL.WriteByte(v)
|
||||
}
|
||||
|
||||
sql = newSQL.String()
|
||||
} else {
|
||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||
|
||||
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
|
||||
num := v[1 : len(v)-1]
|
||||
n, _ := strconv.Atoi(num)
|
||||
|
||||
// position var start from 1 ($1, $2)
|
||||
n -= 1
|
||||
if n >= 0 && n <= len(vars)-1 {
|
||||
return vars[n]
|
||||
}
|
||||
return v
|
||||
})
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
111
vendor/gorm.io/gorm/migrator.go
generated
vendored
Normal file
111
vendor/gorm.io/gorm/migrator.go
generated
vendored
Normal file
@@ -0,0 +1,111 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Migrator returns migrator
|
||||
func (db *DB) Migrator() Migrator {
|
||||
tx := db.getInstance()
|
||||
|
||||
// apply scopes to migrator
|
||||
for len(tx.Statement.scopes) > 0 {
|
||||
tx = tx.executeScopes()
|
||||
}
|
||||
|
||||
return tx.Dialector.Migrator(tx.Session(&Session{}))
|
||||
}
|
||||
|
||||
// AutoMigrate run auto migration for given models
|
||||
func (db *DB) AutoMigrate(dst ...interface{}) error {
|
||||
return db.Migrator().AutoMigrate(dst...)
|
||||
}
|
||||
|
||||
// ViewOption view option
|
||||
type ViewOption struct {
|
||||
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
|
||||
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
|
||||
Query *DB // required subquery.
|
||||
}
|
||||
|
||||
// ColumnType column type interface
|
||||
type ColumnType interface {
|
||||
Name() string
|
||||
DatabaseTypeName() string // varchar
|
||||
ColumnType() (columnType string, ok bool) // varchar(64)
|
||||
PrimaryKey() (isPrimaryKey bool, ok bool)
|
||||
AutoIncrement() (isAutoIncrement bool, ok bool)
|
||||
Length() (length int64, ok bool)
|
||||
DecimalSize() (precision int64, scale int64, ok bool)
|
||||
Nullable() (nullable bool, ok bool)
|
||||
Unique() (unique bool, ok bool)
|
||||
ScanType() reflect.Type
|
||||
Comment() (value string, ok bool)
|
||||
DefaultValue() (value string, ok bool)
|
||||
}
|
||||
|
||||
type Index interface {
|
||||
Table() string
|
||||
Name() string
|
||||
Columns() []string
|
||||
PrimaryKey() (isPrimaryKey bool, ok bool)
|
||||
Unique() (unique bool, ok bool)
|
||||
Option() string
|
||||
}
|
||||
|
||||
// TableType table type interface
|
||||
type TableType interface {
|
||||
Schema() string
|
||||
Name() string
|
||||
Type() string
|
||||
Comment() (comment string, ok bool)
|
||||
}
|
||||
|
||||
// Migrator migrator interface
|
||||
type Migrator interface {
|
||||
// AutoMigrate
|
||||
AutoMigrate(dst ...interface{}) error
|
||||
|
||||
// Database
|
||||
CurrentDatabase() string
|
||||
FullDataTypeOf(*schema.Field) clause.Expr
|
||||
GetTypeAliases(databaseTypeName string) []string
|
||||
|
||||
// Tables
|
||||
CreateTable(dst ...interface{}) error
|
||||
DropTable(dst ...interface{}) error
|
||||
HasTable(dst interface{}) bool
|
||||
RenameTable(oldName, newName interface{}) error
|
||||
GetTables() (tableList []string, err error)
|
||||
TableType(dst interface{}) (TableType, error)
|
||||
|
||||
// Columns
|
||||
AddColumn(dst interface{}, field string) error
|
||||
DropColumn(dst interface{}, field string) error
|
||||
AlterColumn(dst interface{}, field string) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
|
||||
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
HasColumn(dst interface{}, field string) bool
|
||||
RenameColumn(dst interface{}, oldName, field string) error
|
||||
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
||||
|
||||
// Views
|
||||
CreateView(name string, option ViewOption) error
|
||||
DropView(name string) error
|
||||
|
||||
// Constraints
|
||||
CreateConstraint(dst interface{}, name string) error
|
||||
DropConstraint(dst interface{}, name string) error
|
||||
HasConstraint(dst interface{}, name string) bool
|
||||
|
||||
// Indexes
|
||||
CreateIndex(dst interface{}, name string) error
|
||||
DropIndex(dst interface{}, name string) error
|
||||
HasIndex(dst interface{}, name string) bool
|
||||
RenameIndex(dst interface{}, oldName, newName string) error
|
||||
GetIndexes(dst interface{}) ([]Index, error)
|
||||
}
|
||||
107
vendor/gorm.io/gorm/migrator/column_type.go
generated
vendored
Normal file
107
vendor/gorm.io/gorm/migrator/column_type.go
generated
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// ColumnType column type implements ColumnType interface
|
||||
type ColumnType struct {
|
||||
SQLColumnType *sql.ColumnType
|
||||
NameValue sql.NullString
|
||||
DataTypeValue sql.NullString
|
||||
ColumnTypeValue sql.NullString
|
||||
PrimaryKeyValue sql.NullBool
|
||||
UniqueValue sql.NullBool
|
||||
AutoIncrementValue sql.NullBool
|
||||
LengthValue sql.NullInt64
|
||||
DecimalSizeValue sql.NullInt64
|
||||
ScaleValue sql.NullInt64
|
||||
NullableValue sql.NullBool
|
||||
ScanTypeValue reflect.Type
|
||||
CommentValue sql.NullString
|
||||
DefaultValueValue sql.NullString
|
||||
}
|
||||
|
||||
// Name returns the name or alias of the column.
|
||||
func (ct ColumnType) Name() string {
|
||||
if ct.NameValue.Valid {
|
||||
return ct.NameValue.String
|
||||
}
|
||||
return ct.SQLColumnType.Name()
|
||||
}
|
||||
|
||||
// DatabaseTypeName returns the database system name of the column type. If an empty
|
||||
// string is returned, then the driver type name is not supported.
|
||||
// Consult your driver documentation for a list of driver data types. Length specifiers
|
||||
// are not included.
|
||||
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
|
||||
// "INT", and "BIGINT".
|
||||
func (ct ColumnType) DatabaseTypeName() string {
|
||||
if ct.DataTypeValue.Valid {
|
||||
return ct.DataTypeValue.String
|
||||
}
|
||||
return ct.SQLColumnType.DatabaseTypeName()
|
||||
}
|
||||
|
||||
// ColumnType returns the database type of the column. like `varchar(16)`
|
||||
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
|
||||
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
|
||||
}
|
||||
|
||||
// PrimaryKey returns the column is primary key or not.
|
||||
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
|
||||
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
|
||||
}
|
||||
|
||||
// AutoIncrement returns the column is auto increment or not.
|
||||
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
|
||||
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
|
||||
}
|
||||
|
||||
// Length returns the column type length for variable length column types
|
||||
func (ct ColumnType) Length() (length int64, ok bool) {
|
||||
if ct.LengthValue.Valid {
|
||||
return ct.LengthValue.Int64, true
|
||||
}
|
||||
return ct.SQLColumnType.Length()
|
||||
}
|
||||
|
||||
// DecimalSize returns the scale and precision of a decimal type.
|
||||
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
|
||||
if ct.DecimalSizeValue.Valid {
|
||||
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
|
||||
}
|
||||
return ct.SQLColumnType.DecimalSize()
|
||||
}
|
||||
|
||||
// Nullable reports whether the column may be null.
|
||||
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
|
||||
if ct.NullableValue.Valid {
|
||||
return ct.NullableValue.Bool, true
|
||||
}
|
||||
return ct.SQLColumnType.Nullable()
|
||||
}
|
||||
|
||||
// Unique reports whether the column may be unique.
|
||||
func (ct ColumnType) Unique() (unique bool, ok bool) {
|
||||
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
|
||||
}
|
||||
|
||||
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
|
||||
func (ct ColumnType) ScanType() reflect.Type {
|
||||
if ct.ScanTypeValue != nil {
|
||||
return ct.ScanTypeValue
|
||||
}
|
||||
return ct.SQLColumnType.ScanType()
|
||||
}
|
||||
|
||||
// Comment returns the comment of current column.
|
||||
func (ct ColumnType) Comment() (value string, ok bool) {
|
||||
return ct.CommentValue.String, ct.CommentValue.Valid
|
||||
}
|
||||
|
||||
// DefaultValue returns the default value of current column.
|
||||
func (ct ColumnType) DefaultValue() (value string, ok bool) {
|
||||
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
|
||||
}
|
||||
43
vendor/gorm.io/gorm/migrator/index.go
generated
vendored
Normal file
43
vendor/gorm.io/gorm/migrator/index.go
generated
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
package migrator
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// Index implements gorm.Index interface
|
||||
type Index struct {
|
||||
TableName string
|
||||
NameValue string
|
||||
ColumnList []string
|
||||
PrimaryKeyValue sql.NullBool
|
||||
UniqueValue sql.NullBool
|
||||
OptionValue string
|
||||
}
|
||||
|
||||
// Table return the table name of the index.
|
||||
func (idx Index) Table() string {
|
||||
return idx.TableName
|
||||
}
|
||||
|
||||
// Name return the name of the index.
|
||||
func (idx Index) Name() string {
|
||||
return idx.NameValue
|
||||
}
|
||||
|
||||
// Columns return the columns of the index
|
||||
func (idx Index) Columns() []string {
|
||||
return idx.ColumnList
|
||||
}
|
||||
|
||||
// PrimaryKey returns the index is primary key or not.
|
||||
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
|
||||
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
|
||||
}
|
||||
|
||||
// Unique returns whether the index is unique or not.
|
||||
func (idx Index) Unique() (unique bool, ok bool) {
|
||||
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
|
||||
}
|
||||
|
||||
// Option return the optional attribute of the index
|
||||
func (idx Index) Option() string {
|
||||
return idx.OptionValue
|
||||
}
|
||||
1024
vendor/gorm.io/gorm/migrator/migrator.go
generated
vendored
Normal file
1024
vendor/gorm.io/gorm/migrator/migrator.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
33
vendor/gorm.io/gorm/migrator/table_type.go
generated
vendored
Normal file
33
vendor/gorm.io/gorm/migrator/table_type.go
generated
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// TableType table type implements TableType interface
|
||||
type TableType struct {
|
||||
SchemaValue string
|
||||
NameValue string
|
||||
TypeValue string
|
||||
CommentValue sql.NullString
|
||||
}
|
||||
|
||||
// Schema returns the schema of the table.
|
||||
func (ct TableType) Schema() string {
|
||||
return ct.SchemaValue
|
||||
}
|
||||
|
||||
// Name returns the name of the table.
|
||||
func (ct TableType) Name() string {
|
||||
return ct.NameValue
|
||||
}
|
||||
|
||||
// Type returns the type of the table.
|
||||
func (ct TableType) Type() string {
|
||||
return ct.TypeValue
|
||||
}
|
||||
|
||||
// Comment returns the comment of current table.
|
||||
func (ct TableType) Comment() (comment string, ok bool) {
|
||||
return ct.CommentValue.String, ct.CommentValue.Valid
|
||||
}
|
||||
16
vendor/gorm.io/gorm/model.go
generated
vendored
Normal file
16
vendor/gorm.io/gorm/model.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embedded into your model or you may build your own model without it
|
||||
//
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt DeletedAt `gorm:"index"`
|
||||
}
|
||||
206
vendor/gorm.io/gorm/prepare_stmt.go
generated
vendored
Normal file
206
vendor/gorm.io/gorm/prepare_stmt.go
generated
vendored
Normal file
@@ -0,0 +1,206 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/stmt_store"
|
||||
)
|
||||
|
||||
type PreparedStmtDB struct {
|
||||
Stmts stmt_store.Store
|
||||
Mux *sync.RWMutex
|
||||
ConnPool
|
||||
}
|
||||
|
||||
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
|
||||
//
|
||||
// Parameters:
|
||||
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
|
||||
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
|
||||
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
|
||||
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
|
||||
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
|
||||
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
|
||||
}
|
||||
}
|
||||
|
||||
// GetDBConn returns the underlying *sql.DB connection
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
// Close closes all prepared statements in the store
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
for _, key := range db.Stmts.Keys() {
|
||||
db.Stmts.Delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset Deprecated use Close instead
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
|
||||
db.Mux.RLock()
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
// retry
|
||||
db.Mux.Lock()
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
||||
tx, err := beginner.BeginTx(ctx, opt)
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||
}
|
||||
|
||||
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||
if !ok {
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
connPool, err := beginner.BeginTx(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tx, ok := connPool.(Tx); ok {
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||
}
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Ping() error {
|
||||
conn, err := db.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
||||
type PreparedStmtTX struct {
|
||||
Tx
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||
return db.PreparedStmtDB.GetDBConn()
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Rollback() error {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Ping() error {
|
||||
conn, err := tx.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
369
vendor/gorm.io/gorm/scan.go
generated
vendored
Normal file
369
vendor/gorm.io/gorm/scan.go
generated
vendored
Normal file
@@ -0,0 +1,369 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// prepareValues prepare values slice
|
||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||
if db.Statement.Schema != nil {
|
||||
for idx, name := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||
values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
} else if len(columnTypes) > 0 {
|
||||
for idx, columnType := range columnTypes {
|
||||
if columnType.ScanType() != nil {
|
||||
values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
|
||||
} else {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
|
||||
for idx, column := range columns {
|
||||
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
|
||||
mapValue[column] = reflectValue.Interface()
|
||||
if valuer, ok := mapValue[column].(driver.Valuer); ok {
|
||||
mapValue[column], _ = valuer.Value()
|
||||
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
|
||||
mapValue[column] = string(b)
|
||||
}
|
||||
} else {
|
||||
mapValue[column] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = field.NewValuePool.Get()
|
||||
} else if len(fields) == 1 {
|
||||
if reflectValue.CanAddr() {
|
||||
values[idx] = reflectValue.Addr().Interface()
|
||||
} else {
|
||||
values[idx] = reflectValue.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
joinedNestedSchemaMap := make(map[string]interface{})
|
||||
for idx, field := range fields {
|
||||
if field == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
|
||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||
} else { // joinFields count is larger than 2 when using join
|
||||
var isNilPtrValue bool
|
||||
var relValue reflect.Value
|
||||
// does not contain raw dbname
|
||||
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
|
||||
// current reflect value
|
||||
currentReflectValue := reflectValue
|
||||
fullRels := make([]string, 0, len(nestedJoinSchemas))
|
||||
for _, joinSchema := range nestedJoinSchemas {
|
||||
fullRels = append(fullRels, joinSchema.Name)
|
||||
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
|
||||
if relValue.Kind() == reflect.Ptr {
|
||||
fullRelsName := utils.JoinNestedRelationNames(fullRels)
|
||||
// same nested structure
|
||||
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
isNilPtrValue = true
|
||||
break
|
||||
}
|
||||
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
joinedNestedSchemaMap[fullRelsName] = nil
|
||||
}
|
||||
}
|
||||
currentReflectValue = relValue
|
||||
}
|
||||
|
||||
if !isNilPtrValue { // ignore if value is nil
|
||||
f := joinFields[idx][len(joinFields[idx])-1]
|
||||
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
// release data to pool
|
||||
field.NewValuePool.Put(values[idx])
|
||||
}
|
||||
}
|
||||
|
||||
// ScanMode scan data mode
|
||||
type ScanMode uint8
|
||||
|
||||
// scan modes
|
||||
const (
|
||||
ScanInitialized ScanMode = 1 << 0 // 1
|
||||
ScanUpdate ScanMode = 1 << 1 // 2
|
||||
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
|
||||
)
|
||||
|
||||
// Scan scan rows into db statement
|
||||
func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
var (
|
||||
columns, _ = rows.Columns()
|
||||
values = make([]interface{}, len(columns))
|
||||
initialized = mode&ScanInitialized != 0
|
||||
update = mode&ScanUpdate != 0
|
||||
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
|
||||
)
|
||||
|
||||
if len(db.Statement.ColumnMapping) > 0 {
|
||||
for i, column := range columns {
|
||||
v, ok := db.Statement.ColumnMapping[column]
|
||||
if ok {
|
||||
columns[i] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected = 0
|
||||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
if initialized || rows.Next() {
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue, ok := dest.(map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := dest.(*map[string]interface{}); ok {
|
||||
if *v == nil {
|
||||
*v = map[string]interface{}{}
|
||||
}
|
||||
mapValue = *v
|
||||
}
|
||||
}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
}
|
||||
case *[]map[string]interface{}:
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
for initialized || rows.Next() {
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue := map[string]interface{}{}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
*dest = append(*dest, mapValue)
|
||||
}
|
||||
case *int, *int8, *int16, *int32, *int64,
|
||||
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
||||
*float32, *float64,
|
||||
*bool, *string, *time.Time,
|
||||
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
|
||||
*sql.NullBool, *sql.NullString, *sql.NullTime:
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(dest))
|
||||
}
|
||||
default:
|
||||
var (
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
reflectValueType := reflectValue.Type()
|
||||
switch reflectValueType.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
isPtr := reflectValueType.Kind() == reflect.Ptr
|
||||
if isPtr {
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
|
||||
if sch != nil {
|
||||
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
}
|
||||
|
||||
if len(columns) == 1 {
|
||||
// Is Pluck
|
||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||
sch = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Not Pluck
|
||||
if sch != nil {
|
||||
matchedFieldCount := make(map[string]int, len(columns))
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
if count, ok := matchedFieldCount[column]; ok {
|
||||
// handle duplicate fields
|
||||
for _, selectField := range sch.Fields {
|
||||
if selectField.DBName == column && selectField.Readable {
|
||||
if count == 0 {
|
||||
matchedFieldCount[column]++
|
||||
fields[idx] = selectField
|
||||
break
|
||||
}
|
||||
count--
|
||||
}
|
||||
}
|
||||
} else {
|
||||
matchedFieldCount[column] = 1
|
||||
}
|
||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
|
||||
for _, join := range db.Statement.Joins {
|
||||
if join.Alias == aliasName {
|
||||
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
subNameCount := len(names)
|
||||
// nested relation fields
|
||||
relFields := make([]*schema.Field, 0, subNameCount-1)
|
||||
relFields = append(relFields, rel.Field)
|
||||
for _, name := range names[1 : subNameCount-1] {
|
||||
rel = rel.FieldSchema.Relationships.Relations[name]
|
||||
relFields = append(relFields, rel.Field)
|
||||
}
|
||||
// latest name is raw dbname
|
||||
dbName := names[subNameCount-1]
|
||||
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][]*schema.Field, len(columns))
|
||||
}
|
||||
relFields = append(relFields, field)
|
||||
joinFields[idx] = relFields
|
||||
continue
|
||||
}
|
||||
}
|
||||
var val interface{}
|
||||
values[idx] = &val
|
||||
} else {
|
||||
var val interface{}
|
||||
values[idx] = &val
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
elem reflect.Value
|
||||
isArrayKind = reflectValue.Kind() == reflect.Array
|
||||
)
|
||||
|
||||
if !update || reflectValue.Len() == 0 {
|
||||
update = false
|
||||
if isArrayKind {
|
||||
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||
} else {
|
||||
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||
if reflectValue.Cap() == 0 {
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
} else {
|
||||
reflectValue.SetLen(0)
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
BEGIN:
|
||||
initialized = false
|
||||
|
||||
if update {
|
||||
if int(db.RowsAffected) >= reflectValue.Len() {
|
||||
return
|
||||
}
|
||||
elem = reflectValue.Index(int(db.RowsAffected))
|
||||
if onConflictDonothing {
|
||||
for _, field := range fields {
|
||||
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
|
||||
db.RowsAffected++
|
||||
goto BEGIN
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
elem = reflect.New(reflectValueType)
|
||||
}
|
||||
|
||||
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||
|
||||
if !update {
|
||||
if !isPtr {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
if isArrayKind {
|
||||
if reflectValue.Len() >= int(db.RowsAffected) {
|
||||
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
|
||||
}
|
||||
} else {
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !update {
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if initialized || rows.Next() {
|
||||
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
|
||||
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||
}
|
||||
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||
}
|
||||
default:
|
||||
db.AddError(rows.Scan(dest))
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil && err != db.Error {
|
||||
db.AddError(err)
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
|
||||
db.AddError(ErrRecordNotFound)
|
||||
}
|
||||
}
|
||||
66
vendor/gorm.io/gorm/schema/constraint.go
generated
vendored
Normal file
66
vendor/gorm.io/gorm/schema/constraint.go
generated
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// reg match english letters and midline
|
||||
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
|
||||
|
||||
type CheckConstraint struct {
|
||||
Name string
|
||||
Constraint string // length(phone) >= 10
|
||||
*Field
|
||||
}
|
||||
|
||||
func (chk *CheckConstraint) GetName() string { return chk.Name }
|
||||
|
||||
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
}
|
||||
|
||||
// ParseCheckConstraints parse schema check constraints
|
||||
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
|
||||
checks := map[string]CheckConstraint{}
|
||||
for _, field := range schema.FieldsByDBName {
|
||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||
names := strings.Split(chk, ",")
|
||||
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
||||
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||
} else {
|
||||
if names[0] == "" {
|
||||
chk = strings.Join(names[1:], ",")
|
||||
}
|
||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
|
||||
}
|
||||
}
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
type UniqueConstraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
}
|
||||
|
||||
func (uni *UniqueConstraint) GetName() string { return uni.Name }
|
||||
|
||||
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
|
||||
}
|
||||
|
||||
// ParseUniqueConstraints parse schema unique constraints
|
||||
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
|
||||
uniques := make(map[string]UniqueConstraint)
|
||||
for _, field := range schema.Fields {
|
||||
if field.Unique {
|
||||
name := schema.namer.UniqueName(schema.Table, field.DBName)
|
||||
uniques[name] = UniqueConstraint{Name: name, Field: field}
|
||||
}
|
||||
}
|
||||
return uniques
|
||||
}
|
||||
1002
vendor/gorm.io/gorm/schema/field.go
generated
vendored
Normal file
1002
vendor/gorm.io/gorm/schema/field.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
167
vendor/gorm.io/gorm/schema/index.go
generated
vendored
Normal file
167
vendor/gorm.io/gorm/schema/index.go
generated
vendored
Normal file
@@ -0,0 +1,167 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Index struct {
|
||||
Name string
|
||||
Class string // UNIQUE | FULLTEXT | SPATIAL
|
||||
Type string // btree, hash, gist, spgist, gin, and brin
|
||||
Where string
|
||||
Comment string
|
||||
Option string // WITH PARSER parser_name
|
||||
Fields []IndexOption // Note: IndexOption's Field maybe the same
|
||||
}
|
||||
|
||||
type IndexOption struct {
|
||||
*Field
|
||||
Expression string
|
||||
Sort string // DESC, ASC
|
||||
Collate string
|
||||
Length int
|
||||
Priority int
|
||||
}
|
||||
|
||||
// ParseIndexes parse schema indexes
|
||||
func (schema *Schema) ParseIndexes() []*Index {
|
||||
indexesByName := map[string]*Index{}
|
||||
indexes := []*Index{}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
|
||||
fieldIndexes, err := parseFieldIndexes(field)
|
||||
if err != nil {
|
||||
schema.err = err
|
||||
break
|
||||
}
|
||||
for _, index := range fieldIndexes {
|
||||
idx := indexesByName[index.Name]
|
||||
if idx == nil {
|
||||
idx = &Index{Name: index.Name}
|
||||
indexesByName[index.Name] = idx
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
idx.Name = index.Name
|
||||
if idx.Class == "" {
|
||||
idx.Class = index.Class
|
||||
}
|
||||
if idx.Type == "" {
|
||||
idx.Type = index.Type
|
||||
}
|
||||
if idx.Where == "" {
|
||||
idx.Where = index.Where
|
||||
}
|
||||
if idx.Comment == "" {
|
||||
idx.Comment = index.Comment
|
||||
}
|
||||
if idx.Option == "" {
|
||||
idx.Option = index.Option
|
||||
}
|
||||
|
||||
idx.Fields = append(idx.Fields, index.Fields...)
|
||||
sort.Slice(idx.Fields, func(i, j int) bool {
|
||||
return idx.Fields[i].Priority < idx.Fields[j].Priority
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, index := range indexes {
|
||||
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
|
||||
index.Fields[0].Field.UniqueIndex = index.Name
|
||||
}
|
||||
}
|
||||
return indexes
|
||||
}
|
||||
|
||||
func (schema *Schema) LookIndex(name string) *Index {
|
||||
if schema != nil {
|
||||
indexes := schema.ParseIndexes()
|
||||
for _, index := range indexes {
|
||||
if index.Name == name {
|
||||
return index
|
||||
}
|
||||
|
||||
for _, field := range index.Fields {
|
||||
if field.Name == name {
|
||||
return index
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
|
||||
if value != "" {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||
if k == "INDEX" || k == "UNIQUEINDEX" {
|
||||
var (
|
||||
name string
|
||||
tag = strings.Join(v[1:], ":")
|
||||
idx = strings.IndexByte(tag, ',')
|
||||
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
|
||||
settings = ParseTagSetting(tagSetting, ",")
|
||||
length, _ = strconv.Atoi(settings["LENGTH"])
|
||||
)
|
||||
|
||||
if idx == -1 {
|
||||
idx = len(tag)
|
||||
}
|
||||
|
||||
name = tag[0:idx]
|
||||
if name == "" {
|
||||
subName := field.Name
|
||||
const key = "COMPOSITE"
|
||||
if composite, found := settings[key]; found {
|
||||
if len(composite) == 0 || composite == key {
|
||||
err = fmt.Errorf(
|
||||
"the composite tag of %s.%s cannot be empty",
|
||||
field.Schema.Name,
|
||||
field.Name)
|
||||
return
|
||||
}
|
||||
subName = composite
|
||||
}
|
||||
name = field.Schema.namer.IndexName(
|
||||
field.Schema.Table, subName)
|
||||
}
|
||||
|
||||
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
|
||||
settings["CLASS"] = "UNIQUE"
|
||||
}
|
||||
|
||||
priority, err := strconv.Atoi(settings["PRIORITY"])
|
||||
if err != nil {
|
||||
priority = 10
|
||||
}
|
||||
|
||||
indexes = append(indexes, Index{
|
||||
Name: name,
|
||||
Class: settings["CLASS"],
|
||||
Type: settings["TYPE"],
|
||||
Where: settings["WHERE"],
|
||||
Comment: settings["COMMENT"],
|
||||
Option: settings["OPTION"],
|
||||
Fields: []IndexOption{{
|
||||
Field: field,
|
||||
Expression: settings["EXPRESSION"],
|
||||
Sort: settings["SORT"],
|
||||
Collate: settings["COLLATE"],
|
||||
Length: length,
|
||||
Priority: priority,
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
42
vendor/gorm.io/gorm/schema/interfaces.go
generated
vendored
Normal file
42
vendor/gorm.io/gorm/schema/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConstraintInterface database constraint interface
|
||||
type ConstraintInterface interface {
|
||||
GetName() string
|
||||
Build() (sql string, vars []interface{})
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDataType() string
|
||||
}
|
||||
|
||||
// FieldNewValuePool field new scan value pool
|
||||
type FieldNewValuePool interface {
|
||||
Get() interface{}
|
||||
Put(interface{})
|
||||
}
|
||||
|
||||
// CreateClausesInterface create clauses interface
|
||||
type CreateClausesInterface interface {
|
||||
CreateClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
// QueryClausesInterface query clauses interface
|
||||
type QueryClausesInterface interface {
|
||||
QueryClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
// UpdateClausesInterface update clauses interface
|
||||
type UpdateClausesInterface interface {
|
||||
UpdateClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
// DeleteClausesInterface delete clauses interface
|
||||
type DeleteClausesInterface interface {
|
||||
DeleteClauses(*Field) []clause.Interface
|
||||
}
|
||||
196
vendor/gorm.io/gorm/schema/naming.go
generated
vendored
Normal file
196
vendor/gorm.io/gorm/schema/naming.go
generated
vendored
Normal file
@@ -0,0 +1,196 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Namer namer interface
|
||||
type Namer interface {
|
||||
TableName(table string) string
|
||||
SchemaName(table string) string
|
||||
ColumnName(table, column string) string
|
||||
JoinTableName(joinTable string) string
|
||||
RelationshipFKName(Relationship) string
|
||||
CheckerName(table, column string) string
|
||||
IndexName(table, column string) string
|
||||
UniqueName(table, column string) string
|
||||
}
|
||||
|
||||
// Replacer replacer interface like strings.Replacer
|
||||
type Replacer interface {
|
||||
Replace(name string) string
|
||||
}
|
||||
|
||||
var _ Namer = (*NamingStrategy)(nil)
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
type NamingStrategy struct {
|
||||
TablePrefix string
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
IdentifierMaxLength int
|
||||
}
|
||||
|
||||
// TableName convert string to table name
|
||||
func (ns NamingStrategy) TableName(str string) string {
|
||||
if ns.SingularTable {
|
||||
return ns.TablePrefix + ns.toDBName(str)
|
||||
}
|
||||
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
|
||||
}
|
||||
|
||||
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
|
||||
func (ns NamingStrategy) SchemaName(table string) string {
|
||||
table = strings.TrimPrefix(table, ns.TablePrefix)
|
||||
|
||||
if ns.SingularTable {
|
||||
return ns.toSchemaName(table)
|
||||
}
|
||||
return ns.toSchemaName(inflection.Singular(table))
|
||||
}
|
||||
|
||||
// ColumnName convert string to column name
|
||||
func (ns NamingStrategy) ColumnName(table, column string) string {
|
||||
return ns.toDBName(column)
|
||||
}
|
||||
|
||||
// JoinTableName convert string to join table name
|
||||
func (ns NamingStrategy) JoinTableName(str string) string {
|
||||
if !ns.NoLowerCase && strings.ToLower(str) == str {
|
||||
return ns.TablePrefix + str
|
||||
}
|
||||
|
||||
if ns.SingularTable {
|
||||
return ns.TablePrefix + ns.toDBName(str)
|
||||
}
|
||||
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
|
||||
}
|
||||
|
||||
// RelationshipFKName generate fk name for relation
|
||||
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
|
||||
return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name))
|
||||
}
|
||||
|
||||
// CheckerName generate checker name
|
||||
func (ns NamingStrategy) CheckerName(table, column string) string {
|
||||
return ns.formatName("chk", table, column)
|
||||
}
|
||||
|
||||
// IndexName generate index name
|
||||
func (ns NamingStrategy) IndexName(table, column string) string {
|
||||
return ns.formatName("idx", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
// UniqueName generate unique constraint name
|
||||
func (ns NamingStrategy) UniqueName(table, column string) string {
|
||||
return ns.formatName("uni", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
formattedName := strings.ReplaceAll(strings.Join([]string{
|
||||
prefix, table, name,
|
||||
}, "_"), ".", "_")
|
||||
|
||||
if ns.IdentifierMaxLength == 0 {
|
||||
ns.IdentifierMaxLength = 64
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(formattedName))
|
||||
bs := h.Sum(nil)
|
||||
|
||||
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
|
||||
}
|
||||
return formattedName
|
||||
}
|
||||
|
||||
var (
|
||||
// https://github.com/golang/lint/blob/master/lint.go#L770
|
||||
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||
commonInitialismsReplacer *strings.Replacer
|
||||
)
|
||||
|
||||
func init() {
|
||||
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
|
||||
for _, initialism := range commonInitialisms {
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
|
||||
}
|
||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) toDBName(name string) string {
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if ns.NameReplacer != nil {
|
||||
tmpName := ns.NameReplacer.Replace(name)
|
||||
|
||||
if tmpName == "" {
|
||||
return name
|
||||
}
|
||||
|
||||
name = tmpName
|
||||
}
|
||||
|
||||
if ns.NoLowerCase {
|
||||
return name
|
||||
}
|
||||
|
||||
var (
|
||||
value = commonInitialismsReplacer.Replace(name)
|
||||
buf strings.Builder
|
||||
lastCase, nextCase, nextNumber bool // upper case == true
|
||||
curCase = value[0] <= 'Z' && value[0] >= 'A'
|
||||
)
|
||||
|
||||
for i, v := range value[:len(value)-1] {
|
||||
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
|
||||
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
|
||||
|
||||
if curCase {
|
||||
if lastCase && (nextCase || nextNumber) {
|
||||
buf.WriteRune(v + 32)
|
||||
} else {
|
||||
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
|
||||
buf.WriteByte('_')
|
||||
}
|
||||
buf.WriteRune(v + 32)
|
||||
}
|
||||
} else {
|
||||
buf.WriteRune(v)
|
||||
}
|
||||
|
||||
lastCase = curCase
|
||||
curCase = nextCase
|
||||
}
|
||||
|
||||
if curCase {
|
||||
if !lastCase && len(value) > 1 {
|
||||
buf.WriteByte('_')
|
||||
}
|
||||
buf.WriteByte(value[len(value)-1] + 32)
|
||||
} else {
|
||||
buf.WriteByte(value[len(value)-1])
|
||||
}
|
||||
ret := buf.String()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) toSchemaName(name string) string {
|
||||
result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
|
||||
for _, initialism := range commonInitialisms {
|
||||
result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
||||
}
|
||||
return result
|
||||
}
|
||||
19
vendor/gorm.io/gorm/schema/pool.go
generated
vendored
Normal file
19
vendor/gorm.io/gorm/schema/pool.go
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// sync pools
|
||||
var (
|
||||
normalPool sync.Map
|
||||
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
|
||||
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return reflect.New(reflectType).Interface()
|
||||
},
|
||||
})
|
||||
return v.(FieldNewValuePool)
|
||||
}
|
||||
)
|
||||
773
vendor/gorm.io/gorm/schema/relationship.go
generated
vendored
Normal file
773
vendor/gorm.io/gorm/schema/relationship.go
generated
vendored
Normal file
@@ -0,0 +1,773 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// RelationshipType relationship type
|
||||
type RelationshipType string
|
||||
|
||||
const (
|
||||
HasOne RelationshipType = "has_one" // HasOneRel has one relationship
|
||||
HasMany RelationshipType = "has_many" // HasManyRel has many relationship
|
||||
BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
|
||||
Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
|
||||
has RelationshipType = "has"
|
||||
)
|
||||
|
||||
type Relationships struct {
|
||||
HasOne []*Relationship
|
||||
BelongsTo []*Relationship
|
||||
HasMany []*Relationship
|
||||
Many2Many []*Relationship
|
||||
Relations map[string]*Relationship
|
||||
|
||||
EmbeddedRelations map[string]*Relationships
|
||||
|
||||
Mux sync.RWMutex
|
||||
}
|
||||
|
||||
type Relationship struct {
|
||||
Name string
|
||||
Type RelationshipType
|
||||
Field *Field
|
||||
Polymorphic *Polymorphic
|
||||
References []*Reference
|
||||
Schema *Schema
|
||||
FieldSchema *Schema
|
||||
JoinTable *Schema
|
||||
foreignKeys, primaryKeys []string
|
||||
}
|
||||
|
||||
type Polymorphic struct {
|
||||
PolymorphicID *Field
|
||||
PolymorphicType *Field
|
||||
Value string
|
||||
}
|
||||
|
||||
type Reference struct {
|
||||
PrimaryKey *Field
|
||||
PrimaryValue string
|
||||
ForeignKey *Field
|
||||
OwnPrimaryKey bool
|
||||
}
|
||||
|
||||
func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
var (
|
||||
err error
|
||||
fieldValue = reflect.New(field.IndirectFieldType).Interface()
|
||||
relation = &Relationship{
|
||||
Name: field.Name,
|
||||
Field: field,
|
||||
Schema: schema,
|
||||
foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
|
||||
primaryKeys: toColumns(field.TagSettings["REFERENCES"]),
|
||||
}
|
||||
)
|
||||
|
||||
cacheStore := schema.cacheStore
|
||||
|
||||
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if hasPolymorphicRelation(field.TagSettings) {
|
||||
schema.buildPolymorphicRelation(relation, field)
|
||||
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
|
||||
schema.guessRelation(relation, field, guessBelongs)
|
||||
} else {
|
||||
switch field.IndirectFieldType.Kind() {
|
||||
case reflect.Struct:
|
||||
schema.guessRelation(relation, field, guessGuess)
|
||||
case reflect.Slice:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
|
||||
field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if relation.Type == has {
|
||||
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
|
||||
relation.FieldSchema.Relationships.Mux.Lock()
|
||||
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
|
||||
relation.FieldSchema.Relationships.Mux.Unlock()
|
||||
}
|
||||
|
||||
switch field.IndirectFieldType.Kind() {
|
||||
case reflect.Struct:
|
||||
relation.Type = HasOne
|
||||
case reflect.Slice:
|
||||
relation.Type = HasMany
|
||||
}
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
schema.setRelation(relation)
|
||||
switch relation.Type {
|
||||
case HasOne:
|
||||
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
||||
case HasMany:
|
||||
schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation)
|
||||
case BelongsTo:
|
||||
schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation)
|
||||
case Many2Many:
|
||||
schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
|
||||
}
|
||||
}
|
||||
|
||||
return relation
|
||||
}
|
||||
|
||||
// hasPolymorphicRelation check if has polymorphic relation
|
||||
// 1. `POLYMORPHIC` tag
|
||||
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
|
||||
func hasPolymorphicRelation(tagSettings map[string]string) bool {
|
||||
if _, ok := tagSettings["POLYMORPHIC"]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
_, hasType := tagSettings["POLYMORPHICTYPE"]
|
||||
_, hasId := tagSettings["POLYMORPHICID"]
|
||||
|
||||
return hasType && hasId
|
||||
}
|
||||
|
||||
func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// set non-embedded relation
|
||||
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||
if len(rel.Field.BindNames) > 1 {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
} else {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
|
||||
// set embedded relation
|
||||
if len(relation.Field.EmbeddedBindNames) <= 1 {
|
||||
return
|
||||
}
|
||||
relationships := &schema.Relationships
|
||||
for i, name := range relation.Field.EmbeddedBindNames {
|
||||
if i < len(relation.Field.EmbeddedBindNames)-1 {
|
||||
if relationships.EmbeddedRelations == nil {
|
||||
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||
}
|
||||
if r := relationships.EmbeddedRelations[name]; r == nil {
|
||||
relationships.EmbeddedRelations[name] = &Relationships{}
|
||||
}
|
||||
relationships = relationships.EmbeddedRelations[name]
|
||||
} else {
|
||||
if relationships.Relations == nil {
|
||||
relationships.Relations = map[string]*Relationship{}
|
||||
}
|
||||
relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||
//
|
||||
// type User struct {
|
||||
// Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Pet struct {
|
||||
// Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Toy struct {
|
||||
// OwnerID int
|
||||
// OwnerType string
|
||||
// }
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
|
||||
polymorphic := field.TagSettings["POLYMORPHIC"]
|
||||
|
||||
relation.Polymorphic = &Polymorphic{
|
||||
Value: schema.Table,
|
||||
}
|
||||
|
||||
var (
|
||||
typeName = polymorphic + "Type"
|
||||
typeId = polymorphic + "ID"
|
||||
)
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
|
||||
typeName = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
|
||||
typeId = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
|
||||
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
|
||||
relation.Polymorphic.Value = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicType == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicID == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryValue: relation.Polymorphic.Value,
|
||||
ForeignKey: relation.Polymorphic.PolymorphicType,
|
||||
})
|
||||
|
||||
primaryKeyField := schema.PrioritizedPrimaryField
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
|
||||
schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if primaryKeyField == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
|
||||
relation.FieldSchema, schema, field.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(primaryKeyField.DataType) {
|
||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||
}
|
||||
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
|
||||
if relation.Polymorphic.PolymorphicID.Size == 0 {
|
||||
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
|
||||
}
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryKeyField,
|
||||
ForeignKey: relation.Polymorphic.PolymorphicID,
|
||||
OwnPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
relation.Type = has
|
||||
}
|
||||
|
||||
func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
|
||||
relation.Type = Many2Many
|
||||
|
||||
var (
|
||||
err error
|
||||
joinTableFields []reflect.StructField
|
||||
fieldsMap = map[string]*Field{}
|
||||
ownFieldsMap = map[string]*Field{} // fix self join many2many
|
||||
referFieldsMap = map[string]*Field{}
|
||||
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
|
||||
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
|
||||
)
|
||||
|
||||
ownForeignFields := schema.PrimaryFields
|
||||
refForeignFields := relation.FieldSchema.PrimaryFields
|
||||
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
ownForeignFields = []*Field{}
|
||||
for _, foreignKey := range relation.foreignKeys {
|
||||
if field := schema.LookUpField(foreignKey); field != nil {
|
||||
ownForeignFields = append(ownForeignFields, field)
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(relation.primaryKeys) > 0 {
|
||||
refForeignFields = []*Field{}
|
||||
for _, foreignKey := range relation.primaryKeys {
|
||||
if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
|
||||
refForeignFields = append(refForeignFields, field)
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, ownField := range ownForeignFields {
|
||||
joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
|
||||
if len(joinForeignKeys) > idx {
|
||||
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
|
||||
}
|
||||
|
||||
ownFieldsMap[joinFieldName] = ownField
|
||||
fieldsMap[joinFieldName] = ownField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
PkgPath: ownField.StructField.PkgPath,
|
||||
Type: ownField.StructField.Type,
|
||||
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
|
||||
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
|
||||
for idx, relField := range refForeignFields {
|
||||
joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
|
||||
|
||||
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
||||
if field.Name != relation.FieldSchema.Name {
|
||||
joinFieldName = inflection.Singular(field.Name) + relField.Name
|
||||
} else {
|
||||
joinFieldName += "Reference"
|
||||
}
|
||||
}
|
||||
|
||||
if len(joinReferences) > idx {
|
||||
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
|
||||
}
|
||||
|
||||
referFieldsMap[joinFieldName] = relField
|
||||
|
||||
if _, ok := fieldsMap[joinFieldName]; !ok {
|
||||
fieldsMap[joinFieldName] = relField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
PkgPath: relField.StructField.PkgPath,
|
||||
Type: relField.StructField.Type,
|
||||
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
|
||||
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
|
||||
Type: schema.ModelType,
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||
schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
|
||||
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
|
||||
|
||||
relName := relation.Schema.Name
|
||||
relRefName := relation.FieldSchema.Name
|
||||
if relName == relRefName {
|
||||
relRefName = relation.Field.Name
|
||||
}
|
||||
|
||||
if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok {
|
||||
relation.JoinTable.Relationships.Relations[relName] = &Relationship{
|
||||
Name: relName,
|
||||
Type: BelongsTo,
|
||||
Schema: relation.JoinTable,
|
||||
FieldSchema: relation.Schema,
|
||||
}
|
||||
} else {
|
||||
relation.JoinTable.Relationships.Relations[relName].References = []*Reference{}
|
||||
}
|
||||
|
||||
if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok {
|
||||
relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{
|
||||
Name: relRefName,
|
||||
Type: BelongsTo,
|
||||
Schema: relation.JoinTable,
|
||||
FieldSchema: relation.FieldSchema,
|
||||
}
|
||||
} else {
|
||||
relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{}
|
||||
}
|
||||
|
||||
// build references
|
||||
for _, f := range relation.JoinTable.Fields {
|
||||
if f.Creatable || f.Readable || f.Updatable {
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(fieldsMap[f.Name].DataType) {
|
||||
f.DataType = fieldsMap[f.Name].DataType
|
||||
}
|
||||
f.GORMDataType = fieldsMap[f.Name].GORMDataType
|
||||
if f.Size == 0 {
|
||||
f.Size = fieldsMap[f.Name].Size
|
||||
}
|
||||
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
||||
|
||||
if of, ok := ownFieldsMap[f.Name]; ok {
|
||||
joinRel := relation.JoinTable.Relationships.Relations[relName]
|
||||
joinRel.Field = relation.Field
|
||||
joinRel.References = append(joinRel.References, &Reference{
|
||||
PrimaryKey: of,
|
||||
ForeignKey: f,
|
||||
})
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: of,
|
||||
ForeignKey: f,
|
||||
OwnPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
if rf, ok := referFieldsMap[f.Name]; ok {
|
||||
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
|
||||
if joinRefRel.Field == nil {
|
||||
joinRefRel.Field = relation.Field
|
||||
}
|
||||
joinRefRel.References = append(joinRefRel.References, &Reference{
|
||||
PrimaryKey: rf,
|
||||
ForeignKey: f,
|
||||
})
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: rf,
|
||||
ForeignKey: f,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type guessLevel int
|
||||
|
||||
const (
|
||||
guessGuess guessLevel = iota
|
||||
guessBelongs
|
||||
guessEmbeddedBelongs
|
||||
guessHas
|
||||
guessEmbeddedHas
|
||||
)
|
||||
|
||||
func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) {
|
||||
var (
|
||||
primaryFields, foreignFields []*Field
|
||||
primarySchema, foreignSchema = schema, relation.FieldSchema
|
||||
gl = cgl
|
||||
)
|
||||
|
||||
if gl == guessGuess {
|
||||
if field.Schema == relation.FieldSchema {
|
||||
gl = guessBelongs
|
||||
} else {
|
||||
gl = guessHas
|
||||
}
|
||||
}
|
||||
|
||||
reguessOrErr := func() {
|
||||
switch cgl {
|
||||
case guessGuess:
|
||||
schema.guessRelation(relation, field, guessBelongs)
|
||||
case guessBelongs:
|
||||
schema.guessRelation(relation, field, guessEmbeddedBelongs)
|
||||
case guessEmbeddedBelongs:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
case guessHas:
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
// case guessEmbeddedHas:
|
||||
default:
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
|
||||
schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
switch gl {
|
||||
case guessBelongs:
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
case guessEmbeddedBelongs:
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
case guessHas:
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
}
|
||||
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
for _, foreignKey := range relation.foreignKeys {
|
||||
f := foreignSchema.LookUpField(foreignKey)
|
||||
if f == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
foreignFields = append(foreignFields, f)
|
||||
}
|
||||
} else {
|
||||
primarySchemaName := primarySchema.Name
|
||||
if primarySchemaName == "" {
|
||||
primarySchemaName = relation.FieldSchema.Name
|
||||
}
|
||||
|
||||
if len(relation.primaryKeys) > 0 {
|
||||
for _, primaryKey := range relation.primaryKeys {
|
||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||
primaryFields = append(primaryFields, f)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
primaryFields = primarySchema.PrimaryFields
|
||||
}
|
||||
|
||||
primaryFieldLoop:
|
||||
for _, primaryField := range primaryFields {
|
||||
lookUpName := primarySchemaName + primaryField.Name
|
||||
if gl == guessBelongs {
|
||||
lookUpName = field.Name + primaryField.Name
|
||||
}
|
||||
|
||||
lookUpNames := []string{lookUpName}
|
||||
if len(primaryFields) == 1 {
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
}
|
||||
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpField(name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(foreignFields) == 0:
|
||||
reguessOrErr()
|
||||
return
|
||||
case len(relation.primaryKeys) > 0:
|
||||
for idx, primaryKey := range relation.primaryKeys {
|
||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||
if len(primaryFields) < idx+1 {
|
||||
primaryFields = append(primaryFields, f)
|
||||
} else if f != primaryFields[idx] {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
}
|
||||
case len(primaryFields) == 0:
|
||||
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
|
||||
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
|
||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
|
||||
} else {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// build references
|
||||
for idx, foreignField := range foreignFields {
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(primaryFields[idx].DataType) {
|
||||
foreignField.DataType = primaryFields[idx].DataType
|
||||
}
|
||||
foreignField.GORMDataType = primaryFields[idx].GORMDataType
|
||||
if foreignField.Size == 0 {
|
||||
foreignField.Size = primaryFields[idx].Size
|
||||
}
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryFields[idx],
|
||||
ForeignKey: foreignField,
|
||||
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
|
||||
})
|
||||
}
|
||||
|
||||
if gl == guessHas || gl == guessEmbeddedHas {
|
||||
relation.Type = has
|
||||
} else {
|
||||
relation.Type = BelongsTo
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint is ForeignKey Constraint
|
||||
type Constraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
Schema *Schema
|
||||
ForeignKeys []*Field
|
||||
ReferenceSchema *Schema
|
||||
References []*Field
|
||||
OnDelete string
|
||||
OnUpdate string
|
||||
}
|
||||
|
||||
func (constraint *Constraint) GetName() string { return constraint.Name }
|
||||
|
||||
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
references := make([]interface{}, 0, len(constraint.References))
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
str := rel.Field.TagSettings["CONSTRAINT"]
|
||||
if str == "-" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rel.Type == BelongsTo {
|
||||
for _, r := range rel.FieldSchema.Relationships.Relations {
|
||||
if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
|
||||
matched := true
|
||||
for idx, ref := range r.References {
|
||||
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
|
||||
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matched {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
name string
|
||||
idx = strings.IndexByte(str, ',')
|
||||
settings = ParseTagSetting(str, ",")
|
||||
)
|
||||
|
||||
// optimize match english letters and midline
|
||||
// The following code is basically called in for.
|
||||
// In order to avoid the performance problems caused by repeated compilation of regular expressions,
|
||||
// it only needs to be done once outside, so optimization is done here.
|
||||
if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) {
|
||||
name = str[0:idx]
|
||||
} else {
|
||||
name = rel.Schema.namer.RelationshipFKName(*rel)
|
||||
}
|
||||
|
||||
constraint := Constraint{
|
||||
Name: name,
|
||||
Field: rel.Field,
|
||||
OnUpdate: settings["ONUPDATE"],
|
||||
OnDelete: settings["ONDELETE"],
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) {
|
||||
constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
|
||||
constraint.References = append(constraint.References, ref.PrimaryKey)
|
||||
|
||||
if ref.OwnPrimaryKey {
|
||||
constraint.Schema = ref.ForeignKey.Schema
|
||||
constraint.ReferenceSchema = rel.Schema
|
||||
} else {
|
||||
constraint.Schema = rel.Schema
|
||||
constraint.ReferenceSchema = ref.PrimaryKey.Schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &constraint
|
||||
}
|
||||
|
||||
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
|
||||
table := rel.FieldSchema.Table
|
||||
foreignFields := []*Field{}
|
||||
relForeignKeys := []string{}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
table = rel.JoinTable.Table
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
} else {
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName},
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
} else {
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
|
||||
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
|
||||
|
||||
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||
return
|
||||
}
|
||||
|
||||
func copyableDataType(str DataType) bool {
|
||||
lowerStr := strings.ToLower(string(str))
|
||||
for _, s := range []string{"auto_increment", "primary key"} {
|
||||
if strings.Contains(lowerStr, s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
424
vendor/gorm.io/gorm/schema/schema.go
generated
vendored
Normal file
424
vendor/gorm.io/gorm/schema/schema.go
generated
vendored
Normal file
@@ -0,0 +1,424 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type callbackType string
|
||||
|
||||
const (
|
||||
callbackTypeBeforeCreate callbackType = "BeforeCreate"
|
||||
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
|
||||
callbackTypeAfterCreate callbackType = "AfterCreate"
|
||||
callbackTypeAfterUpdate callbackType = "AfterUpdate"
|
||||
callbackTypeBeforeSave callbackType = "BeforeSave"
|
||||
callbackTypeAfterSave callbackType = "AfterSave"
|
||||
callbackTypeBeforeDelete callbackType = "BeforeDelete"
|
||||
callbackTypeAfterDelete callbackType = "AfterDelete"
|
||||
callbackTypeAfterFind callbackType = "AfterFind"
|
||||
)
|
||||
|
||||
// ErrUnsupportedDataType unsupported data type
|
||||
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
||||
|
||||
type Schema struct {
|
||||
Name string
|
||||
ModelType reflect.Type
|
||||
Table string
|
||||
PrioritizedPrimaryField *Field
|
||||
DBNames []string
|
||||
PrimaryFields []*Field
|
||||
PrimaryFieldDBNames []string
|
||||
Fields []*Field
|
||||
FieldsByName map[string]*Field
|
||||
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||
FieldsByDBName map[string]*Field
|
||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||
Relationships Relationships
|
||||
CreateClauses []clause.Interface
|
||||
QueryClauses []clause.Interface
|
||||
UpdateClauses []clause.Interface
|
||||
DeleteClauses []clause.Interface
|
||||
BeforeCreate, AfterCreate bool
|
||||
BeforeUpdate, AfterUpdate bool
|
||||
BeforeDelete, AfterDelete bool
|
||||
BeforeSave, AfterSave bool
|
||||
AfterFind bool
|
||||
err error
|
||||
initialized chan struct{}
|
||||
namer Namer
|
||||
cacheStore *sync.Map
|
||||
}
|
||||
|
||||
func (schema Schema) String() string {
|
||||
if schema.ModelType.Name() == "" {
|
||||
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
|
||||
}
|
||||
|
||||
func (schema Schema) MakeSlice() reflect.Value {
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
|
||||
results := reflect.New(slice.Type())
|
||||
results.Elem().Set(slice)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func (schema Schema) LookUpField(name string) *Field {
|
||||
if field, ok := schema.FieldsByDBName[name]; ok {
|
||||
return field
|
||||
}
|
||||
if field, ok := schema.FieldsByName[name]; ok {
|
||||
return field
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookUpFieldByBindName looks for the closest field in the embedded struct.
|
||||
//
|
||||
// type Struct struct {
|
||||
// Embedded struct {
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
|
||||
// }
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
|
||||
// }
|
||||
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
|
||||
if len(bindNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := len(bindNames) - 1; i >= 0; i-- {
|
||||
find := strings.Join(bindNames[:i], ".") + "." + name
|
||||
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type TablerWithNamer interface {
|
||||
TableName(Namer) string
|
||||
}
|
||||
|
||||
// Parse get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||
}
|
||||
|
||||
// ParseWithSpecialTableName get data type from dialector with extra schema table
|
||||
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
|
||||
if dest == nil {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(dest)
|
||||
if value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
value = reflect.New(value.Type().Elem())
|
||||
}
|
||||
modelType := reflect.Indirect(value).Type()
|
||||
|
||||
if modelType.Kind() == reflect.Interface {
|
||||
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
|
||||
}
|
||||
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
if modelType.PkgPath() == "" {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
// Cache the Schema for performance,
|
||||
// Use the modelType or modelType + schemaTable (if it present) as cache key.
|
||||
var schemaCacheKey interface{}
|
||||
if specialTableName != "" {
|
||||
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
|
||||
} else {
|
||||
schemaCacheKey = modelType
|
||||
}
|
||||
|
||||
// Load exist schema cache, return if exists
|
||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||
s := v.(*Schema)
|
||||
// Wait for the initialization of other goroutines to complete
|
||||
<-s.initialized
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
modelValue := reflect.New(modelType)
|
||||
tableName := namer.TableName(modelType.Name())
|
||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||
tableName = tabler.TableName()
|
||||
}
|
||||
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
|
||||
tableName = tabler.TableName(namer)
|
||||
}
|
||||
if en, ok := namer.(embeddedNamer); ok {
|
||||
tableName = en.Table
|
||||
}
|
||||
if specialTableName != "" && specialTableName != tableName {
|
||||
tableName = specialTableName
|
||||
}
|
||||
|
||||
schema := &Schema{
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByBindName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
}
|
||||
// When the schema initialization is completed, the channel will be closed
|
||||
defer close(schema.initialized)
|
||||
|
||||
// Load exist schema cache, return if exists
|
||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||
s := v.(*Schema)
|
||||
// Wait for the initialization of other goroutines to complete
|
||||
<-s.initialized
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
|
||||
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
|
||||
} else {
|
||||
schema.Fields = append(schema.Fields, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.DBName == "" && field.DataType != "" {
|
||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||
}
|
||||
|
||||
bindName := field.BindName()
|
||||
if field.DBName != "" {
|
||||
// nonexistence or shortest path or first appear prioritized if has permission
|
||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
|
||||
schema.DBNames = append(schema.DBNames, field.DBName)
|
||||
}
|
||||
schema.FieldsByDBName[field.DBName] = field
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
|
||||
if v != nil && v.PrimaryKey {
|
||||
for idx, f := range schema.PrimaryFields {
|
||||
if f == v {
|
||||
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if field.PrimaryKey {
|
||||
schema.PrimaryFields = append(schema.PrimaryFields, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
}
|
||||
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
}
|
||||
|
||||
field.setupValuerAndSetter()
|
||||
}
|
||||
|
||||
prioritizedPrimaryField := schema.LookUpField("id")
|
||||
if prioritizedPrimaryField == nil {
|
||||
prioritizedPrimaryField = schema.LookUpField("ID")
|
||||
}
|
||||
|
||||
if prioritizedPrimaryField != nil {
|
||||
if prioritizedPrimaryField.PrimaryKey {
|
||||
schema.PrioritizedPrimaryField = prioritizedPrimaryField
|
||||
} else if len(schema.PrimaryFields) == 0 {
|
||||
prioritizedPrimaryField.PrimaryKey = true
|
||||
schema.PrioritizedPrimaryField = prioritizedPrimaryField
|
||||
schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
|
||||
}
|
||||
}
|
||||
|
||||
if schema.PrioritizedPrimaryField == nil {
|
||||
if len(schema.PrimaryFields) == 1 {
|
||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||
} else if len(schema.PrimaryFields) > 1 {
|
||||
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
|
||||
for _, field := range schema.PrimaryFields {
|
||||
if field.AutoIncrement {
|
||||
schema.PrioritizedPrimaryField = field
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range schema.PrimaryFields {
|
||||
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
|
||||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
}
|
||||
|
||||
if field := schema.PrioritizedPrimaryField; field != nil {
|
||||
switch field.GORMDataType {
|
||||
case Int, Uint:
|
||||
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
|
||||
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
|
||||
field.HasDefaultValue = true
|
||||
field.AutoIncrement = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
callbackTypes := []callbackType{
|
||||
callbackTypeBeforeCreate, callbackTypeAfterCreate,
|
||||
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
|
||||
callbackTypeBeforeSave, callbackTypeAfterSave,
|
||||
callbackTypeBeforeDelete, callbackTypeAfterDelete,
|
||||
callbackTypeAfterFind,
|
||||
}
|
||||
for _, cbName := range callbackTypes {
|
||||
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
|
||||
switch methodValue.Type().String() {
|
||||
case "func(*gorm.DB) error": // TODO hack
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
|
||||
default:
|
||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the schema
|
||||
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
|
||||
s := v.(*Schema)
|
||||
// Wait for the initialization of other goroutines to complete
|
||||
<-s.initialized
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if schema.err != nil {
|
||||
logger.Default.Error(context.Background(), schema.err.Error())
|
||||
cacheStore.Delete(modelType)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, schema.err
|
||||
} else {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[field.BindName()] = field
|
||||
}
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(field.IndirectFieldType)
|
||||
fieldInterface := fieldValue.Interface()
|
||||
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
|
||||
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
|
||||
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
|
||||
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
|
||||
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schema, schema.err
|
||||
}
|
||||
|
||||
// This unrolling is needed to show to the compiler the exact set of methods
|
||||
// that can be used on the modelType.
|
||||
// Prior to go1.22 any use of MethodByName would cause the linker to
|
||||
// abandon dead code elimination for the entire binary.
|
||||
// As of go1.22 the compiler supports one special case of a string constant
|
||||
// being passed to MethodByName. For enterprise customers or those building
|
||||
// large binaries, this gives a significant reduction in binary size.
|
||||
// https://github.com/golang/go/issues/62257
|
||||
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
|
||||
switch cbType {
|
||||
case callbackTypeBeforeCreate:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeCreate))
|
||||
case callbackTypeAfterCreate:
|
||||
return modelType.MethodByName(string(callbackTypeAfterCreate))
|
||||
case callbackTypeBeforeUpdate:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
|
||||
case callbackTypeAfterUpdate:
|
||||
return modelType.MethodByName(string(callbackTypeAfterUpdate))
|
||||
case callbackTypeBeforeSave:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeSave))
|
||||
case callbackTypeAfterSave:
|
||||
return modelType.MethodByName(string(callbackTypeAfterSave))
|
||||
case callbackTypeBeforeDelete:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeDelete))
|
||||
case callbackTypeAfterDelete:
|
||||
return modelType.MethodByName(string(callbackTypeAfterDelete))
|
||||
case callbackTypeAfterFind:
|
||||
return modelType.MethodByName(string(callbackTypeAfterFind))
|
||||
default:
|
||||
return reflect.ValueOf(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
if modelType.PkgPath() == "" {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
if v, ok := cacheStore.Load(modelType); ok {
|
||||
return v.(*Schema), nil
|
||||
}
|
||||
|
||||
return Parse(dest, cacheStore, namer)
|
||||
}
|
||||
173
vendor/gorm.io/gorm/schema/serializer.go
generated
vendored
Normal file
173
vendor/gorm.io/gorm/schema/serializer.go
generated
vendored
Normal file
@@ -0,0 +1,173 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var serializerMap = sync.Map{}
|
||||
|
||||
// RegisterSerializer register serializer
|
||||
func RegisterSerializer(name string, serializer SerializerInterface) {
|
||||
serializerMap.Store(strings.ToLower(name), serializer)
|
||||
}
|
||||
|
||||
// GetSerializer get serializer
|
||||
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
|
||||
v, ok := serializerMap.Load(strings.ToLower(name))
|
||||
if ok {
|
||||
serializer, ok = v.(SerializerInterface)
|
||||
}
|
||||
return serializer, ok
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterSerializer("json", JSONSerializer{})
|
||||
RegisterSerializer("unixtime", UnixSecondSerializer{})
|
||||
RegisterSerializer("gob", GobSerializer{})
|
||||
}
|
||||
|
||||
// Serializer field value serializer
|
||||
type serializer struct {
|
||||
Field *Field
|
||||
Serializer SerializerInterface
|
||||
SerializeValuer SerializerValuerInterface
|
||||
Destination reflect.Value
|
||||
Context context.Context
|
||||
value interface{}
|
||||
fieldValue interface{}
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner interface
|
||||
func (s *serializer) Scan(value interface{}) error {
|
||||
s.value = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
func (s serializer) Value() (driver.Value, error) {
|
||||
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
|
||||
}
|
||||
|
||||
// SerializerInterface serializer interface
|
||||
type SerializerInterface interface {
|
||||
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
|
||||
SerializerValuerInterface
|
||||
}
|
||||
|
||||
// SerializerValuerInterface serializer valuer interface
|
||||
type SerializerValuerInterface interface {
|
||||
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
|
||||
}
|
||||
|
||||
// JSONSerializer json serializer
|
||||
type JSONSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
fieldValue := reflect.New(field.FieldType)
|
||||
|
||||
if dbValue != nil {
|
||||
var bytes []byte
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
bytes, err = json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(bytes) > 0 {
|
||||
err = json.Unmarshal(bytes, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||
return
|
||||
}
|
||||
|
||||
// Value implements serializer interface
|
||||
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
result, err := json.Marshal(fieldValue)
|
||||
if string(result) == "null" {
|
||||
if field.TagSettings["NOT NULL"] != "" {
|
||||
return "", nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return string(result), err
|
||||
}
|
||||
|
||||
// UnixSecondSerializer json serializer
|
||||
type UnixSecondSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
t := sql.NullTime{}
|
||||
if err = t.Scan(dbValue); err == nil && t.Valid {
|
||||
err = field.Set(ctx, dst, t.Time.Unix())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Value implements serializer interface
|
||||
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
|
||||
rv := reflect.ValueOf(fieldValue)
|
||||
switch v := fieldValue.(type) {
|
||||
case int64, int, uint, uint64, int32, uint32, int16, uint16:
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||
if rv.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
default:
|
||||
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GobSerializer gob serializer
|
||||
type GobSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
fieldValue := reflect.New(field.FieldType)
|
||||
|
||||
if dbValue != nil {
|
||||
var bytesValue []byte
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytesValue = v
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
||||
}
|
||||
if len(bytesValue) > 0 {
|
||||
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
||||
err = decoder.Decode(fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||
return
|
||||
}
|
||||
|
||||
// Value implements serializer interface
|
||||
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
err := gob.NewEncoder(buf).Encode(fieldValue)
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
213
vendor/gorm.io/gorm/schema/utils.go
generated
vendored
Normal file
213
vendor/gorm.io/gorm/schema/utils.go
generated
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
var embeddedCacheKey = "embedded_cache_store"
|
||||
|
||||
func ParseTagSetting(str string, sep string) map[string]string {
|
||||
settings := map[string]string{}
|
||||
names := strings.Split(str, sep)
|
||||
|
||||
for i := 0; i < len(names); i++ {
|
||||
j := i
|
||||
if len(names[j]) > 0 {
|
||||
for {
|
||||
if names[j][len(names[j])-1] == '\\' {
|
||||
i++
|
||||
names[j] = names[j][0:len(names[j])-1] + sep + names[i]
|
||||
names[i] = ""
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
values := strings.Split(names[j], ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(values[0]))
|
||||
|
||||
if len(values) >= 2 {
|
||||
settings[k] = strings.Join(values[1:], ":")
|
||||
} else if k != "" {
|
||||
settings[k] = k
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func toColumns(val string) (results []string) {
|
||||
if val != "" {
|
||||
for _, v := range strings.Split(val, ",") {
|
||||
results = append(results, strings.TrimSpace(v))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag {
|
||||
for _, name := range names {
|
||||
tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
|
||||
}
|
||||
return tag
|
||||
}
|
||||
|
||||
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
|
||||
t := tag.Get("gorm")
|
||||
if strings.Contains(t, value) {
|
||||
return tag
|
||||
}
|
||||
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
|
||||
}
|
||||
|
||||
// GetRelationsValues get relations's values from a reflect value
|
||||
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
||||
for _, rel := range rels {
|
||||
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1)
|
||||
|
||||
appendToResults := func(value reflect.Value) {
|
||||
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
|
||||
result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
|
||||
switch result.Kind() {
|
||||
case reflect.Struct:
|
||||
reflectResults = reflect.Append(reflectResults, result.Addr())
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < result.Len(); i++ {
|
||||
if elem := result.Index(i); elem.Kind() == reflect.Ptr {
|
||||
reflectResults = reflect.Append(reflectResults, elem)
|
||||
} else {
|
||||
reflectResults = reflect.Append(reflectResults, elem.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
appendToResults(reflectValue)
|
||||
case reflect.Slice:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
appendToResults(reflectValue.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue = reflectResults
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetIdentityFieldValuesMap get identity map from fields
|
||||
func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
||||
var (
|
||||
results = [][]interface{}{}
|
||||
dataResults = map[string][]reflect.Value{}
|
||||
loaded = map[interface{}]bool{}
|
||||
notZero, zero bool
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Ptr ||
|
||||
reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||
|
||||
for idx, field := range fields {
|
||||
results[0][idx], zero = field.ValueOf(ctx, reflectValue)
|
||||
notZero = notZero || !zero
|
||||
}
|
||||
|
||||
if !notZero {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
elem := reflectValue.Index(i)
|
||||
elemKey := elem.Interface()
|
||||
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
|
||||
elemKey = elem.Addr().Interface()
|
||||
}
|
||||
|
||||
if _, ok := loaded[elemKey]; ok {
|
||||
continue
|
||||
}
|
||||
loaded[elemKey] = true
|
||||
|
||||
fieldValues := make([]interface{}, len(fields))
|
||||
notZero = false
|
||||
for idx, field := range fields {
|
||||
fieldValues[idx], zero = field.ValueOf(ctx, elem)
|
||||
notZero = notZero || !zero
|
||||
}
|
||||
|
||||
if notZero {
|
||||
dataKey := utils.ToStringKey(fieldValues...)
|
||||
if _, ok := dataResults[dataKey]; !ok {
|
||||
results = append(results, fieldValues)
|
||||
dataResults[dataKey] = []reflect.Value{elem}
|
||||
} else {
|
||||
dataResults[dataKey] = append(dataResults[dataKey], elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dataResults, results
|
||||
}
|
||||
|
||||
// GetIdentityFieldValuesMapFromValues get identity map from fields
|
||||
func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
||||
resultsMap := map[string][]reflect.Value{}
|
||||
results := [][]interface{}{}
|
||||
|
||||
for _, v := range values {
|
||||
rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
|
||||
for k, v := range rm {
|
||||
resultsMap[k] = append(resultsMap[k], v...)
|
||||
}
|
||||
results = append(results, rs...)
|
||||
}
|
||||
return resultsMap, results
|
||||
}
|
||||
|
||||
// ToQueryValues to query values
|
||||
func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
|
||||
queryValues := make([]interface{}, len(foreignValues))
|
||||
if len(foreignKeys) == 1 {
|
||||
for idx, r := range foreignValues {
|
||||
queryValues[idx] = r[0]
|
||||
}
|
||||
|
||||
return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
|
||||
}
|
||||
|
||||
columns := make([]clause.Column, len(foreignKeys))
|
||||
for idx, key := range foreignKeys {
|
||||
columns[idx] = clause.Column{Table: table, Name: key}
|
||||
}
|
||||
|
||||
for idx, r := range foreignValues {
|
||||
queryValues[idx] = r
|
||||
}
|
||||
|
||||
return columns, queryValues
|
||||
}
|
||||
|
||||
type embeddedNamer struct {
|
||||
Table string
|
||||
Namer
|
||||
}
|
||||
170
vendor/gorm.io/gorm/soft_delete.go
generated
vendored
Normal file
170
vendor/gorm.io/gorm/soft_delete.go
generated
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type DeletedAt sql.NullTime
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *DeletedAt) Scan(value interface{}) error {
|
||||
return (*sql.NullTime)(n).Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n DeletedAt) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Time, nil
|
||||
}
|
||||
|
||||
func (n DeletedAt) MarshalJSON() ([]byte, error) {
|
||||
if n.Valid {
|
||||
return json.Marshal(n.Time)
|
||||
}
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
|
||||
func (n *DeletedAt) UnmarshalJSON(b []byte) error {
|
||||
if string(b) == "null" {
|
||||
n.Valid = false
|
||||
return nil
|
||||
}
|
||||
err := json.Unmarshal(b, &n.Time)
|
||||
if err == nil {
|
||||
n.Valid = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
func parseZeroValueTag(f *schema.Field) sql.NullString {
|
||||
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
|
||||
if _, err := now.Parse(v); err == nil {
|
||||
return sql.NullString{String: v, Valid: true}
|
||||
}
|
||||
}
|
||||
return sql.NullString{Valid: false}
|
||||
}
|
||||
|
||||
type SoftDeleteQueryClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped {
|
||||
if c, ok := stmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 {
|
||||
for _, expr := range where.Exprs {
|
||||
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
|
||||
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
|
||||
c.Expression = where
|
||||
stmt.Clauses["WHERE"] = c
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
|
||||
}})
|
||||
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
type SoftDeleteUpdateClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
type SoftDeleteDeleteClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||
curTime := stmt.DB.NowFunc()
|
||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
|
||||
stmt.SetColumn(sd.Field.DBName, curTime, true)
|
||||
|
||||
if stmt.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
stmt.AddClauseIfNotExists(clause.Update{})
|
||||
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||
}
|
||||
}
|
||||
757
vendor/gorm.io/gorm/statement.go
generated
vendored
Normal file
757
vendor/gorm.io/gorm/statement.go
generated
vendored
Normal file
@@ -0,0 +1,757 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Statement statement
|
||||
type Statement struct {
|
||||
*DB
|
||||
TableExpr *clause.Expr
|
||||
Table string
|
||||
Model interface{}
|
||||
Unscoped bool
|
||||
Dest interface{}
|
||||
ReflectValue reflect.Value
|
||||
Clauses map[string]clause.Clause
|
||||
BuildClauses []string
|
||||
Distinct bool
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
ColumnMapping map[string]string // map columns
|
||||
Joins []join
|
||||
Preloads map[string][]interface{}
|
||||
Settings sync.Map
|
||||
ConnPool ConnPool
|
||||
Schema *schema.Schema
|
||||
Context context.Context
|
||||
RaiseErrorOnNotFound bool
|
||||
SkipHooks bool
|
||||
SQL strings.Builder
|
||||
Vars []interface{}
|
||||
CurDestIndex int
|
||||
attrs []interface{}
|
||||
assigns []interface{}
|
||||
scopes []func(*DB) *DB
|
||||
Result *result
|
||||
}
|
||||
|
||||
type join struct {
|
||||
Name string
|
||||
Alias string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
Expression clause.Expression
|
||||
JoinType clause.JoinType
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
type StatementModifier interface {
|
||||
ModifyStatement(*Statement)
|
||||
}
|
||||
|
||||
// WriteString write string
|
||||
func (stmt *Statement) WriteString(str string) (int, error) {
|
||||
return stmt.SQL.WriteString(str)
|
||||
}
|
||||
|
||||
// WriteByte write byte
|
||||
func (stmt *Statement) WriteByte(c byte) error {
|
||||
return stmt.SQL.WriteByte(c)
|
||||
}
|
||||
|
||||
// WriteQuoted write quoted value
|
||||
func (stmt *Statement) WriteQuoted(value interface{}) {
|
||||
stmt.QuoteTo(&stmt.SQL, value)
|
||||
}
|
||||
|
||||
// QuoteTo write quoted value to writer
|
||||
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||
write := func(raw bool, str string) {
|
||||
if raw {
|
||||
writer.WriteString(str)
|
||||
} else {
|
||||
stmt.DB.Dialector.QuoteTo(writer, str)
|
||||
}
|
||||
}
|
||||
|
||||
switch v := field.(type) {
|
||||
case clause.Table:
|
||||
if v.Name == clause.CurrentTable {
|
||||
if stmt.TableExpr != nil {
|
||||
stmt.TableExpr.Build(stmt)
|
||||
} else {
|
||||
write(v.Raw, stmt.Table)
|
||||
}
|
||||
} else {
|
||||
write(v.Raw, v.Name)
|
||||
}
|
||||
|
||||
if v.Alias != "" {
|
||||
writer.WriteByte(' ')
|
||||
write(v.Raw, v.Alias)
|
||||
}
|
||||
case clause.Column:
|
||||
if v.Table != "" {
|
||||
if v.Table == clause.CurrentTable {
|
||||
write(v.Raw, stmt.Table)
|
||||
} else {
|
||||
write(v.Raw, v.Table)
|
||||
}
|
||||
writer.WriteByte('.')
|
||||
}
|
||||
|
||||
if v.Name == clause.PrimaryKey {
|
||||
if stmt.Schema == nil {
|
||||
stmt.DB.AddError(ErrModelValueRequired)
|
||||
} else if stmt.Schema.PrioritizedPrimaryField != nil {
|
||||
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||
} else if len(stmt.Schema.DBNames) > 0 {
|
||||
write(v.Raw, stmt.Schema.DBNames[0])
|
||||
} else {
|
||||
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
|
||||
}
|
||||
} else {
|
||||
write(v.Raw, v.Name)
|
||||
}
|
||||
|
||||
if v.Alias != "" {
|
||||
writer.WriteString(" AS ")
|
||||
write(v.Raw, v.Alias)
|
||||
}
|
||||
case []clause.Column:
|
||||
writer.WriteByte('(')
|
||||
for idx, d := range v {
|
||||
if idx > 0 {
|
||||
writer.WriteByte(',')
|
||||
}
|
||||
stmt.QuoteTo(writer, d)
|
||||
}
|
||||
writer.WriteByte(')')
|
||||
case clause.Expr:
|
||||
v.Build(stmt)
|
||||
case string:
|
||||
stmt.DB.Dialector.QuoteTo(writer, v)
|
||||
case []string:
|
||||
writer.WriteByte('(')
|
||||
for idx, d := range v {
|
||||
if idx > 0 {
|
||||
writer.WriteByte(',')
|
||||
}
|
||||
stmt.DB.Dialector.QuoteTo(writer, d)
|
||||
}
|
||||
writer.WriteByte(')')
|
||||
default:
|
||||
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
|
||||
}
|
||||
}
|
||||
|
||||
// Quote returns quoted value
|
||||
func (stmt *Statement) Quote(field interface{}) string {
|
||||
var builder strings.Builder
|
||||
stmt.QuoteTo(&builder, field)
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// AddVar add var
|
||||
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
for idx, v := range vars {
|
||||
if idx > 0 {
|
||||
writer.WriteByte(',')
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case sql.NamedArg:
|
||||
stmt.Vars = append(stmt.Vars, v.Value)
|
||||
case clause.Column, clause.Table:
|
||||
stmt.QuoteTo(writer, v)
|
||||
case Valuer:
|
||||
reflectValue := reflect.ValueOf(v)
|
||||
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
|
||||
stmt.AddVar(writer, nil)
|
||||
} else {
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
}
|
||||
case clause.Interface:
|
||||
c := clause.Clause{Name: v.Name()}
|
||||
v.MergeClause(&c)
|
||||
c.Build(stmt)
|
||||
case clause.Expression:
|
||||
v.Build(stmt)
|
||||
case driver.Valuer:
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
case []byte:
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
case []interface{}:
|
||||
if len(v) > 0 {
|
||||
writer.WriteByte('(')
|
||||
stmt.AddVar(writer, v...)
|
||||
writer.WriteByte(')')
|
||||
} else {
|
||||
writer.WriteString("(NULL)")
|
||||
}
|
||||
case interface{ getInstance() *DB }:
|
||||
cv := v.getInstance()
|
||||
|
||||
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||
if cv.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
vars = subdb.Statement.Vars
|
||||
sql = cv.Statement.SQL.String()
|
||||
)
|
||||
|
||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||
for _, vv := range vars {
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||
bindvar := strings.Builder{}
|
||||
cv.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
subdb.Statement.SQL.Reset()
|
||||
subdb.Statement.Vars = stmt.Vars
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
}
|
||||
} else {
|
||||
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
}
|
||||
|
||||
writer.WriteString(subdb.Statement.SQL.String())
|
||||
stmt.Vars = subdb.Statement.Vars
|
||||
default:
|
||||
switch rv := reflect.ValueOf(v); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
writer.WriteString("(NULL)")
|
||||
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
} else {
|
||||
writer.WriteByte('(')
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
writer.WriteByte(',')
|
||||
}
|
||||
stmt.AddVar(writer, rv.Index(i).Interface())
|
||||
}
|
||||
writer.WriteByte(')')
|
||||
}
|
||||
default:
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddClause add clause
|
||||
func (stmt *Statement) AddClause(v clause.Interface) {
|
||||
if optimizer, ok := v.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(stmt)
|
||||
} else {
|
||||
name := v.Name()
|
||||
c := stmt.Clauses[name]
|
||||
c.Name = name
|
||||
v.MergeClause(&c)
|
||||
stmt.Clauses[name] = c
|
||||
}
|
||||
}
|
||||
|
||||
// AddClauseIfNotExists add clause if not exists
|
||||
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
||||
if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
|
||||
stmt.AddClause(v)
|
||||
}
|
||||
}
|
||||
|
||||
// BuildCondition build condition
|
||||
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
|
||||
if s, ok := query.(string); ok {
|
||||
// if it is a number, then treats it as primary key
|
||||
if _, err := strconv.Atoi(s); err != nil {
|
||||
if s == "" && len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
|
||||
// looks like a where condition
|
||||
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
||||
}
|
||||
|
||||
if len(args) > 0 && strings.Contains(s, "@") {
|
||||
// looks like a named query
|
||||
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
|
||||
}
|
||||
|
||||
if strings.Contains(strings.TrimSpace(s), " ") {
|
||||
// looks like a where condition
|
||||
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
||||
}
|
||||
|
||||
if len(args) == 1 {
|
||||
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conds := make([]clause.Expression, 0, 4)
|
||||
args = append([]interface{}{query}, args...)
|
||||
for idx, arg := range args {
|
||||
if arg == nil {
|
||||
continue
|
||||
}
|
||||
if valuer, ok := arg.(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
||||
curTable := stmt.Table
|
||||
if curTable == "" {
|
||||
curTable = clause.CurrentTable
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case clause.Expression:
|
||||
conds = append(conds, v)
|
||||
case *DB:
|
||||
v.executeScopes()
|
||||
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
if len(where.Exprs) == 1 {
|
||||
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
||||
where.Exprs[0] = clause.AndConditions(orConds)
|
||||
}
|
||||
}
|
||||
conds = append(conds, clause.And(where.Exprs...))
|
||||
} else if cs.Expression != nil {
|
||||
conds = append(conds, cs.Expression)
|
||||
}
|
||||
}
|
||||
case map[interface{}]interface{}:
|
||||
for i, j := range v {
|
||||
conds = append(conds, clause.Eq{Column: i, Value: j})
|
||||
}
|
||||
case map[string]string:
|
||||
keys := make([]string, 0, len(v))
|
||||
for i := range v {
|
||||
keys = append(keys, i)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
column := clause.Column{Name: key, Table: curTable}
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
}
|
||||
case map[string]interface{}:
|
||||
keys := make([]string, 0, len(v))
|
||||
for i := range v {
|
||||
keys = append(keys, i)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||
column := clause.Column{Name: key, Table: curTable}
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := v[key].(driver.Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
} else if _, ok := v[key].(Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
} else {
|
||||
// optimize reflect value length
|
||||
valueLen := reflectValue.Len()
|
||||
values := make([]interface{}, valueLen)
|
||||
for i := 0; i < valueLen; i++ {
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||
}
|
||||
default:
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
}
|
||||
}
|
||||
default:
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
||||
for reflectValue.Kind() == reflect.Ptr {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||
selectedColumns := map[string]bool{}
|
||||
if idx == 0 {
|
||||
for _, v := range args[1:] {
|
||||
if vs, ok := v.(string); ok {
|
||||
selectedColumns[vs] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
restricted := len(selectedColumns) != 0
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, field := range s.Fields {
|
||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
for _, field := range s.Fields {
|
||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if restricted {
|
||||
break
|
||||
}
|
||||
} else if !reflectValue.IsValid() {
|
||||
stmt.AddError(ErrInvalidData)
|
||||
} else if len(conds) == 0 {
|
||||
if len(args) == 1 {
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
// optimize reflect value length
|
||||
valueLen := reflectValue.Len()
|
||||
values := make([]interface{}, valueLen)
|
||||
for i := 0; i < valueLen; i++ {
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
func (stmt *Statement) Build(clauses ...string) {
|
||||
var firstClauseWritten bool
|
||||
|
||||
for _, name := range clauses {
|
||||
if c, ok := stmt.Clauses[name]; ok {
|
||||
if firstClauseWritten {
|
||||
stmt.WriteByte(' ')
|
||||
}
|
||||
|
||||
firstClauseWritten = true
|
||||
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
|
||||
b(c, stmt)
|
||||
} else {
|
||||
c.Build(stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
return stmt.ParseWithSpecialTableName(value, "")
|
||||
}
|
||||
|
||||
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
|
||||
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
|
||||
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
|
||||
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
|
||||
stmt.Table = tables[1]
|
||||
return
|
||||
}
|
||||
|
||||
stmt.Table = stmt.Schema.Table
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (stmt *Statement) clone() *Statement {
|
||||
newStmt := &Statement{
|
||||
TableExpr: stmt.TableExpr,
|
||||
Table: stmt.Table,
|
||||
Model: stmt.Model,
|
||||
Unscoped: stmt.Unscoped,
|
||||
Dest: stmt.Dest,
|
||||
ReflectValue: stmt.ReflectValue,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Distinct: stmt.Distinct,
|
||||
Selects: stmt.Selects,
|
||||
Omits: stmt.Omits,
|
||||
ColumnMapping: stmt.ColumnMapping,
|
||||
Preloads: map[string][]interface{}{},
|
||||
ConnPool: stmt.ConnPool,
|
||||
Schema: stmt.Schema,
|
||||
Context: stmt.Context,
|
||||
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
||||
SkipHooks: stmt.SkipHooks,
|
||||
Result: stmt.Result,
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
newStmt.SQL.WriteString(stmt.SQL.String())
|
||||
newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
|
||||
newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
|
||||
}
|
||||
|
||||
for k, c := range stmt.Clauses {
|
||||
newStmt.Clauses[k] = c
|
||||
}
|
||||
|
||||
for k, p := range stmt.Preloads {
|
||||
newStmt.Preloads[k] = p
|
||||
}
|
||||
|
||||
if len(stmt.Joins) > 0 {
|
||||
newStmt.Joins = make([]join, len(stmt.Joins))
|
||||
copy(newStmt.Joins, stmt.Joins)
|
||||
}
|
||||
|
||||
if len(stmt.scopes) > 0 {
|
||||
newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
|
||||
copy(newStmt.scopes, stmt.scopes)
|
||||
}
|
||||
|
||||
stmt.Settings.Range(func(k, v interface{}) bool {
|
||||
newStmt.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
return newStmt
|
||||
}
|
||||
|
||||
// SetColumn set column's value
|
||||
//
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
v[name] = value
|
||||
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
|
||||
for _, m := range v {
|
||||
m[name] = value
|
||||
}
|
||||
} else if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
destValue := reflect.ValueOf(stmt.Dest)
|
||||
for destValue.Kind() == reflect.Ptr {
|
||||
destValue = destValue.Elem()
|
||||
}
|
||||
|
||||
if stmt.ReflectValue != destValue {
|
||||
if !destValue.CanAddr() {
|
||||
destValueCanAddr := reflect.New(destValue.Type())
|
||||
destValueCanAddr.Elem().Set(destValue)
|
||||
stmt.Dest = destValueCanAddr.Interface()
|
||||
destValue = destValueCanAddr.Elem()
|
||||
}
|
||||
|
||||
switch destValue.Kind() {
|
||||
case reflect.Struct:
|
||||
stmt.AddError(field.Set(stmt.Context, destValue, value))
|
||||
default:
|
||||
stmt.AddError(ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if len(fromCallbacks) > 0 {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
|
||||
}
|
||||
} else {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !stmt.ReflectValue.CanAddr() {
|
||||
stmt.AddError(ErrInvalidValue)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
|
||||
}
|
||||
} else {
|
||||
stmt.AddError(ErrInvalidField)
|
||||
}
|
||||
} else {
|
||||
stmt.AddError(ErrInvalidField)
|
||||
}
|
||||
}
|
||||
|
||||
// Changed check model changed or not when updating
|
||||
func (stmt *Statement) Changed(fields ...string) bool {
|
||||
modelValue := stmt.ReflectValue
|
||||
switch modelValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
|
||||
}
|
||||
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
||||
changed := func(field *schema.Field) bool {
|
||||
fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if mv, mok := stmt.Dest.(map[string]interface{}); mok {
|
||||
if fv, ok := mv[field.Name]; ok {
|
||||
return !utils.AssertEqual(fv, fieldValue)
|
||||
} else if fv, ok := mv[field.DBName]; ok {
|
||||
return !utils.AssertEqual(fv, fieldValue)
|
||||
}
|
||||
} else {
|
||||
destValue := reflect.ValueOf(stmt.Dest)
|
||||
for destValue.Kind() == reflect.Ptr {
|
||||
destValue = destValue.Elem()
|
||||
}
|
||||
|
||||
changedValue, zero := field.ValueOf(stmt.Context, destValue)
|
||||
if v {
|
||||
return !utils.AssertEqual(changedValue, fieldValue)
|
||||
}
|
||||
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if len(fields) == 0 {
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
if changed(field) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, name := range fields {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
if changed(field) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var matchName = func() func(tableColumn string) (table, column string) {
|
||||
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
|
||||
return func(tableColumn string) (table, column string) {
|
||||
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
|
||||
table = matches[1]
|
||||
star := matches[2]
|
||||
columnName := matches[3]
|
||||
if star != "" {
|
||||
return table, star
|
||||
}
|
||||
return table, columnName
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
}()
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
results := map[string]bool{}
|
||||
notRestricted := false
|
||||
|
||||
processColumn := func(column string, result bool) {
|
||||
if stmt.Schema == nil {
|
||||
results[column] = result
|
||||
} else if column == "*" {
|
||||
notRestricted = result
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else if column == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = result
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = result
|
||||
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
|
||||
if col == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else {
|
||||
results[col] = result
|
||||
}
|
||||
} else {
|
||||
results[column] = result
|
||||
}
|
||||
}
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
processColumn(column, true)
|
||||
}
|
||||
|
||||
// omit columns
|
||||
for _, column := range stmt.Omits {
|
||||
processColumn(column, false)
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
for _, field := range stmt.Schema.FieldsByName {
|
||||
name := field.DBName
|
||||
if name == "" {
|
||||
name = field.Name
|
||||
}
|
||||
|
||||
if requireCreate && !field.Creatable {
|
||||
results[name] = false
|
||||
} else if requireUpdate && !field.Updatable {
|
||||
results[name] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, !notRestricted && len(stmt.Selects) > 0
|
||||
}
|
||||
179
vendor/gorm.io/gorm/utils/utils.go
generated
vendored
Normal file
179
vendor/gorm.io/gorm/utils/utils.go
generated
vendored
Normal file
@@ -0,0 +1,179 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var gormSourceDir string
|
||||
|
||||
func init() {
|
||||
_, file, _, _ := runtime.Caller(0)
|
||||
// compatible solution to get gorm source directory with various operating systems
|
||||
gormSourceDir = sourceDir(file)
|
||||
}
|
||||
|
||||
func sourceDir(file string) string {
|
||||
dir := filepath.Dir(file)
|
||||
dir = filepath.Dir(dir)
|
||||
|
||||
s := filepath.Dir(dir)
|
||||
if filepath.Base(s) != "gorm.io" {
|
||||
s = dir
|
||||
}
|
||||
return filepath.ToSlash(s) + "/"
|
||||
}
|
||||
|
||||
// FileWithLineNum return the file name and line number of the current file
|
||||
func FileWithLineNum() string {
|
||||
pcs := [13]uintptr{}
|
||||
// the third caller usually from gorm internal
|
||||
len := runtime.Callers(3, pcs[:])
|
||||
frames := runtime.CallersFrames(pcs[:len])
|
||||
for i := 0; i < len; i++ {
|
||||
// second return value is "more", not "ok"
|
||||
frame, _ := frames.Next()
|
||||
if (!strings.HasPrefix(frame.File, gormSourceDir) ||
|
||||
strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") {
|
||||
return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10))
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func IsValidDBNameChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
|
||||
}
|
||||
|
||||
// CheckTruth check string true or not
|
||||
func CheckTruth(vals ...string) bool {
|
||||
for _, val := range vals {
|
||||
if val != "" && !strings.EqualFold(val, "false") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ToStringKey(values ...interface{}) string {
|
||||
results := make([]string, len(values))
|
||||
|
||||
for idx, value := range values {
|
||||
if valuer, ok := value.(driver.Valuer); ok {
|
||||
value, _ = valuer.Value()
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
results[idx] = v
|
||||
case []byte:
|
||||
results[idx] = string(v)
|
||||
case uint:
|
||||
results[idx] = strconv.FormatUint(uint64(v), 10)
|
||||
default:
|
||||
results[idx] = "nil"
|
||||
vv := reflect.ValueOf(v)
|
||||
if vv.IsValid() && !vv.IsZero() {
|
||||
results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(results, "_")
|
||||
}
|
||||
|
||||
func Contains(elems []string, elem string) bool {
|
||||
for _, e := range elems {
|
||||
if elem == e {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AssertEqual(x, y interface{}) bool {
|
||||
if reflect.DeepEqual(x, y) {
|
||||
return true
|
||||
}
|
||||
if x == nil || y == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
xval := reflect.ValueOf(x)
|
||||
yval := reflect.ValueOf(y)
|
||||
if xval.Kind() == reflect.Ptr && xval.IsNil() ||
|
||||
yval.Kind() == reflect.Ptr && yval.IsNil() {
|
||||
return false
|
||||
}
|
||||
|
||||
if valuer, ok := x.(driver.Valuer); ok {
|
||||
x, _ = valuer.Value()
|
||||
}
|
||||
if valuer, ok := y.(driver.Valuer); ok {
|
||||
y, _ = valuer.Value()
|
||||
}
|
||||
return reflect.DeepEqual(x, y)
|
||||
}
|
||||
|
||||
func ToString(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case int:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(v, 10)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const nestedRelationSplit = "__"
|
||||
|
||||
// NestedRelationName nested relationships like `Manager__Company`
|
||||
func NestedRelationName(prefix, name string) string {
|
||||
return prefix + nestedRelationSplit + name
|
||||
}
|
||||
|
||||
// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}`
|
||||
func SplitNestedRelationName(name string) []string {
|
||||
return strings.Split(name, nestedRelationSplit)
|
||||
}
|
||||
|
||||
// JoinNestedRelationNames nested relationships like `Manager__Company`
|
||||
func JoinNestedRelationNames(relationNames []string) string {
|
||||
return strings.Join(relationNames, nestedRelationSplit)
|
||||
}
|
||||
|
||||
// RTrimSlice Right trims the given slice by given length
|
||||
func RTrimSlice[T any](v []T, trimLen int) []T {
|
||||
if trimLen >= len(v) { // trimLen greater than slice len means fully sliced
|
||||
return v[:0]
|
||||
}
|
||||
if trimLen < 0 { // negative trimLen is ignored
|
||||
return v[:]
|
||||
}
|
||||
return v[:len(v)-trimLen]
|
||||
}
|
||||
Reference in New Issue
Block a user