五步用golang封装自己web框架httper:第三步,实现中间件

提供一个新思路,由于中间件作用是作用在各个路由对请求进行预处理或后续处理的http.Handler,所以可以分类进行注册中间件。

  1. 在监听服务启动时——即调用Start()方法时,注册根serveMux实例保存的中间件,这样可以在不管是否存在匹配路由情况下,均调用跨域或日志等中间件。
  2. 在路由分组时——即调用Group()方法时,保存除根路由组以外父路由组中间件,存储在新的serveMux实例中。
  3. 在注册路由时注册该serveMux中的中间件。

    isRootGroup用于判断是否为根serveMux,实现注册根路由组中间件

    useMiddleware()方法用于反向嵌套中间件

http.go

type Middleware func(next http.Handler) http.Handler

type serveMux struct {
    mux *http.ServeMux
    // root handle
    IsBreakRootRouter bool
    NotFoundHandler   http.Handler
    // group
    routers     map[string]string
    groupPrefix string
    isRootGroup bool
    middlewares []Middleware
}

func New() *serveMux {
    return &serveMux{
        mux: http.NewServeMux(),
        // root handle
        IsBreakRootRouter: true,
        NotFoundHandler:   http.NotFoundHandler(),
        // group
        routers:     make(map[string]string),
        isRootGroup: true,
        middlewares: make([]Middleware, 0),
    }
}

func (sm *serveMux) Start(addr string) error {
    var handler http.Handler
    handler = sm.mux
    if sm.IsBreakRootRouter {
        handler = sm.breakRootRouterHandler(sm.mux)
    }
    // use middlewares of the root serve mux
    handler = sm.useMiddlewares(handler, sm.middlewares...)
    return http.ListenAndServe(addr, handler)
}

func (sm *serveMux) useMiddlewares(handler http.Handler, middlewares ...Middleware) http.Handler {
    l := len(sm.middlewares) - 1
    // use middlewares reversely
    for i := l; i >= 0; i-- {
        handler = middlewares[i](handler)
    }
    return handler
}

Use方法用于添加中间件到当前serveMux实例

func (sm *serveMux) Use(middleware ...Middleware) {
    sm.middlewares = append(sm.middlewares, middleware...)
}

修改Group方法,保存新除了根路由组中间件外父路由组的中间件

func (sm *serveMux) Group(prefix string) *serveMux {
    middlewares := make([]Middleware, 0)
    // add middlewares of parent serve mux,but not root serve mux
    if !sm.isRootGroup {
        middlewares = append(middlewares, sm.middlewares...)
    }
    newServeMux := &serveMux{
        mux:               sm.mux,
        IsBreakRootRouter: sm.IsBreakRootRouter,
        NotFoundHandler:   sm.NotFoundHandler,
        routers:           sm.routers,
        groupPrefix:       sm.groupPrefix + prefix,
        isRootGroup:       false,
        middlewares:       middlewares,
    }
    return newServeMux
}

注册路由时,注册serveMux实例中的中间件

func (sm *serveMux) registerRouter(method string, router string, handler http.Handler) {
    // use middlewares of not root serve mux
    if !sm.isRootGroup {
        handler = sm.useMiddlewares(handler, sm.middlewares...)
    }
    router = fmt.Sprintf("%s %s%s", method, sm.groupPrefix, router)
    sm.routers[router] = funcName(handler)
    sm.mux.Handle(router, handler)
}

测试

http_test.go

func TestServeMux_Run(t *testing.T) {
    mux := httper.New()
    mux.POST("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        _, err := w.Write([]byte(fmt.Sprintf("url: %s", r.URL.Path)))
        if err != nil {
            http.Error(w, err.Error(), http.StatusInternalServerError)
        }
    }))
    api := mux.Group("/api")
    api.Use(MiddleTest1(t), MiddleTest2(t))
    v1 := api.Group("/v1")
    v1.Use(MiddleTest3(t))
    v1.POST("/{value}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        value := r.PathValue("value")
        if value == "" {
            http.Error(w, "value is empty", http.StatusBadRequest)
            return
        }
        _, err := w.Write([]byte(fmt.Sprintf("path value: %s", value)))
        if err != nil {
            http.Error(w, err.Error(), http.StatusInternalServerError)
        }
    }))
    for r, f := range mux.Routes() {
        t.Log("register router:", r, f)
    }
    t.Fatal(mux.Start(":8000"))
}

func MiddleTest1(t *testing.T) httper.Middleware {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            t.Log("middle 1 start")
            next.ServeHTTP(w, r)
            t.Log("middle 1 end")
        })
    }
}

func MiddleTest2(t *testing.T) httper.Middleware {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            t.Log("middle 2 start")
            next.ServeHTTP(w, r)
            t.Log("middle 2 end")
        })
    }
}

func MiddleTest3(t *testing.T) httper.Middleware {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            t.Log("middle 3 start")
            next.ServeHTTP(w, r)
            t.Log("middle 3 end")
        })
    }
}

测试结果

POST 127.0.0.1:8000/api

http_test.go:42: middle 1 start
http_test.go:52: middle 2 start
http_test.go:62: middle 3 start
http_test.go:64: middle 3 end
http_test.go:54: middle 2 end
http_test.go:44: middle 1 end
本作品采用《CC 协议》,转载必须注明作者和本文链接
讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!