Go 写一个内网穿透工具

系统架构

系统分为两个部分,client 和 server,client运行在内网服务器中,server运行在公网服务器中,当我们想访问内网中的服务,我们通过公网服务器做一个中继。

下面是展示我灵魂画手的时刻了

user发送请求给 server,server和client建立连接,将请求发给client,client再将请求发给本地程序处理(内网中),然后本地程序将处理结果返回给client,client将结果返回给server,server再将结果返回给用户,这样用户就访问到了内网中的程序了。

代码流程

  1. server端监听两个端口,一个用来和user通信,一个和client通信
  2. client启动时连接server端,并启动一个端口监听本地某程序
  3. 当User连接到server端口,将User请求内容发给client
  4. client将从server收到的请求发给本地程序
  5. client将从本地程序收到的内容发给server
  6. server将从client收到的内容发给User即可

  1. 当Server与client没有消息通信,连接会断开
  2. client断开后,再启动会连接不到Server
  3. Server端会因为client断开而引发panic

为了解决这种坑点,加入了心跳包机制,通过5s发送一次心跳包,保持client与server的连接,同时建立一个重连通道,监听该通道,如果当Client被断开后,则往重连通道放一个值,告诉Server端,等待新的Client连接,而避免引发Panic

代码

更详细的我就不说了,直接看代码,代码里面有详细的注释, 排版有问题,直接去github看吧。。。

代码仓库地址: https://github.com/pibigstar/go-proxy

Server端

运行在具有公网IP地址的服务器端

package main
import (
    "flag"
    "fmt"
    "io"
    "net"
    "runtime"
    "strings"
    "time"
)
var (
    localPort  int
    remotePort int
)
func init() {
    flag.IntVar(&localPort, "l", 5200, "the user link port")
    flag.IntVar(&remotePort, "r", 3333, "client listen port")
}
type client struct {
    conn net.Conn
    // 数据传输通道
    read  chan []byte
    write chan []byte
    // 异常退出通道
    exit chan error
    // 重连通道
    reConn chan bool
}
// 从Client端读取数据
func (c *client) Read() {
    // 如果10秒钟内没有消息传输,则Read函数会返回一个timeout的错误
    _ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
    for {
        data := make([]byte, 10240)
        n, err := c.conn.Read(data)
        if err != nil && err != io.EOF {
            if strings.Contains(err.Error(), "timeout") {
                // 设置读取时间为3秒,3秒后若读取不到, 则err会抛出timeout,然后发送心跳
                _ = c.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
                c.conn.Write([]byte("pi"))
                continue
            }
            fmt.Println("读取出现错误...")
            c.exit <- err
        }
        // 收到心跳包,则跳过
        if data[0] == 'p' && data[1] == 'i' {
            fmt.Println("server收到心跳包")
            continue
        }
        c.read <- data[:n]
    }
}
// 将数据写入到Client端
func (c *client) Write() {
    for {
        select {
        case data := <-c.write:
            _, err := c.conn.Write(data)
            if err != nil && err != io.EOF {
                c.exit <- err
            }
        }
    }
}
type user struct {
    conn net.Conn
    // 数据传输通道
    read  chan []byte
    write chan []byte
    // 异常退出通道
    exit chan error
}
// 从User端读取数据
func (u *user) Read() {
    _ = u.conn.SetReadDeadline(time.Now().Add(time.Second * 200))
    for {
        data := make([]byte, 10240)
        n, err := u.conn.Read(data)
        if err != nil && err != io.EOF {
            u.exit <- err
        }
        u.read <- data[:n]
    }
}
// 将数据写给User端
func (u *user) Write() {
    for {
        select {
        case data := <-u.write:
            _, err := u.conn.Write(data)
            if err != nil && err != io.EOF {
                u.exit <- err
            }
        }
    }
}
func main() {
    flag.Parse()
    defer func() {
        err := recover()
        if err != nil {
            fmt.Println(err)
        }
    }()
    clientListener, err := net.Listen("tcp", fmt.Sprintf(":%d", remotePort))
    if err != nil {
        panic(err)
    }
    fmt.Printf("监听:%d端口, 等待client连接... \n", remotePort)
    // 监听User来连接
    userListener, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort))
    if err != nil {
        panic(err)
    }
    fmt.Printf("监听:%d端口, 等待user连接.... \n", localPort)
    for {
        // 有Client来连接了
        clientConn, err := clientListener.Accept()
        if err != nil {
            panic(err)
        }
        fmt.Printf("有Client连接: %s \n", clientConn.RemoteAddr())
        client := &client{
            conn:   clientConn,
            read:   make(chan []byte),
            write:  make(chan []byte),
            exit:   make(chan error),
            reConn: make(chan bool),
        }
        userConnChan := make(chan net.Conn)
        go AcceptUserConn(userListener, userConnChan)
        go HandleClient(client, userConnChan)
        <-client.reConn
        fmt.Println("重新等待新的client连接..")
    }
}
func HandleClient(client *client, userConnChan chan net.Conn) {
    go client.Read()
    go client.Write()
    for {
        select {
        case err := <-client.exit:
            fmt.Printf("client出现错误, 开始重试, err: %s \n", err.Error())
            client.reConn <- true
            runtime.Goexit()
        case userConn := <-userConnChan:
            user := &user{
                conn:  userConn,
                read:  make(chan []byte),
                write: make(chan []byte),
                exit:  make(chan error),
            }
            go user.Read()
            go user.Write()
            go handle(client, user)
        }
    }
}
// 将两个Socket通道链接
// 1. 将从user收到的信息发给client
// 2. 将从client收到信息发给user
func handle(client *client, user *user) {
    for {
        select {
        case userRecv := <-user.read:
            // 收到从user发来的信息
            client.write <- userRecv
        case clientRecv := <-client.read:
            // 收到从client发来的信息
            user.write <- clientRecv
        case err := <-client.exit:
            fmt.Println("client出现错误, 关闭连接", err.Error())
            _ = client.conn.Close()
            _ = user.conn.Close()
            client.reConn <- true
            // 结束当前goroutine
            runtime.Goexit()
        case err := <-user.exit:
            fmt.Println("user出现错误,关闭连接", err.Error())
            _ = user.conn.Close()
        }
    }
}
// 等待user连接
func AcceptUserConn(userListener net.Listener, connChan chan net.Conn) {
    userConn, err := userListener.Accept()
    if err != nil {
        panic(err)
    }
    fmt.Printf("user connect: %s \n", userConn.RemoteAddr())
    connChan <- userConn
}

Client端

运行在需要内网穿透的客户端中

package main
import (
    "flag"
    "fmt"
    "io"
    "net"
    "runtime"
    "strings"
    "time"
)
var (
    host       string
    localPort  int
    remotePort int
)
func init() {
    flag.StringVar(&host, "h", "127.0.0.1", "remote server ip")
    flag.IntVar(&localPort, "l", 8080, "the local port")
    flag.IntVar(&remotePort, "r", 3333, "remote server port")
}
type server struct {
    conn net.Conn
    // 数据传输通道
    read  chan []byte
    write chan []byte
    // 异常退出通道
    exit chan error
    // 重连通道
    reConn chan bool
}
// 从Server端读取数据
func (s *server) Read() {
    // 如果10秒钟内没有消息传输,则Read函数会返回一个timeout的错误
    _ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 10))
    for {
        data := make([]byte, 10240)
        n, err := s.conn.Read(data)
        if err != nil && err != io.EOF {
            // 读取超时,发送一个心跳包过去
            if strings.Contains(err.Error(), "timeout") {
                // 3秒发一次心跳
                _ = s.conn.SetReadDeadline(time.Now().Add(time.Second * 3))
                s.conn.Write([]byte("pi"))
                continue
            }
            fmt.Println("从server读取数据失败, ", err.Error())
            s.exit <- err
            runtime.Goexit()
        }
        // 如果收到心跳包, 则跳过
        if data[0] == 'p' && data[1] == 'i' {
            fmt.Println("client收到心跳包")
            continue
        }
        s.read <- data[:n]
    }
}
// 将数据写入到Server端
func (s *server) Write() {
    for {
        select {
        case data := <-s.write:
            _, err := s.conn.Write(data)
            if err != nil && err != io.EOF {
                s.exit <- err
            }
        }
    }
}
type local struct {
    conn net.Conn
    // 数据传输通道
    read  chan []byte
    write chan []byte
    // 有异常退出通道
    exit chan error
}
func (l *local) Read() {
    for {
        data := make([]byte, 10240)
        n, err := l.conn.Read(data)
        if err != nil {
            l.exit <- err
        }
        l.read <- data[:n]
    }
}
func (l *local) Write() {
    for {
        select {
        case data := <-l.write:
            _, err := l.conn.Write(data)
            if err != nil {
                l.exit <- err
            }
        }
    }
}
func main() {
    flag.Parse()
    target := net.JoinHostPort(host, fmt.Sprintf("%d", remotePort))
    for {
        serverConn, err := net.Dial("tcp", target)
        if err != nil {
            panic(err)
        }
        fmt.Printf("已连接server: %s \n", serverConn.RemoteAddr())
        server := &server{
            conn:   serverConn,
            read:   make(chan []byte),
            write:  make(chan []byte),
            exit:   make(chan error),
            reConn: make(chan bool),
        }
        go server.Read()
        go server.Write()
        go handle(server)
        <-server.reConn
        _ = server.conn.Close()
    }
}
func handle(server *server) {
    // 等待server端发来的信息,也就是说user来请求server了
    data := <-server.read
    localConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
    if err != nil {
        panic(err)
    }
    local := &local{
        conn:  localConn,
        read:  make(chan []byte),
        write: make(chan []byte),
        exit:  make(chan error),
    }
    go local.Read()
    go local.Write()
    local.write <- data
    for {
        select {
        case data := <-server.read:
            local.write <- data
        case data := <-local.read:
            server.write <- data
        case err := <-server.exit:
            fmt.Printf("server have err: %s", err.Error())
            _ = server.conn.Close()
            _ = local.conn.Close()
            server.reConn <- true
        case err := <-local.exit:
            fmt.Printf("server have err: %s", err.Error())
            _ = local.conn.Close()
        }
    }
}
本作品采用《CC 协议》,转载必须注明作者和本文链接
本帖由系统于 3个月前 自动加精
讨论数量: 8

啥叫内网穿透啊

7个月前 评论
pibigstar (楼主) 7个月前

自建内网穿透的话 自己要有服务器,这样我为什么不把服务直接放在线上,个人觉得这个东西只有在自己没有服务器,使用第三方服务做穿透才有意义~ :joy:

7个月前 评论
pibigstar (楼主) 7个月前

顶一个,兄弟牛逼

3个月前 评论
pibigstar (楼主) 3个月前
wangchunbo

冲你这画画。我必须给你点赞!

1个月前 评论

@wangchunbo 灵魂级画手 :smile:

1个月前 评论

server和client之间是长连接吗

2周前 评论

请勿发布不友善或者负能量的内容。与人为善,比聪明更重要!