gorm无法批量插入解决办法(已实践)

gorm2.0版本以下不支持批量插入,那么咋只好造个轮子
利用反射机制获取数据集的类型和字段,合生批量插入的sql,最终利用exec语句执行。
具体代码如下(有很大的优化空间):

// BatchCreate 批量插入
func BatchCreate(db *gorm.DB, data interface{}) error {
    getValue := reflect.ValueOf(data)
    if getValue.Kind() != reflect.Slice {
        return errors.New("数据类型不支持")
    }

    l := getValue.Len()
    if l == 0 {
        return nil
    }

    firstValue := getValue.Index(0)
    fieldNum := firstValue.NumField()
    tableName := getTableName(firstValue.Type().Name())

    const CreatedAt = "CreatedAt"            //创建时间的结构体字段名
    const UpdatedAt = "UpdatedAt"            //更新时间的结构体字段名
    const CreatedAtField = "created_at"        //创建时间的数据库字段名
    const UpdatedAtField = "updated_at"        //更新时间的数据库字段名

    //获取字段名称
    var fields []string
    for i := 0; i < fieldNum; i++ {
        if firstValue.Field(i).Type().String() == "gorm.Model" {
            gormValue := reflect.ValueOf(firstValue.Field(i).Interface())
            for j := 0; j < gormValue.NumField(); j++ {
                if gormValue.Type().Field(j).Name == CreatedAt {
                    fields = append(fields, CreatedAtField)
                } else if gormValue.Type().Field(j).Name == UpdatedAt {
                    fields = append(fields, UpdatedAtField)
                }
            }
            continue
        }
        column := getTagValues(firstValue.Type().Field(i).Tag.Get("gorm"))["column"]
        if column != "" {
            fields = append(fields, column)
        }
    }

    //获取字段值
    var values []string
    for i := 0; i < l; i++ {
        value := getValue.Index(i)
        var one []string
        for j := 0; j < fieldNum; j++ {
            if value.Field(j).Type().String() == "gorm.Model" {
                gormValue := reflect.ValueOf(firstValue.Field(j).Interface())
                for k := 0; k < gormValue.NumField(); k++ {
                    if gormValue.Type().Field(k).Name == CreatedAt {
                        createdTime := getField(gormValue.Field(k).Interface(), gormValue.Field(k).Type().String())
                        if createdTime == "''" {
                            createdTime = fmt.Sprintf("'%s'", time.Now().Format("2006-01-02 15:04:05"))
                        }
                        one = append(one, createdTime)
                    } else if gormValue.Type().Field(k).Name == UpdatedAt {
                        updatedTime := getField(gormValue.Field(k).Interface(), gormValue.Field(k).Type().String())
                        if updatedTime == "''" {
                            updatedTime = fmt.Sprintf("'%s'", time.Now().Format("2006-01-02 15:04:05"))
                        }
                        one = append(one, updatedTime)
                    }
                }
                continue
            }
            if getTagValues(value.Type().Field(j).Tag.Get("gorm"))["column"] != "" {
                fieldType := value.Field(j).Type().String()
                one = append(one, getField(value.Field(j).Interface(), fieldType))
            }
        }
        values = append(values, fmt.Sprintf("(%s)", strings.Join(one, ",")))
        if len(values) >= 100 {
            //大于等于100条分页插入
            sql := fmt.Sprintf("insert into %s (%s) values%s", tableName, strings.Join(fields, ","), strings.Join(values, ","))
            if err := db.Exec(sql).Error; err != nil {
                return err
            }
            values = []string{}
        }
    }

    if len(values) > 0 {
        sql := fmt.Sprintf("insert into %s (%s) values%s", tableName, strings.Join(fields, ","), strings.Join(values, ","))
        return db.Exec(sql).Error
    }

    return nil
}

//根据结构体名称获取表名
func getTableName(modelName string) string {
    reg, _ := regexp.Compile("[A-Z]([a-z]+)")
    return strings.ToLower(strings.Join(reg.FindAllString(modelName, -1), "_"))
}

//获取gorm tag中的字段名
func getTagValues(tag string) map[string]string {
    var fieldMap map[string]string
    fieldMap = make(map[string]string)
    for _, v := range strings.Split(tag, ";") {
        s := strings.Split(v, ":")
        fieldMap[s[0]] = s[1]
    }
    return fieldMap
}

//获取插入的字段值
func getField(data interface{}, fieldType string) string {
    switch fieldType {
    case "string":
        return fmt.Sprintf("'%s'", data.(string))
    case "uint64":
        return fmt.Sprintf("%d", data.(uint64))
    case "uint32":
        return fmt.Sprintf("%d", data.(uint32))
    case "time.Time":
        s := data.(time.Time).Format("2006-01-02 15:04:05")
        if s == "0001-01-01 00:00:00" {
            return "''"
        }
        return fmt.Sprintf("'%s'", s)
    case "*time.Time":
        dtime := data.(*time.Time)
        if dtime == nil {
            return "''"
        }
        s := dtime.Format("2006-01-02 15:04:05")
        if s == "0001-01-01 00:00:00" {
            return "''"
        }
        return fmt.Sprintf("'%s'", s)
    }

    return "''"
}

使用方式:

//gorm.Model
type Model struct {
    ID        uint `gorm:"primary_key"`
    CreatedAt time.Time
    UpdatedAt time.Time
    DeletedAt *time.Time `sql:"index"`
}

type User struct {
    gorm.Model

    Name          string `gorm:"column:name"`
    Sex         uint32 `gorm:"column:sex"`
}
users := []User{
        {
            Name: "小明",
            Sex: 0,
        },
        {
            Name: "小红",
            Sex: 1,
        },
    }
BatchCreate(conn, users)
本作品采用《CC 协议》,转载必须注明作者和本文链接
GitHub地址:github.com/bllon
讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!