183 lines
4.3 KiB
Go

package client
import (
"context"
"fmt"
"log"
"net"
"os"
"sync"
"time"
pb "net-tunnel/pkg/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Client 客户端
type Client struct {
serverAddr string
serverPort int
conn *grpc.ClientConn
client pb.TunnelServiceClient
stream pb.TunnelService_ConnectClient
ctx context.Context
cancel context.CancelFunc
proxies *pb.ProxyConfigs
mu sync.Mutex
tcpConnections sync.Map // map[connID]*TCPConnectionState
udpConnections sync.Map // map[connID]*net.UDPConn
}
// NewClient 创建一个新的客户端
func NewClient(serverAddr string, serverPort int) *Client {
ctx, cancel := context.WithCancel(context.Background())
return &Client{
serverAddr: serverAddr,
serverPort: serverPort,
ctx: ctx,
cancel: cancel,
proxies: &pb.ProxyConfigs{},
tcpConnections: sync.Map{},
udpConnections: sync.Map{},
}
}
// AddProxy 添加一个代理配置
func (c *Client) AddProxy(name string, proxyType pb.ProxyType, localIP string, localPort, remotePort int32) {
c.mu.Lock()
defer c.mu.Unlock()
if c.proxies.Configs == nil {
c.proxies.Configs = make(map[string]*pb.ProxyConfig)
}
c.proxies.Configs[name] = &pb.ProxyConfig{
Name: name,
Type: proxyType,
LocalIp: localIP,
LocalPort: localPort,
RemotePort: remotePort,
}
}
// Start 启动客户端
func (c *Client) Start() error {
// 连接到服务端
addr := fmt.Sprintf("%s:%d", c.serverAddr, c.serverPort)
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("failed to connect to server: %v", err)
}
c.conn = conn
c.client = pb.NewTunnelServiceClient(conn)
// 建立控制连接
if err := c.connect(); err != nil {
return err
}
return nil
}
// Stop 停止客户端
func (c *Client) Stop() {
// 关闭所有连接
c.tcpConnections.Range(func(key, value any) bool {
connState := value.(*TCPConnectionState)
connState.Close()
c.tcpConnections.Delete(key)
return true
})
c.udpConnections.Range(func(key, value any) bool {
connState := value.(*net.UDPConn)
connState.Close()
c.udpConnections.Delete(key)
return true
})
// 关闭控制连接
c.cancel()
if c.conn != nil {
c.conn.Close()
}
}
// connect 建立控制连接
func (c *Client) connect() error {
stream, err := c.client.Connect(c.ctx)
if err != nil {
return fmt.Errorf("failed to establish control connection: %v", err)
}
c.stream = stream
// 注册所有代理
if err := stream.Send(&pb.Message{
Content: &pb.Message_RegisterConfigs{
RegisterConfigs: c.proxies,
},
}); err != nil {
return fmt.Errorf("failed to send proxy config: %v", err)
}
// 接收服务端消息
go c.receiveMessages()
return nil
}
// receiveMessages 接收服务端消息
func (c *Client) receiveMessages() {
for {
msg, err := c.stream.Recv()
if err != nil {
log.Printf("Control connection closed: %v", err)
// 尝试重连
time.Sleep(5 * time.Second)
if err := c.connect(); err != nil {
log.Printf("Failed to re-establish control connection: %v", err)
return
}
return
}
switch v := msg.GetContent().(type) {
case *pb.Message_ProxyData:
switch v.ProxyData.ProxyConfig.Type {
case pb.ProxyType_TCP:
c.handleTCPData(v)
case pb.ProxyType_UDP:
c.handleUDPData(v)
}
case *pb.Message_RegisterProxiesError:
c.handleRegisterProxiesError(msg)
case *pb.Message_RegisterConfigs:
c.handleRegisterConfigs(msg)
case *pb.Message_ProxyError:
c.handleProxyError(msg)
}
}
}
func (c *Client) handleRegisterProxiesError(msg *pb.Message) {
registerProxiesError := msg.GetRegisterProxiesError()
log.Printf("Register proxies error: %v", registerProxiesError.ProxyConfig.Name)
os.Exit(1)
}
func (c *Client) handleRegisterConfigs(msg *pb.Message) {
registerConfigs := msg.GetRegisterConfigs()
for name := range registerConfigs.Configs {
log.Printf("Register config: %v", name)
}
}
func (c *Client) handleProxyError(msg *pb.Message) {
proxyError := msg.GetProxyError()
connState, ok := c.tcpConnections.Load(proxyError.ConnId)
if ok {
connState.(*TCPConnectionState).Close()
c.tcpConnections.Delete(proxyError.ConnId)
}
}