grpc拦截器 扩展

grpc的拦截器作为aop编程的利器,相信大家在使用grpc时,一定都体会过。这里呢,主要说一下平时使用时的不足:

只能用于全局,不能灵活的对每个或者一类方法进行拦截处理。如果只做一些业务无关的操作(记录请求日志,发起非业务错误重试),借助grpc-multi-interceptor还是能很好实现。但跟业务相关的一些处理,则显得力不从心了。


刚好最近在写grpc server端的时候,需要做方法调用身份验证,调用者权限验证。但是这两类验证并不是对所有方法都拦截,而是有好些方法是完全放开的。所以呢只能把两者的验证写在一个拦截器里面的,然后定义一个白名单切片,在白名单中就跳过。但是写下来感觉不是很好,毕竟两者功能不一样,只是有共同的白名单。还有一些方法需要做一些统一的提前处理,这些处理也可能不一样,有些少一点,有些多一点,顿时感觉grpc的拦截器有些不灵活。没有其它框架里面类似中间件,拦截器来的强大。

于是自己理了下思路,决定扩展下grpc的拦截器,主要想要实现:

  1. 支持设置多个拦截器

  2. 全局拦截器上实现分组,可以设置白名单,同一分组的拦截器适用该组的白名单

  3. 对单一方法添加拦截器

在谈实现之前,先简单说一下grpc的四类拦截器吧

grpc拦截器简介

server端

unary interceptor 只要实现grpc.UnaryServerInterceptor:

type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)

例:

func UnaryServerInterceptorDemo(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    log.Printf("before handling. Info: %+v", info)
    resp, err := handler(ctx, req)
    log.Printf("after handling. resp: %+v", resp)
    return resp, err
}

stream interceptor 实现grpc.StreamServerInterceptor

type StreamServerInterceptor func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error

例:

func StreamServerInterceptorDemo(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    log.Printf("before handling. Info: %+v", info)
    err := handler(srv, ss)
    log.Printf("after handling. err: %v", err)
    return err
}

然后在服务初始化时:

srv := grpc.NewServer(
    grpc.UnaryInterceptor(UnaryServerInterceptorDemo),
    grpc.StreamInterceptor(StreamServerInterceptorDemo),
)

user.RegisterUserServiceServer(srv, &UserService{})

srv.Server(listen)

client端

unary interceptor 实现 grpc.UnaryClientInterceptor:

type UnaryClientInterceptor fuc(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error

例:

func UnaryClientInterceptorDemo(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    log.Printf("before invoker. method: %+v, request:%+v", method, req)
    err := invoker(ctx, method, req, reply, cc, opts...)
    log.Printf("after invoker. reply: %+v", reply)
    return err
}

stream interceptor 实现grpc.StreamClientInterceptor:

type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)

例:

func StreamServerInterceptorDemo(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    log.Printf("before handling. Info: %+v", info)
    err := handler(srv, ss)
    log.Printf("after handling. err: %v", err)
    return err
}

在初始化客户端时:

grpc.DialContext(ctx, target, 
    grpc.WithUnaryInterceptor(UnaryClientInterceptorDemo),
    grpc.WithStreamInterceptor(StreamServerInterceptorDemo),
    ...
)

默认的拦截器是unary 跟 stream 每种只能配置一个,配置多个,需要借助grpc-multi-interceptor](GitHub - kazegusuri/grpc-multi-interceptor)

实现自己的增强拦截器

首先扩展UnaryServerInterceptor, 定义如下结构体:

// 用于存放一组拦截器
type unaryServerInterceptorGroup struct {
    handlers []grpc.UnaryServerInterceptor  // 包含的拦截器
    skip map[string]struct{}                // 该组的白名单,存放不被拦截的方法名
}

// 用于配置所有的UnaryServerInterceptor
type UnaryServerInterceptors struct {
    global []*unaryServerInterceptorGroup  // 全局拦截器组
    part map[string][]grpc.UnaryServerInterceptor // 局部拦截器
}

定义好结构以后,来实现添加拦截器方法,需要能够添加全局拦截器,添加单一方法的拦截器(局部拦截器)。

  1. 添加全局拦截器,约定调用一次该方法,则添加一组拦截器,方法参数需要传入拦截器切片。如果需要添加单个全局拦截器,则拦截器切片只放一个拦截器,即一个拦截器就是一组。 拦截器的执行顺序按 添加方法的调用顺序。

  2. 添加单一方法的拦截器,需要传入拦截器针对的方法名,以及该方法的拦截器。局部拦截器的执行顺序为 添加方法的调用顺序,然后同一个添加方法传入的拦截器顺序。

  3. 无论添加全局拦截器跟局部拦截器的顺序怎么样,都是先执行全局拦截器再执行局部拦截器。

// 添加全局拦截器
// @param interceptors 添加的拦截器切片
// @param skipMethods  改组拦截器需要忽略的方法名
func (usi *UnaryServerInterceptors) UseGlobal(interceptors []grpc.UnaryServerInterceptor, skipMethods ...string) {
    skip := make(map[string]struct{}, len(skipMethods))
    // 将白名单切片转换为map
    for _, method := range skipMethods {
        skip[method] = struct{}{}
    }

    // 构造拦截器组放置到拦截器组切片末尾
    usi.global = append(usi.global, &unaryServerInterceptorGroup{
        handlers: interceptors,
        skip:     skip,
    })
}

// 添加局部拦截器
// @param method 针对的方法名
// @param interceptors 添加的拦截器
func (usi *UnaryServerInterceptors) UseMethod(method string, interceptors ...grpc.UnaryServerInterceptor) {
    // 局部拦截器用map存放,key为方法全名,判断是否初始化
    if usi.part == nil {
        usi.part = make(map[string][]grpc.UnaryServerInterceptor)
        usi.part[method] = interceptors
        return
    }

    // 已经初始化,判断该方法名的拦截器是否添加过,没有直接赋值
    if _, ok := usi.part[method]; !ok {
        usi.part[method] = interceptors
        return
    }

    // 已经存在,将新增的加至末尾
    usi.part[method] = append(usi.part[method], interceptors...)
}

上面的skipMethods跟method参数,如果不知道grpc方法全名的命名规则,可以直接查看生成的protoc.pb.go文件

比如:

func (c *userServiceClient) GetUserInfo(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*UserInfo, error) {
    out := new(UserInfo)
    err := c.cc.Invoke(ctx, "/user.UserService/GetUserInfo", in, out, opts...)
    if err != nil {
        return nil, err
    }
    return out, nil
}


func _UserService_GetUserInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
    in := new(Empty)
    if err := dec(in); err != nil {
        return nil, err
    }
    if interceptor == nil {
        return srv.(UserServiceServer).GetUserInfo(ctx, in)
    }
    info := &grpc.UnaryServerInfo{
        Server:     srv,
        FullMethod: "/user.UserService/GetUserInfo",
    }
    handler := func(ctx context.Context, req interface{}) (interface{}, error) {
        return srv.(UserServiceServer).GetUserInfo(ctx, req.(*Empty))
    }
    return interceptor(ctx, in, info, handler)
}

tips~ 上面是grpc生成的客户端跟服务端代码。其中"/user.UserService/GetUserInfo"就是方法名,以/开头,然后是protoc文件定义的包名,定义的服务名,包名跟服务名用.隔开,再跟/和定义的方法名。

添加实现了,接下来就是实现导出grpc.UnaryServerInterceptor方法了

func (usi *UnaryServerInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
    // 返回grpc.UnaryServerInterceptor类型的匿名函数
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
    handler grpc.UnaryHandler)(interface{}, error) {

        // 在该匿名函数中,需要依次调用设置的拦截器

        // 获取全局拦截器组的数量
        ggCount := len(usi.global)
        // 获取局部拦截器的数量
        mCount := len(usi.part[info.FullMethod])

        // 如果没有设置任何拦截器,直接调用后续处理函数
        if ggCount + mCount == 0 {
            return handler(ctx, req)
        }

        // 开始调用拦截器
        var (
            groupI, handlerI int  // 全局拦截器组执行计数,拦截器执行计数
            chainHandler grpc.UnaryHandler // 调用函数名,递归调用,需要函数名
        )
        chainHandler = func(ctx context.Context, req interface{}) (i interface{}, err error) {
            // 拦截器组执行计数小于全局拦截器组的数量,说明当前调用仍然执行全局拦截器
            if groupI < ggCount {
                group := usi.global[groupI]
                // 判断当前组白名单是否包含该方法
                if _, ok := group.skip[info.FullMethod]; !ok {
                    // 不包含,得到当前拦截器组执行到哪个拦截器的索引,计数加1
                    index := handlerI
                    handlerI++
                    if index < len(group.handlers) {
                        // 执行当前拦截器
                        return group.handlers[index](ctx, req, info, chainHandler)
                    }
                    // 上步得到的索引大于该组数量,则该组已经执行完成,需要跳到下一组
                    // 先将拦截器计数归0
                    handlerI = 0
                }
                // 拦截器组执行完或者方法在拦截器组白名单,都跳到下一组拦截器
                // 拦截器组计数加1
                groupI++
                return chainHandler(ctx, req)
            }

            // 全局拦截器组执行完以后,执行针对该方法的局部拦截器
            // 拦截器计数在执行完全局拦截器组以后被归0,复用来计数局部拦截器
            if handlerI < mCount {
                special := usi.part[info.FullMethod]
                index := handlerI
                handlerI++
                return special[index](ctx, req, info, chainHandler)
            }

            // 局部拦截器执行完以后,执行后续处理函数,也是递归跳出点
            return handler(ctx, req)
        }

        // 再次导出
        return chainHandler(ctx, req)
    }
}

这样,增强版UnaryServerInterceptor拦截器就实现了

但上面的代码其实还是有点小问题,就是在执行下一个全局拦截器组时以及跳到局部拦截器时,都是使用return chainHandler(ctx, req),本次执行其实没有意义的,加深了无用的函数调用栈。所以可以采用goto或者for循环忽略无意义的调用

我们采用for循环来优化:

func (usi *UnaryServerInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
    // 返回grpc.UnaryServerInterceptor类型的匿名函数
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
    handler grpc.UnaryHandler)(interface{}, error) {
        ...

        chainHandler = func(ctx context.Context, req interface{}) (i interface{}, err error) {
            // 拦截器组执行计数小于全局拦截器组的数量,说明当前调用仍然执行全局拦截器
            if groupI < ggCount {
                for {
                    group := usi.global[groupI]
                    // 判断当前组白名单是否包含该方法
                    if _, ok := group.skip[info.FullMethod]; !ok {
                        // 不包含,得到当前拦截器组执行到哪个拦截器的索引,计数加1
                        index := handlerI
                        handlerI++
                        if index < len(group.handlers) {
                            // 执行当前拦截器
                            return group.handlers[index](ctx, req, info, chainHandler)
                        }
                        // 上步得到的索引大于该组数量,则该组已经执行完成,需要跳到下一组
                        // 先将拦截器计数归0
                        handlerI = 0
                    }
                    // 拦截器组执行完或者方法在拦截器组白名单,都跳到下一组拦截器
                    // 拦截器组计数加1
                    groupI++
                    if groupI >= ggCount { // 拦截器组执行完跳出循环
                        break
                    }
                }
            }

            ...
        }

        ...
    }
}

到这,我们可以在代码中使用了

func LoadUnaryInterceptors() grpc.UnaryServerInterceptor  {
    md := &UnaryServerInterceptors{}

    // 加载全局拦截器
    md.UseGlobal([]grpc.UnaryServerInterceptor{AuthGuard},
        // 完全放开的api
        "/user.UserService/RegisterRule",
        "/user.UserService/SendRegisterCode",
        ...
        )

    ...

    // 加载局部拦截器
    md.UseMethod("/user.UserService/SendRegisterCode", DisableGuard, MsgCodeRateLimitGuard)


    return md.UnaryServerInterceptor()
}


srv := grpc.NewServer(
    grpc.UnaryInterceptor(LoadUnaryInterceptors()),
)

其它类型的拦截器扩展也跟此差不多,就不讲解了。可以参考代码:github.com/welllog/grpc_intercepto...

本作品采用《CC 协议》,转载必须注明作者和本文链接
~by orinfy
讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!