grpc拦截器 扩展
grpc 的拦截器作为 aop 编程的利器,相信大家在使用 grpc 时,一定都体会过。这里呢,主要说一下平时使用时的不足:
只能用于全局,不能灵活的对每个或者一类方法进行拦截处理。如果只做一些业务无关的操作(记录请求日志,发起非业务错误重试),借助 grpc-multi-interceptor 还是能很好实现。但跟业务相关的一些处理,则显得力不从心了。
刚好最近在写 grpc server 端的时候,需要做方法调用身份验证,调用者权限验证。但是这两类验证并不是对所有方法都拦截,而是有好些方法是完全放开的。所以呢只能把两者的验证写在一个拦截器里面的,然后定义一个白名单切片,在白名单中就跳过。但是写下来感觉不是很好,毕竟两者功能不一样,只是有共同的白名单。还有一些方法需要做一些统一的提前处理,这些处理也可能不一样,有些少一点,有些多一点,顿时感觉 grpc 的拦截器有些不灵活。没有其它框架里面类似中间件,拦截器来的强大。
于是自己理了下思路,决定扩展下 grpc 的拦截器,主要想要实现:
支持设置多个拦截器
全局拦截器上实现分组,可以设置白名单,同一分组的拦截器适用该组的白名单
对单一方法添加拦截器
在谈实现之前,先简单说一下 grpc 的四类拦截器吧
grpc 拦截器简介#
server 端#
unary interceptor 只要实现 grpc.UnaryServerInterceptor:
type UnaryServerInterceptor func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error)
例:
func UnaryServerInterceptorDemo(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
log.Printf("before handling. Info: %+v", info)
resp, err := handler(ctx, req)
log.Printf("after handling. resp: %+v", resp)
return resp, err
}
stream interceptor 实现 grpc.StreamServerInterceptor
type StreamServerInterceptor func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error
例:
func StreamServerInterceptorDemo(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
log.Printf("before handling. Info: %+v", info)
err := handler(srv, ss)
log.Printf("after handling. err: %v", err)
return err
}
然后在服务初始化时:
srv := grpc.NewServer(
grpc.UnaryInterceptor(UnaryServerInterceptorDemo),
grpc.StreamInterceptor(StreamServerInterceptorDemo),
)
user.RegisterUserServiceServer(srv, &UserService{})
srv.Server(listen)
client 端#
unary interceptor 实现 grpc.UnaryClientInterceptor:
type UnaryClientInterceptor fuc(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
例:
func UnaryClientInterceptorDemo(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
log.Printf("before invoker. method: %+v, request:%+v", method, req)
err := invoker(ctx, method, req, reply, cc, opts...)
log.Printf("after invoker. reply: %+v", reply)
return err
}
stream interceptor 实现 grpc.StreamClientInterceptor:
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
例:
func StreamServerInterceptorDemo(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
log.Printf("before handling. Info: %+v", info)
err := handler(srv, ss)
log.Printf("after handling. err: %v", err)
return err
}
在初始化客户端时:
grpc.DialContext(ctx, target,
grpc.WithUnaryInterceptor(UnaryClientInterceptorDemo),
grpc.WithStreamInterceptor(StreamServerInterceptorDemo),
...
)
默认的拦截器是 unary 跟 stream 每种只能配置一个,配置多个,需要借助 grpc-multi-interceptor](GitHub - kazegusuri/grpc-multi-interceptor)
实现自己的增强拦截器#
首先扩展 UnaryServerInterceptor, 定义如下结构体:
// 用于存放一组拦截器
type unaryServerInterceptorGroup struct {
handlers []grpc.UnaryServerInterceptor // 包含的拦截器
skip map[string]struct{} // 该组的白名单,存放不被拦截的方法名
}
// 用于配置所有的UnaryServerInterceptor
type UnaryServerInterceptors struct {
global []*unaryServerInterceptorGroup // 全局拦截器组
part map[string][]grpc.UnaryServerInterceptor // 局部拦截器
}
定义好结构以后,来实现添加拦截器方法,需要能够添加全局拦截器,添加单一方法的拦截器(局部拦截器)。
添加全局拦截器,约定调用一次该方法,则添加一组拦截器,方法参数需要传入拦截器切片。如果需要添加单个全局拦截器,则拦截器切片只放一个拦截器,即一个拦截器就是一组。 拦截器的执行顺序按 添加方法的调用顺序。
添加单一方法的拦截器,需要传入拦截器针对的方法名,以及该方法的拦截器。局部拦截器的执行顺序为 添加方法的调用顺序,然后同一个添加方法传入的拦截器顺序。
无论添加全局拦截器跟局部拦截器的顺序怎么样,都是先执行全局拦截器再执行局部拦截器。
// 添加全局拦截器
// @param interceptors 添加的拦截器切片
// @param skipMethods 改组拦截器需要忽略的方法名
func (usi *UnaryServerInterceptors) UseGlobal(interceptors []grpc.UnaryServerInterceptor, skipMethods ...string) {
skip := make(map[string]struct{}, len(skipMethods))
// 将白名单切片转换为map
for _, method := range skipMethods {
skip[method] = struct{}{}
}
// 构造拦截器组放置到拦截器组切片末尾
usi.global = append(usi.global, &unaryServerInterceptorGroup{
handlers: interceptors,
skip: skip,
})
}
// 添加局部拦截器
// @param method 针对的方法名
// @param interceptors 添加的拦截器
func (usi *UnaryServerInterceptors) UseMethod(method string, interceptors ...grpc.UnaryServerInterceptor) {
// 局部拦截器用map存放,key为方法全名,判断是否初始化
if usi.part == nil {
usi.part = make(map[string][]grpc.UnaryServerInterceptor)
usi.part[method] = interceptors
return
}
// 已经初始化,判断该方法名的拦截器是否添加过,没有直接赋值
if _, ok := usi.part[method]; !ok {
usi.part[method] = interceptors
return
}
// 已经存在,将新增的加至末尾
usi.part[method] = append(usi.part[method], interceptors...)
}
上面的 skipMethods 跟 method 参数,如果不知道 grpc 方法全名的命名规则,可以直接查看生成的 protoc.pb.go 文件
比如:
func (c *userServiceClient) GetUserInfo(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*UserInfo, error) {
out := new(UserInfo)
err := c.cc.Invoke(ctx, "/user.UserService/GetUserInfo", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func _UserService_GetUserInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserServiceServer).GetUserInfo(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/user.UserService/GetUserInfo",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserServiceServer).GetUserInfo(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
tips~ 上面是 grpc 生成的客户端跟服务端代码。其中 "/user.UserService/GetUserInfo"
就是方法名,以 /
开头,然后是 protoc 文件定义的包名,定义的服务名,包名跟服务名用.
隔开,再跟 /
和定义的方法名。
添加实现了,接下来就是实现导出 grpc.UnaryServerInterceptor 方法了
func (usi *UnaryServerInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
// 返回grpc.UnaryServerInterceptor类型的匿名函数
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler)(interface{}, error) {
// 在该匿名函数中,需要依次调用设置的拦截器
// 获取全局拦截器组的数量
ggCount := len(usi.global)
// 获取局部拦截器的数量
mCount := len(usi.part[info.FullMethod])
// 如果没有设置任何拦截器,直接调用后续处理函数
if ggCount + mCount == 0 {
return handler(ctx, req)
}
// 开始调用拦截器
var (
groupI, handlerI int // 全局拦截器组执行计数,拦截器执行计数
chainHandler grpc.UnaryHandler // 调用函数名,递归调用,需要函数名
)
chainHandler = func(ctx context.Context, req interface{}) (i interface{}, err error) {
// 拦截器组执行计数小于全局拦截器组的数量,说明当前调用仍然执行全局拦截器
if groupI < ggCount {
group := usi.global[groupI]
// 判断当前组白名单是否包含该方法
if _, ok := group.skip[info.FullMethod]; !ok {
// 不包含,得到当前拦截器组执行到哪个拦截器的索引,计数加1
index := handlerI
handlerI++
if index < len(group.handlers) {
// 执行当前拦截器
return group.handlers[index](ctx, req, info, chainHandler)
}
// 上步得到的索引大于该组数量,则该组已经执行完成,需要跳到下一组
// 先将拦截器计数归0
handlerI = 0
}
// 拦截器组执行完或者方法在拦截器组白名单,都跳到下一组拦截器
// 拦截器组计数加1
groupI++
return chainHandler(ctx, req)
}
// 全局拦截器组执行完以后,执行针对该方法的局部拦截器
// 拦截器计数在执行完全局拦截器组以后被归0,复用来计数局部拦截器
if handlerI < mCount {
special := usi.part[info.FullMethod]
index := handlerI
handlerI++
return special[index](ctx, req, info, chainHandler)
}
// 局部拦截器执行完以后,执行后续处理函数,也是递归跳出点
return handler(ctx, req)
}
// 再次导出
return chainHandler(ctx, req)
}
}
这样,增强版 UnaryServerInterceptor 拦截器就实现了
但上面的代码其实还是有点小问题,就是在执行下一个全局拦截器组时以及跳到局部拦截器时,都是使用 return chainHandler(ctx, req)
,本次执行其实没有意义的,加深了无用的函数调用栈。所以可以采用 goto 或者 for 循环忽略无意义的调用
我们采用 for 循环来优化:
func (usi *UnaryServerInterceptors) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
// 返回grpc.UnaryServerInterceptor类型的匿名函数
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler)(interface{}, error) {
...
chainHandler = func(ctx context.Context, req interface{}) (i interface{}, err error) {
// 拦截器组执行计数小于全局拦截器组的数量,说明当前调用仍然执行全局拦截器
if groupI < ggCount {
for {
group := usi.global[groupI]
// 判断当前组白名单是否包含该方法
if _, ok := group.skip[info.FullMethod]; !ok {
// 不包含,得到当前拦截器组执行到哪个拦截器的索引,计数加1
index := handlerI
handlerI++
if index < len(group.handlers) {
// 执行当前拦截器
return group.handlers[index](ctx, req, info, chainHandler)
}
// 上步得到的索引大于该组数量,则该组已经执行完成,需要跳到下一组
// 先将拦截器计数归0
handlerI = 0
}
// 拦截器组执行完或者方法在拦截器组白名单,都跳到下一组拦截器
// 拦截器组计数加1
groupI++
if groupI >= ggCount { // 拦截器组执行完跳出循环
break
}
}
}
...
}
...
}
}
到这,我们可以在代码中使用了
func LoadUnaryInterceptors() grpc.UnaryServerInterceptor {
md := &UnaryServerInterceptors{}
// 加载全局拦截器
md.UseGlobal([]grpc.UnaryServerInterceptor{AuthGuard},
// 完全放开的api
"/user.UserService/RegisterRule",
"/user.UserService/SendRegisterCode",
...
)
...
// 加载局部拦截器
md.UseMethod("/user.UserService/SendRegisterCode", DisableGuard, MsgCodeRateLimitGuard)
return md.UnaryServerInterceptor()
}
srv := grpc.NewServer(
grpc.UnaryInterceptor(LoadUnaryInterceptors()),
)
其它类型的拦截器扩展也跟此差不多,就不讲解了。可以参考代码:github.com/welllog/grpc_intercepto...
本作品采用《CC 协议》,转载必须注明作者和本文链接
推荐文章: