go-zero + gorm 写测试的一点随笔

开始

涉及测试的类型

  • 单元测试
  • 集成测试
    • 有服务依赖的, 比如数据库依赖, 其它服务依赖. 会去启动一个别的服务
    • 数据库依赖使用go-mysql-server, redis使用mini-redis(也可以启动一个真正的数据库来测试)

例子仓库地址

  • github.com/seth-shi/go-zero-testin...
  • 服务的架构如下
    • id 服务是雪花id服务, 零依赖
    • post 服务依赖雪花服务, 数据库, Redis
      ├─app
      │  ├─id
      │  │  └─rpc
      │  │      ├─etc (配置文件)
      │  │      ├─id (grpc 代码生成)
      │  │      └─internal
      │  │          ├─config (配置定义)
      │  │          ├─logic (业务逻辑)
      │  │          ├─mock (单元测试数据)
      │  │          ├─server (go-zero 服务端生成)
      │  │          └─svc (服务依赖定义)
      │  └─post
      │      └─rpc
      │          ├─etc (配置文件)
      │          ├─internal
      │          │  ├─config (配置定义)
      │          │  ├─faker (集成测试数据)
      │          │  ├─logic (业务逻辑)
      │          │  ├─mock (单元测试数据)
      │          │  ├─model (gorm 生成)
      │          │  │  ├─do (数据库查询操作)
      │          │  │  └─entity (gorm gen 生成模型定义)
      │          │  ├─server (go-zero 服务端生成)
      │          │  └─svc (服务依赖定义)
      │          └─post (grpc 代码生成)
      └─pkg (公共代码)
      └─go.mod

单元测试

id 服务

syntax = "proto3";

package id;
option go_package="./id";

message IdRequest {
}

message IdResponse {
  uint64 id = 1;
  uint64 node = 2;
}

service Id {
  rpc Get(IdRequest) returns(IdResponse);
}

post 服务

syntax = "proto3";

package post;
option go_package="./post";

message PostRequest {
  uint64  id = 1;
}

message PostResponse {
  uint64 id = 1;
  string title = 2;
  string content = 3;
  uint64 createdAt = 4;
  uint64 viewCount = 5;
}

service Post {
  rpc Get(PostRequest) returns(PostResponse);
}

代码

  • 服务依赖
package svc

import (
    "github.com/redis/go-redis/v9"
    "github.com/seth-shi/go-zero-testing-example/app/id/rpc/id"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/config"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/model/do"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/model/entity"
    "github.com/zeromicro/go-zero/core/logx"
    "github.com/zeromicro/go-zero/zrpc"
    "gorm.io/driver/mysql"
    "gorm.io/gorm"
)

type ServiceContext struct {
    Config config.Config
    Redis  *redis.Client
    IdRpc  id.IdClient

    Query *do.Query
}

func NewServiceContext(c config.Config) *ServiceContext {

    conn, err := gorm.Open(mysql.Open(c.DataSource))
    logx.Must(err)

    idClient := id.NewIdClient(zrpc.MustNewClient(c.IdRpc).Conn())
    entity.SetIdGenerator(idClient)

    rdb := redis.NewClient(
        &redis.Options{
            Addr:     c.RedisConf.Host,
            Password: c.RedisConf.Pass,
            DB:       0,
        },
    )

    return &ServiceContext{
        Config: c,
        Redis:  rdb,
        IdRpc:  idClient,
        Query:  do.Use(conn),
    }
}
  • !!!业务逻辑
package logic

import (
    "context"
    "fmt"

    "github.com/samber/lo"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/svc"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/post"

    "github.com/zeromicro/go-zero/core/logx"
)

type GetLogic struct {
    ctx    context.Context
    svcCtx *svc.ServiceContext
    logx.Logger
}

func NewGetLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetLogic {
    return &GetLogic{
        ctx:    ctx,
        svcCtx: svcCtx,
        Logger: logx.WithContext(ctx),
    }
}

func (l *GetLogic) Get(in *post.PostRequest) (*post.PostResponse, error) {

    // 数据库查询数据
    p, err := l.
        svcCtx.
        Query.
        Post.
        WithContext(l.ctx).
        Where(l.svcCtx.Query.Post.ID.Eq(in.GetId())).
        First()
    if err != nil {
        return nil, err
    }

    // redis + 1 浏览量
    redisKey := fmt.Sprintf("post:%d", p.ID)
    val, err := l.svcCtx.Redis.Incr(l.ctx, redisKey).Result()
    if err != nil {
        return nil, err
    }

    resp := &post.PostResponse{
        Id:        p.ID,
        Title:     lo.FromPtr(p.Title),
        Content:   lo.FromPtr(p.Content),
        CreatedAt: uint64(p.CreatedAt.Unix()),
        ViewCount: uint64(val),
    }
    return resp, nil
}

开始写单元测试

  • 业务逻辑单测
package logic

import (
    "context"
    "database/sql/driver"
    "errors"
    "testing"
    "time"

    "github.com/DATA-DOG/go-sqlmock"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/config"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/mock"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/model/do"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/svc"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/post"
    "github.com/stretchr/testify/assert"
)

// 注意, 此部分是单元测试, 不依赖任何外部依赖
// 逻辑的实现尽量通过接口的方式去实现
// 区别于服务根目录下的集成测试, 集成测试会启动服务包括依赖
func TestGetLogic_Get(t *testing.T) {

    var (
        mockVal = mock.GetValue()
        svcCtx  = &svc.ServiceContext{
            Config: config.Config{},
            Redis:  mockVal.Redis,
            IdRpc:  mockVal.IdServer,
            Query:  do.Use(mockVal.Database),
        }
        errNotFound      = errors.New("not found")
        errRedisNotFound = errors.New("redis not found")
        selectSql        = "SELECT (.+) FROM `posts` WHERE `posts`.`id` = (.+)"
        columns          = []string{"id", "title", "content", "created_at", "updated_at"}
        row              = []driver.Value{1, "title", "content", time.Now(), time.Now()}
    )

    logic := NewGetLogic(context.Background(), svcCtx)

    // mock 数据库返回
    mockVal.
        DatabaseMock.
        ExpectQuery(selectSql).
        WithArgs(1, 1).
        WillReturnRows(sqlmock.NewRows(columns).AddRow(row...))
    mockVal.RedisMock.ExpectIncr("post:1").SetVal(1)
    resp, err := logic.Get(&post.PostRequest{Id: 1})
    assert.NoError(t, err)
    assert.Equal(t, uint64(1), resp.GetId())
    assert.Equal(t, "title", resp.GetTitle())

    // redis 返回错误的场景
    mockVal.
        DatabaseMock.
        ExpectQuery(selectSql).
        WithArgs(1, 1).
        WillReturnRows(sqlmock.NewRows(columns).AddRow(row...))
    mockVal.RedisMock.ExpectIncr("post:1").SetErr(errRedisNotFound)
    _, err3 := logic.Get(&post.PostRequest{Id: 1})
    assert.ErrorIs(t, err3, errRedisNotFound)

    // 数据库返回错误的场景
    mockVal.
        DatabaseMock.
        ExpectQuery(selectSql).
        WithArgs(1, 1).
        WillReturnError(errNotFound)
    _, err2 := logic.Get(&post.PostRequest{Id: 1})
    assert.ErrorIs(t, err2, errNotFound)
}
  • mock 包代码组装数据
package mock

import (
    "context"
    "sync"

    "github.com/DATA-DOG/go-sqlmock"
    "github.com/go-redis/redismock/v9"
    "github.com/redis/go-redis/v9"
    "github.com/seth-shi/go-zero-testing-example/app/id/rpc/id"
    "github.com/stretchr/testify/mock"
    "github.com/zeromicro/go-zero/core/logx"
    "google.golang.org/grpc"
    "gorm.io/driver/mysql"
    "gorm.io/gorm"
)

type value struct {
    IdServer     *idMock
    Database     *gorm.DB
    DatabaseMock sqlmock.Sqlmock
    RedisMock    redismock.ClientMock
    Redis        *redis.Client

    cacheStore sync.Map
}

var GetValue = sync.OnceValue(
    func() value {

        db, dbMock := makeDatabase()
        redis, redisMock := redismock.NewClientMock()

        return value{
            IdServer:     &idMock{},
            Database:     db,
            DatabaseMock: dbMock,
            Redis:        redis,
            RedisMock:    redisMock,
            cacheStore:   sync.Map{},
        }
    },
)

type idMock struct {
    mock.Mock
}

func (m *idMock) Get(ctx context.Context, in *id.IdRequest, opts ...grpc.CallOption) (*id.IdResponse, error) {
    args := m.Called()
    idResp := uint64(args.Int(0))

    return &id.IdResponse{
        Id:   idResp,
        Node: idResp,
    }, args.Error(1)
}

func makeDatabase() (*gorm.DB, sqlmock.Sqlmock) {

    db, dbMock, err := sqlmock.New()
    logx.Must(err)
    dbMock.
        ExpectQuery("SELECT VERSION()").
        WillReturnRows(sqlmock.NewRows([]string{"VERSION()"}).AddRow("5.7"))

    gormDB, err := gorm.Open(
        mysql.New(
            mysql.Config{
                Conn: db,
            },
        ), &gorm.Config{},
    )
    logx.Must(err)

    return gormDB, dbMock
}
  • 至此, 我们就完成此业务代码的 100% 测试覆盖

集成测试

  • 需要改造一下main方法
package main

import (
    "flag"
    "fmt"

    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/config"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/server"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/svc"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/post"
    "github.com/zeromicro/go-zero/core/conf"
    "github.com/zeromicro/go-zero/core/logx"
    "github.com/zeromicro/go-zero/core/service"
    "github.com/zeromicro/go-zero/zrpc"
    "google.golang.org/grpc"
    "google.golang.org/grpc/reflection"
)

var svcCtxGet = getCtxByConfigFile

func getCtxByConfigFile() (*svc.ServiceContext, error) {
    flag.Parse()
    var c config.Config
    if err := conf.Load("etc/post.yaml", &c); err != nil {
        return nil, err
    }

    return svc.NewServiceContext(c), nil
}

func main() {

    ctx, err := svcCtxGet()
    logx.Must(err)
    s := zrpc.MustNewServer(
        ctx.Config.RpcServerConf, func(grpcServer *grpc.Server) {
            post.RegisterPostServer(grpcServer, server.NewPostServer(ctx))

            if ctx.Config.Mode == service.DevMode || ctx.Config.Mode == service.TestMode {
                reflection.Register(grpcServer)
            }
        },
    )
    defer s.Stop()

    fmt.Printf("Starting rpc server at %s...\n", ctx.Config.ListenOn)
    s.Start()
}
  • 集成测试方法
package main

import (
    "context"
    "os"
    "testing"

    "github.com/redis/go-redis/v9"
    "github.com/samber/lo"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/config"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/faker"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/model/do"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/svc"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/post"
    "github.com/stretchr/testify/assert"
    "github.com/zeromicro/go-zero/zrpc"
)

func TestMain(m *testing.M) {

    // 使用默认配置
    svcCtxGet = func() (*svc.ServiceContext, error) {

        fakerVal := faker.GetValue()
        return &svc.ServiceContext{
            Config: config.Config{
                RpcServerConf: zrpc.RpcServerConf{
                    ListenOn: fakerVal.RpcListen,
                },
            },
            Redis: redis.NewClient(
                &redis.Options{
                    Addr: fakerVal.RedisAddr,
                    DB:   0,
                },
            ),
            IdRpc: fakerVal.IdServer,
            Query: do.Use(fakerVal.Gorm),
        }, nil
    }

    go main()

    // 运行测试
    code := m.Run()
    os.Exit(code)
}

// 集成测试
func TestGet(t *testing.T) {

    var (
        fakerVal  = faker.GetValue()
        postModel = fakerVal.Models.PostModel
    )
    conn, err := zrpc.NewClient(
        zrpc.RpcClientConf{
            Target:   fakerVal.RpcListen,
            NonBlock: false,
        },
    )
    assert.NoError(t, err)
    client := post.NewPostClient(conn.Conn())
    resp, err := client.Get(context.Background(), &post.PostRequest{Id: postModel.ID})
    assert.NoError(t, err)
    assert.NotZero(t, resp.GetId())
    assert.Equal(t, resp.GetId(), postModel.ID)
    assert.Equal(t, resp.Title, lo.FromPtr(postModel.Title))
}
  • 集成测试依赖服务
package faker

import (
    "context"
    "fmt"
    "math/rand"
    "sync"

    "github.com/alicebob/miniredis/v2"
    "github.com/samber/lo"
    "github.com/seth-shi/go-zero-testing-example/app/id/rpc/id"
    "github.com/seth-shi/go-zero-testing-example/app/post/rpc/internal/model/entity"
    "github.com/seth-shi/go-zero-testing-example/pkg"
    "github.com/zeromicro/go-zero/core/logx"
    "google.golang.org/grpc"
    "gorm.io/driver/mysql"
    "gorm.io/gorm"
)

type value struct {
    IdServer    *idGenerator
    Redis       *miniredis.Miniredis
    Models      *fakerModels
    Gorm        *gorm.DB
    RedisAddr   string
    RpcListen   string
    DatabaseDsn string
}

var GetValue = sync.OnceValue(
    func() value {

        redis, redisAddr := pkg.FakerRedisServer()
        dsn := pkg.FakerDatabaseServer()

        rpcPort, err := pkg.GetAvailablePort()
        logx.Must(err)

        conn, err := gorm.Open(mysql.Open(dsn))
        logx.Must(err)

        idGen := &idGenerator{
            startId: uint64(rand.Int() + 1),
            locker:  &sync.RWMutex{},
        }
        return value{
            IdServer:    idGen,
            Redis:       redis,
            RedisAddr:   redisAddr,
            DatabaseDsn: dsn,
            Models:      makeDatabase(dsn, idGen),
            RpcListen:   fmt.Sprintf(":%d", rpcPort),
            Gorm:        conn,
        }
    },
)

type idGenerator struct {
    startId uint64
    locker  sync.Locker
}

func (m *idGenerator) Get(ctx context.Context, in *id.IdRequest, opts ...grpc.CallOption) (*id.IdResponse, error) {

    m.locker.Lock()
    defer m.locker.Unlock()

    m.startId++

    return &id.IdResponse{
        Id:   m.startId,
        Node: 1,
    }, nil
}


type fakerModels struct {
    PostModel *entity.Post
}

func makeDatabase(dsn string, gen *idGenerator) *fakerModels {

    db, err := gorm.Open(
        mysql.Open(dsn),
    )
    logx.Must(err)

    // 创建表结构
    logx.Must(db.Migrator().CreateTable(&entity.Post{}))

    // 插入测试数据
    entity.SetIdGenerator(gen)
    postModel := &entity.Post{
        Title:   lo.ToPtr("test"),
        Content: lo.ToPtr("content"),
    }
    logx.Must(db.Create(postModel).Error)
    entity.SetIdGenerator(nil)

    return &fakerModels{PostModel: postModel}
}

集成测试依赖数据库

package pkg

import (
    "fmt"
    "log"

    "github.com/alicebob/miniredis/v2"
    sqle "github.com/dolthub/go-mysql-server"
    "github.com/dolthub/go-mysql-server/memory"
    "github.com/dolthub/go-mysql-server/server"
    "github.com/zeromicro/go-zero/core/logx"
)

// FakerDatabaseServer 测试环境可以使用容器化的 dsn/**
func FakerDatabaseServer() string {

    var (
        username = "root"
        password = ""
        host     = "localhost"
        dbname   = "test_db"
        port     int
        err      error
    )

    db := memory.NewDatabase(dbname)
    db.BaseDatabase.EnablePrimaryKeyIndexes()
    provider := memory.NewDBProvider(db)
    engine := sqle.NewDefault(provider)
    mysqlDb := engine.Analyzer.Catalog.MySQLDb
    mysqlDb.SetEnabled(true)
    mysqlDb.AddRootAccount()

    port, err = GetAvailablePort()
    logx.Must(err)

    config := server.Config{
        Protocol: "tcp",
        Address:  fmt.Sprintf("%s:%d", host, port),
    }
    s, err := server.NewServer(
        config,
        engine,
        memory.NewSessionBuilder(provider),
        nil,
    )
    logx.Must(err)
    go func() {
        logx.Must(s.Start())
    }()

    dsn := fmt.Sprintf(
        "%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&loc=Local&parseTime=true",
        username,
        password,
        host,
        port,
        dbname,
    )

    return dsn
}

func FakerRedisServer() (*miniredis.Miniredis, string) {
    m := miniredis.NewMiniRedis()
    if err := m.Start(); err != nil {
        log.Fatalf("could not start miniredis: %s", err)
    }

    return m, m.Addr()
}
  • 至此, 就完成了集成测试的部分

End

本作品采用《CC 协议》,转载必须注明作者和本文链接
当神不再是我们的信仰,那么信仰自己吧,努力让自己变好,不辜负自己的信仰!
讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!
未填写
文章
41
粉丝
158
喜欢
711
收藏
347
排名:30
访问:22.2 万
私信
所有博文
社区赞助商