gRPC拦截器

gRPC拦截器可以在 gRPC调用之前或调用之后执行一些逻辑,如监控,认证,记录日志等。

gRPC 默认的拦截器只能够添加一个拦截器。

简单模式

客户端使用 grpc.WithUnaryInterceptor方法添加拦截器,服务端使用 grpc.UnaryInterceptor方法添加拦截器。

proto文件

syntax = "proto3";

package proto;

option go_package = "/cal;cal";

message RequestInfo {
  int64 number1 = 1;
  int64 number2 = 2;
}

message ResponseInfo {
  int64 res = 1;
}

service Cal {
  rpc Add (RequestInfo) returns (ResponseInfo) {}
}

客户端拦截器

package main

import (
    "context"
    "fmt"
    "google.golang.org/grpc"
    "google.golang.org/grpc/credentials/insecure"
    "test/cal"
    "time"
)

func main() {
    // 添加拦截器
    opt := grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
        // 调用前
        start := time.Now()
        err := invoker(ctx, method, req, reply, cc, opts...)

        // 调用后
        fmt.Println(time.Since(start))
        return err
    })

    // 建立连接
    conn, _ := grpc.Dial("127.0.0.1:8080", grpc.WithTransportCredentials(insecure.NewCredentials()), opt)

    // 实例化客户端
    client := cal.NewCalClient(conn)

    // 调用服务
    res, _ := client.Add(context.Background(), &cal.RequestInfo{
        Number1: 1,
        Number2: 1,
    })
    fmt.Println(res.Res)
}

服务端拦截器

package main

import (
    "fmt"
    "golang.org/x/net/context"
    "google.golang.org/grpc"
    "net"
    "test/cal"
)

type Cal struct {
    cal.UnimplementedCalServer
}

func (c *Cal) Add(ctx context.Context, req *cal.RequestInfo) (*cal.ResponseInfo, error) {
    return &cal.ResponseInfo{Res: req.Number1 + req.Number2}, nil
}

func main() {
    // 监听
    listen, _ := net.Listen("tcp", ":8080")

    // 设置拦截器
    opt := grpc.UnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
        fmt.Println("调用前")
        res, err := handler(ctx, req)
        fmt.Println("调用后")
        return res, err
    })

    // 实例化grpc服务
    s := grpc.NewServer(opt)

    // 注册服务
    cal.RegisterCalServer(s, &Cal{})

    // 启动
    s.Serve(listen)
}

流模式

客户端使用 grpc.WithChainStreamInterceptor方法添加拦截器,服务端使用 grpc.StreamInterceptor方法添加拦截器。

proto文件

syntax = "proto3";

package proto;

option go_package = "/stream;stream";

message RequestInfo {
  string data = 1;
}

message ResponseInfo {
  string data = 1;
}

service Stream {
  rpc AllStream (stream RequestInfo) returns (stream ResponseInfo) {}
}

客户端拦截器

package main

import (
    "context"
    "fmt"
    "google.golang.org/grpc"
    "google.golang.org/grpc/credentials/insecure"
    "sync"
    "test/stream"
    "time"
)

func main() {
    // 设置客户端流模式拦截器
    opt := grpc.WithChainStreamInterceptor(func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
        fmt.Println("调用前")
        s, err := streamer(ctx, desc, cc, method, opts...)
        fmt.Println("调用后")
        return s, err
    })

    // 建立连接
    conn, _ := grpc.Dial("127.0.0.1:8080", grpc.WithTransportCredentials(insecure.NewCredentials()), opt)

    // 实例化客户端
    client := stream.NewStreamClient(conn)

    // 调用服务
    wg := sync.WaitGroup{}
    wg.Add(2)
    all, _ := client.AllStream(context.Background())
    go func() {
        defer wg.Done()
        for {
            if res, err := all.Recv(); err != nil {
                fmt.Println(err)
                break
            } else {
                fmt.Println(res.Data)
            }
        }
    }()
    go func() {
        defer wg.Done()
        for i := 0; i < 10; i++ {
            _ = all.Send(&stream.RequestInfo{
                Data: fmt.Sprintf("客户端消息:%v", time.Now().Unix()),
            })
            time.Sleep(time.Second)
        }
    }()
    wg.Wait()
}

服务端拦截器

package main

import (
    "fmt"
    "google.golang.org/grpc"
    "net"
    "sync"
    "test/stream"
    "time"
)

type Stream struct {
    stream.UnimplementedStreamServer
}

func (s *Stream) AllStream(all stream.Stream_AllStreamServer) error {
    wg := sync.WaitGroup{}
    wg.Add(2)
    go func() {
        defer wg.Done()
        for {
            if res, err := all.Recv(); err != nil {
                fmt.Println(err)
                break
            } else {
                fmt.Println(res.Data)
            }
        }
    }()
    go func() {
        defer wg.Done()

        for i := 0; i < 10; i++ {
            _ = all.Send(&stream.ResponseInfo{
                Data: fmt.Sprintf("服务端消息:%v", time.Now().Unix()),
            })
            time.Sleep(time.Second)
        }
    }()
    wg.Wait()
    return nil
}

func main() {
    opt := grpc.StreamInterceptor(func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
        fmt.Println("调用前")
        err := handler(srv, ss)
        fmt.Println("调用后")
        return err
    })
    // 监听
    listen, _ := net.Listen("tcp", ":8080")

    // 实例化grpc服务
    s := grpc.NewServer(opt)

    // 注册服务
    stream.RegisterStreamServer(s, &Stream{})

    // 启动
    s.Serve(listen)
}
本作品采用《CC 协议》,转载必须注明作者和本文链接
讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!