04-中间件的使用

未匹配的标注

middleware 作用域

1. Kratos service Middleware (grpc, http)
2. http Filter mux
3. grpc unaryinterceptor

是否登录中间件

注意:路由路径为 grpc 的路由,包名+路径

注册中间件

# 1. 编写自定义中间件
# 2 在下方路径中注册中间件
internal\server\http.go

jwt

golang-jwt

自定义中间件

文件路径: internal\pkg\middleware\auth\auth.go

package auth

import (
    "context"
    "errors"
    "fmt"
    "strings"
    "time"

    "github.com/davecgh/go-spew/spew"
    "github.com/go-kratos/kratos/v2/middleware"
    "github.com/go-kratos/kratos/v2/transport"
    "github.com/golang-jwt/jwt/v4"
)

func GenerateToken(secret, username string) string {
    token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
        "username": username,
        "nbf":      time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
    })

    // Sign and get the complete encoded token as a string using the secret
    // tokenString, err := token.SignedString(hmacSampleSecret)
    tokenString, err := token.SignedString([]byte(secret))
    if err != nil {
        panic(err)
    }

    return tokenString
}

func JWTAuth(secret string) middleware.Middleware {
    return func(handler middleware.Handler) middleware.Handler {
        return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
            if tr, ok := transport.FromServerContext(ctx); ok {
                // Do something on entering
                tokenString := tr.RequestHeader().Get("Authorization")
                auths := strings.SplitN(tokenString, " ", 2)
                if len(auths) != 2 || !strings.EqualFold(auths[0], "Token") {
                    return nil, errors.New("jwt token missing")
                }
                token, err := jwt.Parse(auths[1], func(token *jwt.Token) (interface{}, error) {
                    // Don't forget to validate the alg is what you expect:
                    if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
                        return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
                    }

                    // hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
                    return []byte(secret), nil
                })

                if err != nil {
                    return nil, err
                }

                if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
                    // fmt.Println(claims["foo"], claims["nbf"])
                    spew.Dump(claims["username"])
                } else {
                    // fmt.Println(err)
                    return nil, errors.New("Token Invalid")
                }

                defer func() {
                    // Do something on exiting
                }()
            }
            return handler(ctx, req)
        }
    }
}

auth_test.go

package auth

import (
    "testing"

    "github.com/davecgh/go-spew/spew"
)

func TestGenerateToken(t *testing.T) {
    tk := GenerateToken("bypo")
    spew.Dump(tk)
    panic(1)
}

注册

package server
// 路径 internal\server\http.go

import (
    "context"
    "fmt"
    v1 "kratos-realworld/api/realworld/v1"
    "kratos-realworld/internal/conf"
    "kratos-realworld/internal/pkg/middleware/auth"
    "kratos-realworld/internal/service"
    nethttp "net/http"

    "github.com/go-kratos/kratos/v2/log"
    "github.com/go-kratos/kratos/v2/middleware/recovery"
    "github.com/go-kratos/kratos/v2/middleware/selector"
    "github.com/go-kratos/kratos/v2/transport/http"
    "github.com/gorilla/handlers"
)

func NewSkipRoutersMatcher() selector.MatchFunc {

    // 注意: 定制中间件是通过 operation 匹配,并不是 http 本身的路由!!!
    // gRPC path 的拼接规则为 /包名.服务名/方法名(/package.Service/Method)。
    skipRouters := map[string]struct{}{
        "/realworld.v1.RealWorld/Login":        {},
        "/realworld.v1.RealWorld/Register":     {},
        "/realworld.v1.RealWorld/GetArticle":   {},
        "/realworld.v1.RealWorld/ListArticles": {},
        "/realworld.v1.RealWorld/GetComments":  {},
        "/realworld.v1.RealWorld/GetTags":      {},
        "/realworld.v1.RealWorld/GetProfile":   {},
    }

    return func(ctx context.Context, operation string) bool {
        if _, ok := skipRouters[operation]; ok {
            return false
        }
        return true
    }
}

// NewHTTPServer new a HTTP server.
func NewHTTPServer(c *conf.Server, jwtc *conf.JWT, greeter *service.RealWorldService, logger log.Logger) *http.Server {
    var opts = []http.ServerOption{
        http.ErrorEncoder(errorEncoder),

        http.Middleware(
            recovery.Recovery(),
            // 注册 middleware
            // auth.JWTAuth(jwtc.Secret),
            selector.Server(auth.JWTAuth(jwtc.Secret)).Match(NewSkipRoutersMatcher()).Build(),
        ),
        http.Filter(
            func(h nethttp.Handler) nethttp.Handler {
                return nethttp.HandlerFunc(func(w nethttp.ResponseWriter, r *nethttp.Request) {
                    fmt.Println("route filter in")
                    h.ServeHTTP(w, r)
                    fmt.Println("route filter out")
                })
            },
            // 跨域
            handlers.CORS(
                handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"}),
                handlers.AllowedMethods([]string{"GET", "POST", "PUT", "HEAD", "OPTIONS"}),
                handlers.AllowedOrigins([]string{"*"}),
            ),
        ),
    }
    if c.Http.Network != "" {
        opts = append(opts, http.Network(c.Http.Network))
    }
    if c.Http.Addr != "" {
        opts = append(opts, http.Address(c.Http.Addr))
    }
    if c.Http.Timeout != nil {
        opts = append(opts, http.Timeout(c.Http.Timeout.AsDuration()))
    }
    srv := http.NewServer(opts...)
    v1.RegisterRealWorldHTTPServer(srv, greeter)
    return srv
}

本文章首发在 LearnKu.com 网站上。

上一篇 下一篇
讨论数量: 0
发起讨论 只看当前版本


暂无话题~