You've already forked openaccounting-server
forked from cybercinch/openaccounting-server
deps: update dependencies for GORM, Viper, and SQLite support
- Add GORM v1.25.12 with MySQL and SQLite drivers - Add Viper v1.19.0 for configuration management - Add UUID package for GORM model IDs - Update vendor directory with new dependencies - Update Go module requirements and checksums 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
453
vendor/gorm.io/gorm/callbacks/associations.go
generated
vendored
Normal file
453
vendor/gorm.io/gorm/callbacks/associations.go
generated
vendored
Normal file
@@ -0,0 +1,453 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
|
||||
|
||||
// Save Belongs To associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
setupReferences := func(obj reflect.Value, elem reflect.Value) {
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
dest[ref.ForeignKey.DBName] = pv
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
dest[rel.Name] = elem.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
rValLen = db.Statement.ReflectValue.Len()
|
||||
objs = make([]reflect.Value, 0, rValLen)
|
||||
fieldType = rel.Field.FieldType
|
||||
isPtr = fieldType.Kind() == reflect.Ptr
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
||||
if !isPtr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
objs = append(objs, obj)
|
||||
elems = reflect.Append(elems, rv)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
|
||||
setupReferences(db.Statement.ReflectValue, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
|
||||
|
||||
// Save Has One associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasOne {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
fieldType = rel.Field.FieldType
|
||||
isPtr = fieldType.Kind() == reflect.Ptr
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
|
||||
}
|
||||
}
|
||||
|
||||
elems = reflect.Append(elems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if f.Kind() != reflect.Ptr {
|
||||
f = f.Addr()
|
||||
}
|
||||
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
|
||||
}
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save Has Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasMany {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
|
||||
}
|
||||
}
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
elems = reflect.Append(elems, elem.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
|
||||
// Save Many2Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
objs := []reflect.Value{}
|
||||
|
||||
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
|
||||
} else {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
|
||||
}
|
||||
}
|
||||
joins = reflect.Append(joins, joinValue)
|
||||
}
|
||||
|
||||
identityMap := map[string]bool{}
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
if !isPtr {
|
||||
elem = elem.Addr()
|
||||
}
|
||||
objs = append(objs, v)
|
||||
elems = reflect.Append(elems, elem)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
}
|
||||
|
||||
// optimize elems of reflect value length
|
||||
if elemLen := elems.Len(); elemLen > 0 {
|
||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
|
||||
}
|
||||
|
||||
for i := 0; i < elemLen; i++ {
|
||||
appendToJoins(objs[i], elems.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
if joins.Len() > 0 {
|
||||
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
DisableNestedTransaction: true,
|
||||
}).Create(joins.Interface()).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
|
||||
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
|
||||
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
|
||||
for _, dbName := range s.PrimaryFieldDBNames {
|
||||
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
|
||||
}
|
||||
|
||||
onConflict.UpdateAll = stmt.DB.FullSaveAssociations
|
||||
if !onConflict.UpdateAll {
|
||||
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
|
||||
}
|
||||
} else {
|
||||
onConflict.DoNothing = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||
// stop save association loop
|
||||
if checkAssociationsSaved(db, rValues) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
selects, omits []string
|
||||
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
|
||||
refName = rel.Name + "."
|
||||
values = rValues.Interface()
|
||||
)
|
||||
|
||||
for name, ok := range selectColumns {
|
||||
columnName := ""
|
||||
if strings.HasPrefix(name, refName) {
|
||||
columnName = strings.TrimPrefix(name, refName)
|
||||
}
|
||||
|
||||
if columnName != "" {
|
||||
if ok {
|
||||
selects = append(selects, columnName)
|
||||
} else {
|
||||
omits = append(omits, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
|
||||
FullSaveAssociations: db.FullSaveAssociations,
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
DisableNestedTransaction: true,
|
||||
})
|
||||
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if tx.Statement.FullSaveAssociations {
|
||||
tx = tx.Set("gorm:update_track_time", true)
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
} else if restricted && len(omits) == 0 {
|
||||
tx = tx.Omit(clause.Associations)
|
||||
}
|
||||
|
||||
if len(omits) > 0 {
|
||||
tx = tx.Omit(omits...)
|
||||
}
|
||||
|
||||
return db.AddError(tx.Create(values).Error)
|
||||
}
|
||||
|
||||
// check association values has been saved
|
||||
// if values kind is Struct, check it has been saved
|
||||
// if values kind is Slice/Array, check all items have been saved
|
||||
var visitMapStoreKey = "gorm:saved_association_map"
|
||||
|
||||
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
|
||||
if visit, ok := db.Get(visitMapStoreKey); ok {
|
||||
if v, ok := visit.(*visitMap); ok {
|
||||
if loadOrStoreVisitMap(v, values) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vistMap := make(visitMap)
|
||||
loadOrStoreVisitMap(&vistMap, values)
|
||||
db.Set(visitMapStoreKey, &vistMap)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
83
vendor/gorm.io/gorm/callbacks/callbacks.go
generated
vendored
Normal file
83
vendor/gorm.io/gorm/callbacks/callbacks.go
generated
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
||||
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
|
||||
updateClauses = []string{"UPDATE", "SET", "WHERE"}
|
||||
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
LastInsertIDReversed bool
|
||||
CreateClauses []string
|
||||
QueryClauses []string
|
||||
UpdateClauses []string
|
||||
DeleteClauses []string
|
||||
}
|
||||
|
||||
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||
enableTransaction := func(db *gorm.DB) bool {
|
||||
return !db.SkipDefaultTransaction
|
||||
}
|
||||
|
||||
if len(config.CreateClauses) == 0 {
|
||||
config.CreateClauses = createClauses
|
||||
}
|
||||
if len(config.QueryClauses) == 0 {
|
||||
config.QueryClauses = queryClauses
|
||||
}
|
||||
if len(config.DeleteClauses) == 0 {
|
||||
config.DeleteClauses = deleteClauses
|
||||
}
|
||||
if len(config.UpdateClauses) == 0 {
|
||||
config.UpdateClauses = updateClauses
|
||||
}
|
||||
|
||||
createCallback := db.Callback().Create()
|
||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
|
||||
createCallback.Register("gorm:create", Create(config))
|
||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
||||
createCallback.Register("gorm:after_create", AfterCreate)
|
||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
createCallback.Clauses = config.CreateClauses
|
||||
|
||||
queryCallback := db.Callback().Query()
|
||||
queryCallback.Register("gorm:query", Query)
|
||||
queryCallback.Register("gorm:preload", Preload)
|
||||
queryCallback.Register("gorm:after_query", AfterQuery)
|
||||
queryCallback.Clauses = config.QueryClauses
|
||||
|
||||
deleteCallback := db.Callback().Delete()
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||
deleteCallback.Register("gorm:delete", Delete(config))
|
||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
deleteCallback.Clauses = config.DeleteClauses
|
||||
|
||||
updateCallback := db.Callback().Update()
|
||||
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
||||
updateCallback.Register("gorm:update", Update(config))
|
||||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
updateCallback.Clauses = config.UpdateClauses
|
||||
|
||||
rowCallback := db.Callback().Row()
|
||||
rowCallback.Register("gorm:row", RowQuery)
|
||||
rowCallback.Clauses = config.QueryClauses
|
||||
|
||||
rawCallback := db.Callback().Raw()
|
||||
rawCallback.Register("gorm:raw", RawExec)
|
||||
rawCallback.Clauses = config.QueryClauses
|
||||
}
|
||||
32
vendor/gorm.io/gorm/callbacks/callmethod.go
generated
vendored
Normal file
32
vendor/gorm.io/gorm/callbacks/callmethod.go
generated
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||
tx := db.Session(&gorm.Session{NewDB: true})
|
||||
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
db.Statement.CurDestIndex = 0
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
|
||||
fc(value.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
return
|
||||
}
|
||||
db.Statement.CurDestIndex++
|
||||
}
|
||||
case reflect.Struct:
|
||||
if db.Statement.ReflectValue.CanAddr() {
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
406
vendor/gorm.io/gorm/callbacks/create.go
generated
vendored
Normal file
406
vendor/gorm.io/gorm/callbacks/create.go
generated
vendored
Normal file
@@ -0,0 +1,406 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// BeforeCreate before create hooks
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeCreate {
|
||||
if i, ok := value.(BeforeCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeCreate(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Create create hook
|
||||
func Create(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
||||
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
isDryRun := !db.DryRun && db.Error == nil
|
||||
if !isDryRun {
|
||||
return
|
||||
}
|
||||
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
if ok {
|
||||
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
|
||||
mode |= gorm.ScanOnConflictDoNothing
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.Statement.ConnPool.QueryContext(
|
||||
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
|
||||
)
|
||||
if db.AddError(err) == nil {
|
||||
defer func() {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, mode)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
result, err := db.Statement.ConnPool.ExecContext(
|
||||
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
|
||||
)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
pkField *schema.Field
|
||||
pkFieldName = "@id"
|
||||
)
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
|
||||
if !insertOk {
|
||||
if !supportReturning {
|
||||
db.AddError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
return
|
||||
}
|
||||
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||
}
|
||||
|
||||
// append @id column with value for auto-increment primary key
|
||||
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||
switch values := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values[pkFieldName] = insertID
|
||||
case *map[string]interface{}:
|
||||
(*values)[pkFieldName] = insertID
|
||||
case []map[string]interface{}, *[]map[string]interface{}:
|
||||
mapValues, ok := values.([]map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||
if *v != nil {
|
||||
mapValues = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.LastInsertIDReversed {
|
||||
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
|
||||
for _, mapValue := range mapValues {
|
||||
if mapValue != nil {
|
||||
mapValue[pkFieldName] = insertID
|
||||
}
|
||||
insertID += schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
default:
|
||||
if pkField == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AfterCreate after create hooks
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterCreate {
|
||||
if i, ok := value.(AfterCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterCreate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToCreateValues convert to create values
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
curTime := stmt.DB.NowFunc()
|
||||
|
||||
switch value := stmt.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, value)
|
||||
case *map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, *value)
|
||||
case []map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
case *[]map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
|
||||
default:
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
_, updateTrackTime = stmt.Get("gorm:update_track_time")
|
||||
isZero bool
|
||||
)
|
||||
stmt.Settings.Delete("gorm:update_track_time")
|
||||
|
||||
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
|
||||
|
||||
for _, db := range stmt.Schema.DBNames {
|
||||
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
rValLen := stmt.ReflectValue.Len()
|
||||
if rValLen == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.SQL.Grow(rValLen * 18)
|
||||
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
|
||||
values.Values = make([][]interface{}, rValLen)
|
||||
|
||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
if !rv.IsValid() {
|
||||
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||
return
|
||||
}
|
||||
|
||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[i][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
|
||||
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
||||
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
|
||||
}
|
||||
defaultValueFieldsHavingValue[field][i] = rvOfvalue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[0][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
|
||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
|
||||
if stmt.Schema != nil && len(values.Columns) >= 1 {
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
|
||||
|
||||
columns := make([]string, 0, len(values.Columns)-1)
|
||||
for _, column := range values.Columns {
|
||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
|
||||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
|
||||
if field.AutoUpdateTime > 0 {
|
||||
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||
switch field.AutoUpdateTime {
|
||||
case schema.UnixNanosecond:
|
||||
assignment.Value = curTime.UnixNano()
|
||||
case schema.UnixMillisecond:
|
||||
assignment.Value = curTime.UnixMilli()
|
||||
case schema.UnixSecond:
|
||||
assignment.Value = curTime.Unix()
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
|
||||
} else {
|
||||
columns = append(columns, column.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
|
||||
if len(onConflict.DoUpdates) == 0 {
|
||||
onConflict.DoNothing = true
|
||||
}
|
||||
|
||||
// use primary fields as default OnConflict columns
|
||||
if len(onConflict.Columns) == 0 {
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
}
|
||||
stmt.AddClause(onConflict)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
195
vendor/gorm.io/gorm/callbacks/delete.go
generated
vendored
Normal file
195
vendor/gorm.io/gorm/callbacks/delete.go
generated
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteBeforeAssociations(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||
if !restricted {
|
||||
return
|
||||
}
|
||||
|
||||
for column, v := range selectColumns {
|
||||
if !v {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, ok := db.Statement.Schema.Relationships.Relations[column]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch rel.Type {
|
||||
case schema.HasOne, schema.HasMany:
|
||||
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
|
||||
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
|
||||
withoutConditions := false
|
||||
if db.Statement.Unscoped {
|
||||
tx = tx.Unscoped()
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
selects := make([]string, 0, len(db.Statement.Selects))
|
||||
for _, s := range db.Statement.Selects {
|
||||
if s == clause.Associations {
|
||||
selects = append(selects, s)
|
||||
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) {
|
||||
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
|
||||
}
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
}
|
||||
}
|
||||
|
||||
for _, cond := range queryConds {
|
||||
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
|
||||
withoutConditions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
queryConds = make([]clause.Expression, 0, len(rel.References))
|
||||
foreignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
relForeignKeys = make([]string, 0, len(rel.References))
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
table = rel.JoinTable.Table
|
||||
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
queryConds = append(queryConds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
|
||||
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
||||
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
||||
|
||||
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func Delete(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(100)
|
||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
if !ok {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
gorm.Scan(rows, db, mode)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
db.AddError(rows.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterDeleteInterface); ok {
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
152
vendor/gorm.io/gorm/callbacks/helper.go
generated
vendored
Normal file
152
vendor/gorm.io/gorm/callbacks/helper.go
generated
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConvertMapToValuesForCreate convert map to values
|
||||
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
|
||||
values.Columns = make([]clause.Column, 0, len(mapValue))
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
|
||||
|
||||
keys := make([]string, 0, len(mapValue))
|
||||
for k := range mapValue {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
value := mapValue[k]
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: k})
|
||||
if len(values.Values) == 0 {
|
||||
values.Values = [][]interface{}{{}}
|
||||
}
|
||||
|
||||
values.Values[0] = append(values.Values[0], value)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
|
||||
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
|
||||
columns := make([]string, 0, len(mapValues))
|
||||
|
||||
// when the length of mapValues is zero,return directly here
|
||||
// no need to call stmt.SelectAndOmitColumns method
|
||||
if len(mapValues) == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
result = make(map[string][]interface{}, len(mapValues))
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
)
|
||||
|
||||
for idx, mapValue := range mapValues {
|
||||
for k, v := range mapValue {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := result[k]; !ok {
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
result[k] = make([]interface{}, len(mapValues))
|
||||
columns = append(columns, k)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
result[k][idx] = v
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(columns)
|
||||
values.Values = make([][]interface{}, len(mapValues))
|
||||
values.Columns = make([]clause.Column, len(columns))
|
||||
for idx, column := range columns {
|
||||
values.Columns[idx] = clause.Column{Name: column}
|
||||
|
||||
for i, v := range result[column] {
|
||||
if len(values.Values[i]) == 0 {
|
||||
values.Values[i] = make([]interface{}, len(columns))
|
||||
}
|
||||
|
||||
values.Values[i][idx] = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
|
||||
if supportReturning {
|
||||
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
|
||||
returning, _ := c.Expression.(clause.Returning)
|
||||
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
|
||||
return true, 0
|
||||
}
|
||||
return true, gorm.ScanUpdate
|
||||
}
|
||||
}
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func checkMissingWhereConditions(db *gorm.DB) {
|
||||
if !db.AllowGlobalUpdate && db.Error == nil {
|
||||
where, withCondition := db.Statement.Clauses["WHERE"]
|
||||
if withCondition {
|
||||
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
|
||||
whereClause, _ := where.Expression.(clause.Where)
|
||||
withCondition = len(whereClause.Exprs) > 1
|
||||
}
|
||||
}
|
||||
if !withCondition {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type visitMap = map[reflect.Value]bool
|
||||
|
||||
// Check if circular values, return true if loaded
|
||||
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
loaded = true
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
|
||||
loaded = false
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Interface:
|
||||
if v.CanAddr() {
|
||||
p := v.Addr()
|
||||
if _, ok := (*visitMap)[p]; ok {
|
||||
return true
|
||||
}
|
||||
(*visitMap)[p] = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
39
vendor/gorm.io/gorm/callbacks/interfaces.go
generated
vendored
Normal file
39
vendor/gorm.io/gorm/callbacks/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package callbacks
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*gorm.DB) error
|
||||
}
|
||||
351
vendor/gorm.io/gorm/callbacks/preload.go
generated
vendored
Normal file
351
vendor/gorm.io/gorm/callbacks/preload.go
generated
vendored
Normal file
@@ -0,0 +1,351 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// parsePreloadMap extracts nested preloads. e.g.
|
||||
//
|
||||
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||
// parsePreloadMap(schema, map[string][]interface{}{
|
||||
// clause.Associations: {"arg1"},
|
||||
// "k1": {"arg2"},
|
||||
// "k2.k3": {"arg3"},
|
||||
// "k4.k5.k6": {"arg4"},
|
||||
// })
|
||||
// // preloadMap is
|
||||
// map[string]map[string][]interface{}{
|
||||
// "k0": {},
|
||||
// "k7": {
|
||||
// "k8": {},
|
||||
// },
|
||||
// "k1": {},
|
||||
// "k2": {
|
||||
// "k3": {"arg3"},
|
||||
// },
|
||||
// "k4": {
|
||||
// "k5.k6": {"arg4"},
|
||||
// },
|
||||
// }
|
||||
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
setPreloadMap := func(name, value string, args []interface{}) {
|
||||
if _, ok := preloadMap[name]; !ok {
|
||||
preloadMap[name] = map[string][]interface{}{}
|
||||
}
|
||||
if value != "" {
|
||||
preloadMap[name][value] = args
|
||||
}
|
||||
}
|
||||
|
||||
for name, args := range preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, relation := range s.Relationships.Relations {
|
||||
if relation.Schema == s {
|
||||
setPreloadMap(relation.Name, value, args)
|
||||
}
|
||||
}
|
||||
|
||||
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||
for _, value := range embeddedValues(embeddedRelations) {
|
||||
setPreloadMap(embedded, value, args)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setPreloadMap(preloadFields[0], value, args)
|
||||
}
|
||||
}
|
||||
return preloadMap
|
||||
}
|
||||
|
||||
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
if embeddedRelations == nil {
|
||||
return nil
|
||||
}
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
names = append(names, embeddedValues(relations)...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||
// If the current relationship is embedded or joined, current query will be ignored.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||
|
||||
// avoid random traversal of the map
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||
for _, join := range joins {
|
||||
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||
joined = true
|
||||
continue
|
||||
}
|
||||
join0, join1, cut := strings.Cut(join, ".")
|
||||
if cut {
|
||||
if _, ok := relationships.Relations[join0]; ok && name == join0 {
|
||||
joined = true
|
||||
nestedJoins = append(nestedJoins, join1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return joined, nestedJoins
|
||||
}
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
if joined, nestedJoins := isJoined(name); joined {
|
||||
switch rv := db.Statement.ReflectValue; rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() > 0 {
|
||||
reflectValue := rel.FieldSchema.MakeSlice().Elem()
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
|
||||
if frv.Kind() != reflect.Ptr {
|
||||
reflectValue = reflect.Append(reflectValue, frv.Addr())
|
||||
} else {
|
||||
if frv.IsNil() {
|
||||
continue
|
||||
}
|
||||
reflectValue = reflect.Append(reflectValue, frv)
|
||||
}
|
||||
}
|
||||
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Pointer:
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return gorm.ErrInvalidData
|
||||
}
|
||||
} else {
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := tx.Statement.Parse(dest); err != nil {
|
||||
tx.AddError(err)
|
||||
return tx
|
||||
}
|
||||
tx.Statement.ReflectValue = reflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
return tx
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
relForeignKeys []string
|
||||
relForeignFields []*schema.Field
|
||||
foreignFields []*schema.Field
|
||||
foreignValues [][]interface{}
|
||||
identityMap = map[string][]reflect.Value{}
|
||||
inlineConds []interface{}
|
||||
)
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
var (
|
||||
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
joinForeignKeys = make([]string, 0, len(rel.References))
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
|
||||
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
}
|
||||
}
|
||||
|
||||
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||
if len(joinForeignValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// convert join identity map to relation identity map
|
||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
|
||||
for i := 0; i < joinResults.Len(); i++ {
|
||||
joinIndexValue := joinResults.Index(i)
|
||||
for idx, field := range joinForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||
}
|
||||
|
||||
for idx, field := range joinRelForeignFields {
|
||||
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||
}
|
||||
|
||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||
joinKey := utils.ToStringKey(joinFieldValues...)
|
||||
identityMap[joinKey] = append(identityMap[joinKey], results...)
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
|
||||
} else {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||
if len(foreignValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// nested preload
|
||||
for p, pvs := range preloads {
|
||||
tx = tx.Preload(p, pvs...)
|
||||
}
|
||||
|
||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
||||
if len(values) != 0 {
|
||||
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
|
||||
|
||||
for _, cond := range conds {
|
||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||
tx = fc(tx)
|
||||
} else {
|
||||
inlineConds = append(inlineConds, cond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(inlineConds) > 0 {
|
||||
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
|
||||
}
|
||||
|
||||
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fieldValues := make([]interface{}, len(relForeignFields))
|
||||
|
||||
// clean up old values before preloading
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
|
||||
default:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
|
||||
default:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < reflectResults.Len(); i++ {
|
||||
elem := reflectResults.Index(i)
|
||||
for idx, field := range relForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
|
||||
}
|
||||
|
||||
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
|
||||
}
|
||||
|
||||
for _, data := range datas {
|
||||
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
|
||||
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
|
||||
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
|
||||
}
|
||||
|
||||
reflectFieldValue = reflect.Indirect(reflectFieldValue)
|
||||
switch reflectFieldValue.Kind() {
|
||||
case reflect.Struct:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
|
||||
case reflect.Slice, reflect.Array:
|
||||
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
|
||||
} else {
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Error
|
||||
}
|
||||
314
vendor/gorm.io/gorm/callbacks/query.go
generated
vendored
Normal file
314
vendor/gorm.io/gorm/callbacks/query.go
generated
vendored
Normal file
@@ -0,0 +1,314 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, 0)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BuildQuerySQL(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.QueryClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(100)
|
||||
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
||||
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
||||
var conds []clause.Expression
|
||||
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
||||
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: conds})
|
||||
}
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
|
||||
for idx, name := range db.Statement.Selects {
|
||||
if db.Statement.Schema == nil {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
|
||||
} else {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
}
|
||||
}
|
||||
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
|
||||
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
|
||||
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
|
||||
for _, dbName := range db.Statement.Schema.DBNames {
|
||||
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
|
||||
}
|
||||
}
|
||||
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
|
||||
queryFields := db.QueryFields
|
||||
if !queryFields {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
|
||||
case reflect.Slice:
|
||||
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
|
||||
}
|
||||
}
|
||||
|
||||
if queryFields {
|
||||
stmt := gorm.Statement{DB: db}
|
||||
// smaller struct
|
||||
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
|
||||
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
|
||||
|
||||
for idx, dbName := range stmt.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inline joins
|
||||
fromClause := clause.From{}
|
||||
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||
fromClause = v
|
||||
}
|
||||
|
||||
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
|
||||
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
var relations []*schema.Relationship
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||
if ok {
|
||||
isRelations = true
|
||||
relations = append(relations, relation)
|
||||
} else {
|
||||
// handle nested join like "Manager.Company"
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
guessNestedRelations = append(guessNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = guessNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if join.Expression != nil {
|
||||
return clause.Join{
|
||||
Type: join.JoinType,
|
||||
Expression: join.Expression,
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clause.Join{
|
||||
Type: joinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
}
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for idx, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
curAliasName := rel.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
|
||||
}
|
||||
|
||||
if _, ok := specifiedRelationsName[curAliasName]; !ok {
|
||||
aliasName := curAliasName
|
||||
if idx == len(relations)-1 && join.Alias != "" {
|
||||
aliasName = join.Alias
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
|
||||
specifiedRelationsName[curAliasName] = aliasName
|
||||
}
|
||||
|
||||
parentTableName = curAliasName
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.AddClause(fromClause)
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clauseSelect)
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
||||
if db.Statement.Schema == nil {
|
||||
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
|
||||
return
|
||||
}
|
||||
|
||||
joins := make([]string, 0, len(db.Statement.Joins))
|
||||
for _, join := range db.Statement.Joins {
|
||||
joins = append(joins, join.Name)
|
||||
}
|
||||
|
||||
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||
if tx.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// clear the joins after query because preload need it
|
||||
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||
fromClause := db.Statement.Clauses["FROM"]
|
||||
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
|
||||
db.Statement.Clauses["FROM"] = fromClause
|
||||
}
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
db.AddError(i.AfterFind(tx))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
22
vendor/gorm.io/gorm/callbacks/raw.go
generated
vendored
Normal file
22
vendor/gorm.io/gorm/callbacks/raw.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RawExec(db *gorm.DB) {
|
||||
if db.Error == nil && !db.DryRun {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
23
vendor/gorm.io/gorm/callbacks/row.go
generated
vendored
Normal file
23
vendor/gorm.io/gorm/callbacks/row.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RowQuery(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
if db.DryRun || db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
|
||||
db.Statement.Settings.Delete("rows")
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
}
|
||||
|
||||
db.RowsAffected = -1
|
||||
}
|
||||
}
|
||||
32
vendor/gorm.io/gorm/callbacks/transaction.go
generated
vendored
Normal file
32
vendor/gorm.io/gorm/callbacks/transaction.go
generated
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func BeginTransaction(db *gorm.DB) {
|
||||
if !db.Config.SkipDefaultTransaction && db.Error == nil {
|
||||
if tx := db.Begin(); tx.Error == nil {
|
||||
db.Statement.ConnPool = tx.Statement.ConnPool
|
||||
db.InstanceSet("gorm:started_transaction", true)
|
||||
} else if tx.Error == gorm.ErrInvalidTransaction {
|
||||
tx.Error = nil
|
||||
} else {
|
||||
db.Error = tx.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CommitOrRollbackTransaction(db *gorm.DB) {
|
||||
if !db.Config.SkipDefaultTransaction {
|
||||
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
|
||||
if db.Error != nil {
|
||||
db.Rollback()
|
||||
} else {
|
||||
db.Commit()
|
||||
}
|
||||
|
||||
db.Statement.ConnPool = db.ConnPool
|
||||
}
|
||||
}
|
||||
}
|
||||
313
vendor/gorm.io/gorm/callbacks/update.go
generated
vendored
Normal file
313
vendor/gorm.io/gorm/callbacks/update.go
generated
vendored
Normal file
@@ -0,0 +1,313 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func SetupUpdateReflectValue(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
||||
}
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeUpdate before update hooks
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeUpdate {
|
||||
if i, ok := value.(BeforeUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeUpdate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Update update hook
|
||||
func Update(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
defer delete(db.Statement.Clauses, "SET")
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
dest := db.Statement.Dest
|
||||
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
|
||||
gorm.Scan(rows, db, mode)
|
||||
db.Statement.Dest = dest
|
||||
db.AddError(rows.Close())
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AfterUpdate after update hooks
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterUpdate {
|
||||
if i, ok := value.(AfterUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterUpdate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToAssignments convert to update assignments
|
||||
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
|
||||
assignValue func(field *schema.Field, value interface{})
|
||||
)
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.Context, stmt.ReflectValue, value)
|
||||
}
|
||||
}
|
||||
default:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
updatingValue := reflect.ValueOf(stmt.Dest)
|
||||
for updatingValue.Kind() == reflect.Ptr {
|
||||
updatingValue = updatingValue.Elem()
|
||||
}
|
||||
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if size := stmt.ReflectValue.Len(); size > 0 {
|
||||
var isZero bool
|
||||
for i := 0; i < size; i++ {
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
||||
if !isZero {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isZero {
|
||||
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch value := updatingValue.Interface().(type) {
|
||||
case map[string]interface{}:
|
||||
set = make([]clause.Assignment, 0, len(value))
|
||||
|
||||
keys := make([]string, 0, len(value))
|
||||
for k := range value {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
kv := value[k]
|
||||
if _, ok := kv.(*gorm.DB); ok {
|
||||
kv = []interface{}{kv}
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
if field.DBName != "" {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
|
||||
}
|
||||
}
|
||||
|
||||
if !stmt.SkipHooks && stmt.Schema != nil {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.LookUpField(dbName)
|
||||
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
|
||||
now := stmt.DB.NowFunc()
|
||||
assignValue(field, now)
|
||||
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
} else {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
updatingSchema := stmt.Schema
|
||||
var isDiffSchema bool
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
// different schema
|
||||
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
||||
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
||||
updatingSchema = updatingStmt.Schema
|
||||
isDiffSchema = true
|
||||
}
|
||||
}
|
||||
|
||||
switch updatingValue.Kind() {
|
||||
case reflect.Struct:
|
||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
if field := updatingSchema.LookUpField(dbName); field != nil {
|
||||
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
||||
value, isZero := field.ValueOf(stmt.Context, updatingValue)
|
||||
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixMilli()
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
} else {
|
||||
value = stmt.DB.NowFunc()
|
||||
}
|
||||
isZero = false
|
||||
}
|
||||
|
||||
if (ok || !isZero) && field.Updatable {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||
assignField := field
|
||||
if isDiffSchema {
|
||||
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
|
||||
assignField = originField
|
||||
}
|
||||
}
|
||||
assignValue(assignField, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user