04-中间件的使用
middleware 作用域
1. Kratos service Middleware (grpc, http)
2. http Filter mux
3. grpc unaryinterceptor
是否登录中间件
注意:路由路径为 grpc 的路由,包名+路径
注册中间件
# 1. 编写自定义中间件
# 2 在下方路径中注册中间件
internal\server\http.go
jwt
自定义中间件
文件路径: internal\pkg\middleware\auth\auth.go
package auth
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/golang-jwt/jwt/v4"
)
func GenerateToken(secret, username string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"username": username,
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
})
// Sign and get the complete encoded token as a string using the secret
// tokenString, err := token.SignedString(hmacSampleSecret)
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
panic(err)
}
return tokenString
}
func JWTAuth(secret string) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromServerContext(ctx); ok {
// Do something on entering
tokenString := tr.RequestHeader().Get("Authorization")
auths := strings.SplitN(tokenString, " ", 2)
if len(auths) != 2 || !strings.EqualFold(auths[0], "Token") {
return nil, errors.New("jwt token missing")
}
token, err := jwt.Parse(auths[1], func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
return []byte(secret), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// fmt.Println(claims["foo"], claims["nbf"])
spew.Dump(claims["username"])
} else {
// fmt.Println(err)
return nil, errors.New("Token Invalid")
}
defer func() {
// Do something on exiting
}()
}
return handler(ctx, req)
}
}
}
auth_test.go
package auth
import (
"testing"
"github.com/davecgh/go-spew/spew"
)
func TestGenerateToken(t *testing.T) {
tk := GenerateToken("bypo")
spew.Dump(tk)
panic(1)
}
注册
package server
// 路径 internal\server\http.go
import (
"context"
"fmt"
v1 "kratos-realworld/api/realworld/v1"
"kratos-realworld/internal/conf"
"kratos-realworld/internal/pkg/middleware/auth"
"kratos-realworld/internal/service"
nethttp "net/http"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware/recovery"
"github.com/go-kratos/kratos/v2/middleware/selector"
"github.com/go-kratos/kratos/v2/transport/http"
"github.com/gorilla/handlers"
)
func NewSkipRoutersMatcher() selector.MatchFunc {
// 注意: 定制中间件是通过 operation 匹配,并不是 http 本身的路由!!!
// gRPC path 的拼接规则为 /包名.服务名/方法名(/package.Service/Method)。
skipRouters := map[string]struct{}{
"/realworld.v1.RealWorld/Login": {},
"/realworld.v1.RealWorld/Register": {},
"/realworld.v1.RealWorld/GetArticle": {},
"/realworld.v1.RealWorld/ListArticles": {},
"/realworld.v1.RealWorld/GetComments": {},
"/realworld.v1.RealWorld/GetTags": {},
"/realworld.v1.RealWorld/GetProfile": {},
}
return func(ctx context.Context, operation string) bool {
if _, ok := skipRouters[operation]; ok {
return false
}
return true
}
}
// NewHTTPServer new a HTTP server.
func NewHTTPServer(c *conf.Server, jwtc *conf.JWT, greeter *service.RealWorldService, logger log.Logger) *http.Server {
var opts = []http.ServerOption{
http.ErrorEncoder(errorEncoder),
http.Middleware(
recovery.Recovery(),
// 注册 middleware
// auth.JWTAuth(jwtc.Secret),
selector.Server(auth.JWTAuth(jwtc.Secret)).Match(NewSkipRoutersMatcher()).Build(),
),
http.Filter(
func(h nethttp.Handler) nethttp.Handler {
return nethttp.HandlerFunc(func(w nethttp.ResponseWriter, r *nethttp.Request) {
fmt.Println("route filter in")
h.ServeHTTP(w, r)
fmt.Println("route filter out")
})
},
// 跨域
handlers.CORS(
handlers.AllowedHeaders([]string{"X-Requested-With", "Content-Type", "Authorization"}),
handlers.AllowedMethods([]string{"GET", "POST", "PUT", "HEAD", "OPTIONS"}),
handlers.AllowedOrigins([]string{"*"}),
),
),
}
if c.Http.Network != "" {
opts = append(opts, http.Network(c.Http.Network))
}
if c.Http.Addr != "" {
opts = append(opts, http.Address(c.Http.Addr))
}
if c.Http.Timeout != nil {
opts = append(opts, http.Timeout(c.Http.Timeout.AsDuration()))
}
srv := http.NewServer(opts...)
v1.RegisterRealWorldHTTPServer(srv, greeter)
return srv
}