websocket服务器练习,已跑通Autobahn服务端所有测试
实现
package main
import (
"bufio"
"bytes"
"context"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"errors"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"
"unicode/utf8"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
"github.com/panjf2000/gnet/v2"
"github.com/panjf2000/gnet/v2/pkg/logging"
"github.com/panjf2000/gnet/v2/pkg/pool/byteslice"
"github.com/panjf2000/gnet/v2/pkg/pool/goroutine"
)
func NewMessagePayload(len int) *MessagePayload {
return &MessagePayload{
fragments: make([][]byte, 0, len),
}
}
type MessagePayload struct {
fragments [][]byte
}
func (r *MessagePayload) reset() {
// 解除分片引用,避免内存滞留
for i, frag := range r.fragments {
if i > 0 {
byteslice.Put(frag)
}
r.fragments[i] = nil
}
if len(r.fragments) > 16 {
r.fragments = make([][]byte, 0, 4)
} else {
r.fragments = r.fragments[:0] // 保留容量复用
}
}
func (r *MessagePayload) append(fragment []byte, mask [4]byte) {
var tmp []byte
if len(r.fragments) == 0 {
tmp = make([]byte, len(fragment))
} else {
tmp = byteslice.Get(len(fragment))
}
copy(tmp, fragment)
ws.Cipher(tmp, mask, 0)
r.fragments = append(r.fragments, tmp)
}
func (r *MessagePayload) merge() []byte {
n := len(r.fragments)
if n == 0 {
return nil
}
if n == 1 {
return r.fragments[0] // 安全:append 时已复制
}
// 计算总长度
totalLen := 0
for _, frag := range r.fragments {
totalLen += len(frag)
}
// 一次性分配 + 合并
ret := make([]byte, totalLen)
offset := 0
for _, frag := range r.fragments {
copy(ret[offset:], frag)
offset += len(frag)
}
return ret
}
type Codec struct {
preMessagePayload *MessagePayload // 上一个帧的数据
currMessagePayload *MessagePayload // 当前帧的数据
upgraded bool
compression bool
preMessageOpCode ws.OpCode
currMessageOpCode ws.OpCode
messageRsv byte // 0~7 的整数,由于控制帧可以穿插在消息帧的分片中,但是控制帧又没有 rsv,所有只需一个字段即可,当控制帧穿插进来的时候,无需转移保存
session any
closed bool // 是否服务端主动关闭
}
func (codec *Codec) SetClosed() {
codec.closed = true
}
func (codec *Codec) IsClosed() bool {
return codec.closed
}
func (codec *Codec) IsCompress() bool {
return codec.compression
}
func (codec *Codec) SetSession(session any) {
codec.session = session
}
func (codec *Codec) GetSession() any {
return codec.session
}
func (codec *Codec) resetCurrMessage() {
codec.currMessagePayload.reset()
if codec.preMessageOpCode == 255 {
codec.currMessageOpCode = 255
codec.messageRsv = 0
} else {
codec.currMessageOpCode = codec.preMessageOpCode
codec.currMessagePayload = codec.preMessagePayload
codec.preMessageOpCode = 255
codec.preMessagePayload = nil
}
}
func (codec *Codec) upgrade(onUpgradeCheck func(req *http.Request) *http.Response, c gnet.Conn) (*http.Request, gnet.Action) {
peek, err := c.Peek(-1)
if err != nil {
return nil, gnet.Close
}
//判断 http head 是否完整
if l := len(peek); l < 4 || bytes.Equal(peek[l-4:], []byte("\r\n\r\n")) == false {
return nil, gnet.None
}
//读取 http head
req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(peek)))
if err != nil {
//返回 HTTP 400
resp := http.Response{
StatusCode: http.StatusBadRequest,
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(strings.NewReader(http.StatusText(http.StatusBadRequest))),
}
_ = resp.Write(c)
return nil, gnet.Close
}
//判断请求方法
if req.Method != "GET" {
//响应 HTTP 405
resp := http.Response{
StatusCode: http.StatusMethodNotAllowed,
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(strings.NewReader(http.StatusText(http.StatusMethodNotAllowed))),
}
_ = resp.Write(c)
return nil, gnet.Close
}
//没有Upgrade头,说明无须升级协议,响应一个正常的http ok
if _, ok := req.Header["Upgrade"]; ok == false {
resp := http.Response{
StatusCode: http.StatusOK,
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(strings.NewReader("Hello netsvr")),
}
_ = resp.Write(c)
return nil, gnet.Close
}
//握手之前验证
if resp := onUpgradeCheck(req); resp != nil {
_ = resp.Write(c)
return nil, gnet.Close
}
//验证 Upgrade 头
if strings.EqualFold(req.Header.Get("Upgrade"), "websocket") == false {
resp := http.Response{
StatusCode: http.StatusBadRequest,
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(strings.NewReader("invalid Upgrade header")),
}
_ = resp.Write(c)
return nil, gnet.Close
}
//验证 Connection 头包含 "Upgrade"
if connection := req.Header.Get("Connection"); connection != "Upgrade" && connection != "upgrade" {
resp := http.Response{
StatusCode: http.StatusBadRequest,
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(strings.NewReader("invalid Connection header")),
}
_ = resp.Write(c)
return nil, gnet.Close
}
//验证 Sec-WebSocket-Version == 13
if req.Header.Get("Sec-WebSocket-Version") != "13" {
// 返回 426 Upgrade Required + 正确版本
resp := http.Response{
StatusCode: http.StatusUpgradeRequired,
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{
"Sec-WebSocket-Version": []string{"13"},
},
Body: io.NopCloser(strings.NewReader(http.StatusText(http.StatusUpgradeRequired))),
}
_ = resp.Write(c)
return nil, gnet.Close
}
//协商压缩扩展
if ext := req.Header.Get("Sec-WebSocket-Extensions"); ext != "" {
options, ok := httphead.ParseOptions([]byte(ext), nil)
if ok == false {
resp := http.Response{
StatusCode: http.StatusBadRequest,
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(strings.NewReader("invalid Sec-WebSocket-Extensions header")),
}
_ = resp.Write(c)
return nil, gnet.Close
}
for _, option := range options {
if bytes.Equal(option.Name, wsflate.ExtensionNameBytes) {
codec.compression = true
break
}
}
}
//获取 Sec-WebSocket-Key
key := req.Header.Get("Sec-WebSocket-Key")
if key == "" {
resp := http.Response{
StatusCode: http.StatusBadRequest,
ProtoMajor: 1,
ProtoMinor: 1,
Body: io.NopCloser(strings.NewReader("missing Sec-WebSocket-Key")),
}
_ = resp.Write(c)
return nil, gnet.Close
}
//计算 Sec-WebSocket-Accept
h := sha1.New()
h.Write([]byte(key))
h.Write([]byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
acceptKey := base64.StdEncoding.EncodeToString(h.Sum(nil))
//构造 101 Switching Protocols 响应
resp := http.Response{
StatusCode: http.StatusSwitchingProtocols,
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{
"Upgrade": []string{"websocket"},
"Connection": []string{"Upgrade"},
"Sec-WebSocket-Accept": []string{acceptKey},
},
}
//添加压缩扩展
if codec.compression {
//两个参数直接决定了系统的内存开销和压缩率,除非你的并发连接数非常少,且对压缩率有极致要求,否则永远选择 no_context_takeover。
//这是一个典型的用少量性能损失换取巨大可伸缩性和稳定性的架构决策。
// server_no_context_takeover //告诉客户端,服务器不会为客户端的不同消息复用同一个 LZ77 滑动窗口(即压缩上下文),每次压缩都是独立的。
// client_no_context_takeover 告诉客户端,它在解压缩来自服务器的消息时,也不应该复用解压上下文。
resp.Header["sec-websocket-extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
}
//清除所有请求数据,有些请求会在get中添加数据,所以需要清空
_, err = c.Discard(c.InboundBuffered())
if err != nil {
return nil, gnet.Close
}
//发送响应
err = resp.Write(c)
if err != nil {
return nil, gnet.Close
}
//升级成功
codec.upgraded = true
return req, gnet.None
}
func (codec *Codec) decode(c gnet.Conn) ([]byte, ws.StatusCode, error) {
for {
peek, err := c.Peek(-1)
if err != nil {
if errors.Is(err, io.ErrShortBuffer) {
//数据不完整
return nil, ws.StatusCode(0), io.ErrShortBuffer
}
return nil, ws.StatusInternalServerError, err
}
if len(peek) < 2 {
return nil, ws.StatusCode(0), io.ErrShortBuffer
}
//ws.ReadHeader()
var head ws.Header
head.Fin = peek[0]&0x80 != 0
head.Rsv = (peek[0] & 0x70) >> 4
head.OpCode = ws.OpCode(peek[0] & 0x0f)
validOp := head.OpCode == ws.OpContinuation ||
head.OpCode == ws.OpText ||
head.OpCode == ws.OpBinary ||
head.OpCode == ws.OpClose ||
head.OpCode == ws.OpPing ||
head.OpCode == ws.OpPong
if !validOp {
return nil, ws.StatusProtocolError, errors.New("unsupported opcode")
}
if head.Rsv != 0 {
//未协商开启压缩扩展,RSV1 bits MUST be 0
if codec.compression == false {
return nil, ws.StatusProtocolError, errors.New("RSV1 bits must be 0 without extensions")
} else if head.OpCode.IsControl() {
//控制帧的 RSV1 不对
return nil, ws.StatusProtocolError, errors.New("RSV1 must be 0")
}
}
//已经得到了首帧,此刻的是后续帧
if codec.currMessageOpCode != 255 {
//控制帧不允许分片
if codec.currMessageOpCode.IsControl() {
return nil, ws.StatusProtocolError, errors.New("control message MUST NOT be fragmented")
}
//后续帧
if head.OpCode.IsControl() {
//控制帧穿插在数据帧分片之间,保存之前的数据帧
codec.preMessageOpCode = codec.currMessageOpCode
codec.preMessagePayload = codec.currMessagePayload
codec.currMessageOpCode = 255
codec.currMessagePayload = NewMessagePayload(1)
} else {
//数据帧的连续帧的 opcode 不对
if head.OpCode != ws.OpContinuation {
return nil, ws.StatusProtocolError, errors.New("non-first fragment must be continuation")
}
//协商已开启压缩扩展,连续帧的 RSV1 不对
if codec.compression && head.Rsv != 0 {
return nil, ws.StatusProtocolError, errors.New("RSV1 must be 0")
}
}
}
if peek[1]&0x80 == 0 {
return nil, ws.StatusProtocolError, errors.New("client must mask data")
}
var extra = 0
length := peek[1] & 0x7f
switch {
case length < 126:
head.Length = int64(length)
extra = 4 // 2 bytes header + 4 bytes mask
case length == 126:
extra = 6 // 2 bytes header + 2 bytes length + 4 bytes mask
case length == 127:
extra = 12 // 2 bytes header + 8 bytes length + 4 bytes mask
default:
return nil, ws.StatusProtocolError, errors.New("unexpected payload length bits")
}
if len(peek) < 2+extra {
//数据不完整
return nil, ws.StatusCode(0), io.ErrShortBuffer
}
peek = peek[2:]
switch {
case length == 126:
head.Length = int64(binary.BigEndian.Uint16(peek[:2]))
peek = peek[2:]
case length == 127:
if peek[0]&0x80 != 0 {
return nil, ws.StatusProtocolError, errors.New("the most significant bit must be 0")
}
head.Length = int64(binary.BigEndian.Uint64(peek[:8]))
peek = peek[8:]
}
//校验 Ping/Pong/Close 的 payload 长度
if head.OpCode.IsControl() && head.Length > 125 {
return nil, ws.StatusProtocolError, errors.New("control frame too long")
}
if len(peek) < (int)(head.Length)+4 {
//数据不完整
return nil, ws.StatusCode(0), io.ErrShortBuffer
}
copy(head.Mask[:], peek[:4])
//即将得到一个完整的消息帧,此刻可以记住首帧的opcode、RSV
if codec.currMessageOpCode == 255 {
if codec.compression && head.Rsv != 4 && head.Rsv != 0 {
//协商启用了压缩扩展,首帧的 RSV1 不对
return nil, ws.StatusProtocolError, errors.New("RSV1 must be 0 or 4")
}
codec.messageRsv = head.Rsv
//在没有待继续消息时,禁止接收 Continuation 帧
if head.OpCode == ws.OpContinuation {
return nil, ws.StatusProtocolError, errors.New("unexpected continuation frame: no message to continue")
}
codec.currMessageOpCode = head.OpCode
}
codec.currMessagePayload.append(peek[4:4+(int)(head.Length)], head.Mask)
_, err = c.Discard(2 + extra + (int)(head.Length))
if err != nil {
return nil, ws.StatusInternalServerError, err
}
if head.Fin {
completeMessagePayload := codec.currMessagePayload.merge()
//当前 header 已经是一个完整消息
if codec.currMessageOpCode == ws.OpText {
//解压缩数据
if codec.compression && codec.messageRsv == 4 {
completeMessagePayload, err = wsflate.DefaultHelper.Decompress(completeMessagePayload)
if err != nil {
return nil, ws.StatusInvalidFramePayloadData, errors.New("invalid deflate stream")
}
}
//文本消息,校验utf8
if utf8.Valid(completeMessagePayload) == false {
return nil, ws.StatusInvalidFramePayloadData, errors.New("invalid utf8 in text message")
}
} else if codec.currMessageOpCode == ws.OpBinary {
//解压缩数据
if codec.compression && codec.messageRsv == 4 {
completeMessagePayload, err = wsflate.DefaultHelper.Decompress(completeMessagePayload)
if err != nil {
return nil, ws.StatusInvalidFramePayloadData, errors.New("invalid deflate stream")
}
}
} else if codec.currMessageOpCode == ws.OpClose && len(completeMessagePayload) > 0 {
//关闭消息,校验关闭码
pl := len(completeMessagePayload)
if pl == 1 || pl > 125 {
return nil, ws.StatusProtocolError, errors.New("close frame payload length invalid")
}
if pl >= 2 {
code := ws.StatusCode(binary.BigEndian.Uint16(completeMessagePayload[:2]))
invalidCode := code.In(ws.StatusRangeNotInUse) ||
code == ws.StatusNoMeaningYet ||
code == ws.StatusNoStatusRcvd ||
code == ws.StatusAbnormalClosure ||
code == 1016 ||
code == 1100 ||
code == 2000 ||
code == 2999
if invalidCode {
return nil, ws.StatusProtocolError, errors.New("invalid close status code")
}
if pl > 2 && utf8.Valid(completeMessagePayload[2:]) == false {
return nil, ws.StatusInvalidFramePayloadData, errors.New("invalid utf8 in text message")
}
}
}
return completeMessagePayload, ws.StatusCode(0), nil
}
}
}
type Server struct {
gnet.BuiltinEventEngine
eng gnet.Engine
//握手之前回调,可以返回http.Response,拒绝握手,在这回调中,你可以校验host、校验path等
OnUpgradeCheck func(req *http.Request) *http.Response
//握手成功后回调
OnWebsocketOpen func(conn gnet.Conn, req *http.Request) (ws.StatusCode, error)
//服务器响应ping后的回调
OnWebsocketPing func()
//收到数据帧的回调
OnWebsocketMessage func(conn gnet.Conn, messageType ws.OpCode, messagePtr []byte)
//服务器回显客户端的状态码和原因后的回调
OnWebsocketClose func(conn gnet.Conn)
}
func (server *Server) Shutdown(ctx context.Context) error {
return server.eng.Stop(ctx)
}
func (server *Server) OnBoot(eng gnet.Engine) gnet.Action {
server.eng = eng
return gnet.None
}
func (server *Server) OnOpen(c gnet.Conn) ([]byte, gnet.Action) {
wsCodec := new(Codec)
c.SetContext(wsCodec)
wsCodec.currMessagePayload = NewMessagePayload(4)
wsCodec.preMessageOpCode = 255
wsCodec.currMessageOpCode = 255
return nil, gnet.None
}
func (server *Server) OnClose(conn gnet.Conn, _ error) (action gnet.Action) {
server.OnWebsocketClose(conn)
return gnet.None
}
func (server *Server) OnTick() (delay time.Duration, action gnet.Action) {
return time.Second, gnet.None
}
func (server *Server) OnTraffic(c gnet.Conn) (action gnet.Action) {
wsCodec, _ := c.Context().(*Codec)
if wsCodec.upgraded == false {
var req *http.Request
req, action = wsCodec.upgrade(server.OnUpgradeCheck, c)
if wsCodec.upgraded {
if statusCode, err := server.OnWebsocketOpen(c, req); err != nil {
payload := ws.NewCloseFrameBody(statusCode, err.Error())
wsCodec.closed = true
_ = wsutil.WriteServerMessage(c, ws.OpClose, payload)
return gnet.Close
}
}
return action
}
loop:
completeMessagePayload, statusCode, err := wsCodec.decode(c)
if err != nil {
if !statusCode.Empty() {
//数据帧解析错误,立即构造并发送close帧
payload := ws.NewCloseFrameBody(statusCode, err.Error())
wsCodec.closed = true
_ = wsutil.WriteServerMessage(c, ws.OpClose, payload)
return gnet.Close
}
//等待更多数据
if errors.Is(err, io.ErrShortBuffer) {
return gnet.None
}
//这行代码应该不会执行
return gnet.Close
}
if wsCodec.currMessageOpCode.IsData() {
server.OnWebsocketMessage(c, wsCodec.currMessageOpCode, completeMessagePayload)
wsCodec.resetCurrMessage()
if c.InboundBuffered() > 0 {
//缓冲区还有数据,继续解析
goto loop
}
return gnet.None
} else if wsCodec.currMessageOpCode == ws.OpClose {
// 是否服务端主动关闭,服务端主动关闭,则不能再回close包,否则客户端会报错:Close received after close
if wsCodec.closed == false {
//返回close,回显客户端的状态码和原因
_ = wsutil.WriteServerMessage(c, ws.OpClose, completeMessagePayload)
}
return gnet.Close
} else if wsCodec.currMessageOpCode == ws.OpPing {
//返回pong,并将ping的payload一并返回
err = wsutil.WriteServerMessage(c, ws.OpPong, completeMessagePayload)
if err != nil {
return gnet.Close
}
server.OnWebsocketPing()
wsCodec.resetCurrMessage()
if c.InboundBuffered() > 0 {
goto loop
}
return gnet.None
} else if wsCodec.currMessageOpCode == ws.OpPong {
//不处理
wsCodec.resetCurrMessage()
if c.InboundBuffered() > 0 {
//缓冲区还有数据,继续解析
goto loop
}
return gnet.None
} else if wsCodec.currMessageOpCode.IsReserved() {
//不支持的数据帧
payload := ws.NewCloseFrameBody(ws.StatusUnsupportedData, "unsupported opcode")
wsCodec.closed = true
_ = wsutil.WriteServerMessage(c, ws.OpClose, payload)
return gnet.Close
}
//这块代码可能执行不到
wsCodec.resetCurrMessage()
if c.InboundBuffered() > 0 {
//继续处理数据
goto loop
}
return gnet.None
}
// WriteMessage 发送数据
func WriteMessage(conn gnet.Conn, messageType ws.OpCode, data []byte) bool {
wsCodec, ok := conn.Context().(*Codec)
if !ok || wsCodec.IsClosed() {
return false
}
var err error
var buff *bytes.Buffer
var frame ws.Frame
//需要压缩,并且数据值得压缩
if wsCodec.IsCompress() {
compressed, err := wsflate.DefaultHelper.Compress(data)
if err != nil {
return false
}
frame = ws.NewFrame(messageType, true, compressed)
frame.Header.Rsv = ws.Rsv(true, false, false)
} else {
//不压缩
frame = ws.NewFrame(messageType, true, data)
}
//创建 frame
buff = bytes.NewBuffer(make([]byte, 0, 4+len(data)))
err = ws.WriteFrame(buff, frame)
if err != nil {
return false
}
//发送数据
err = conn.AsyncWrite(buff.Bytes(), nil)
if err == nil {
return true
}
//发送失败,则关闭连接
_ = conn.Close()
return false
}
func main() {
server := &Server{
OnUpgradeCheck: func(req *http.Request) *http.Response {
return nil
},
OnWebsocketOpen: func(conn gnet.Conn, req *http.Request) (ws.StatusCode, error) {
return ws.StatusCode(0), nil
},
OnWebsocketClose: func(conn gnet.Conn) {
},
OnWebsocketPing: func() {
},
OnWebsocketMessage: func(conn gnet.Conn, messageType ws.OpCode, data []byte) {
fn := func() {
WriteMessage(conn, messageType, data)
}
err := goroutine.DefaultWorkerPool.Submit(fn)
if err != nil {
//提交异步任务失败,则使用goroutine执行,保证数据发送成功
go fn()
}
},
}
err := gnet.Run(
server,
"tcp://0.0.0.0:6636",
gnet.WithMulticore(false),
//gnet.WithLockOSThread(true),
//gnet.WithLoadBalancing(gnet.LeastConnections),
gnet.WithLogLevel(logging.InfoLevel),
)
if err != nil {
slog.Error("gnet.Run error", "error", err)
time.Sleep(time.Millisecond * 100)
os.Exit(1)
}
}
测试
docker run -v "C:/buexplain/testgent:/reports" --name fuzzingclient -d crossbario/autobahn-testsuite tail -f /dev/null
docker exec -it fuzzingclient /bin/bash
echo '{"options":{"failByDrop":false},"outdir":"./reports/servers","servers":[{"agent":"AutobahnServer","url":"ws://host.docker.internal:6636/"}],"cases":["*"],"exclude-cases":[],"exclude-agent-cases":{}}' > /config/fuzzingclient.json
wstest -m fuzzingclient -s /config/fuzzingclient.json
本作品采用《CC 协议》,转载必须注明作者和本文链接
关于 LearnKu