package client import ( "context" "fmt" "log" "net" "os" "sync" "time" pb "net-tunnel/pkg/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) // Client 客户端 type Client struct { serverAddr string serverPort int conn *grpc.ClientConn client pb.TunnelServiceClient stream pb.TunnelService_ConnectClient ctx context.Context cancel context.CancelFunc proxies *pb.ProxyConfigs mu sync.Mutex tcpConnections sync.Map // map[connID]*TCPConnectionState udpConnections sync.Map // map[connID]*net.UDPConn } // NewClient 创建一个新的客户端 func NewClient(serverAddr string, serverPort int) *Client { ctx, cancel := context.WithCancel(context.Background()) return &Client{ serverAddr: serverAddr, serverPort: serverPort, ctx: ctx, cancel: cancel, proxies: &pb.ProxyConfigs{}, tcpConnections: sync.Map{}, udpConnections: sync.Map{}, } } // AddProxy 添加一个代理配置 func (c *Client) AddProxy(name string, proxyType pb.ProxyType, localIP string, localPort, remotePort int32) { c.mu.Lock() defer c.mu.Unlock() if c.proxies.Configs == nil { c.proxies.Configs = make(map[string]*pb.ProxyConfig) } c.proxies.Configs[name] = &pb.ProxyConfig{ Name: name, Type: proxyType, LocalIp: localIP, LocalPort: localPort, RemotePort: remotePort, } } // Start 启动客户端 func (c *Client) Start() error { // 连接到服务端 addr := fmt.Sprintf("%s:%d", c.serverAddr, c.serverPort) conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return fmt.Errorf("failed to connect to server: %v", err) } c.conn = conn c.client = pb.NewTunnelServiceClient(conn) // 建立控制连接 if err := c.connect(); err != nil { return err } return nil } // Stop 停止客户端 func (c *Client) Stop() { // 关闭所有连接 c.tcpConnections.Range(func(key, value any) bool { connState := value.(*TCPConnectionState) connState.Close() c.tcpConnections.Delete(key) return true }) c.udpConnections.Range(func(key, value any) bool { connState := value.(*net.UDPConn) connState.Close() c.udpConnections.Delete(key) return true }) // 关闭控制连接 c.cancel() if c.conn != nil { c.conn.Close() } } // connect 建立控制连接 func (c *Client) connect() error { stream, err := c.client.Connect(c.ctx) if err != nil { return fmt.Errorf("failed to establish control connection: %v", err) } c.stream = stream // 注册所有代理 if err := stream.Send(&pb.Message{ Content: &pb.Message_RegisterConfigs{ RegisterConfigs: c.proxies, }, }); err != nil { return fmt.Errorf("failed to send proxy config: %v", err) } // 接收服务端消息 go c.receiveMessages() return nil } // receiveMessages 接收服务端消息 func (c *Client) receiveMessages() { for { msg, err := c.stream.Recv() if err != nil { log.Printf("Control connection closed: %v", err) // 尝试重连 time.Sleep(5 * time.Second) if err := c.connect(); err != nil { log.Printf("Failed to re-establish control connection: %v", err) return } return } switch v := msg.GetContent().(type) { case *pb.Message_ProxyData: switch v.ProxyData.ProxyConfig.Type { case pb.ProxyType_TCP: c.handleTCPData(v) case pb.ProxyType_UDP: c.handleUDPData(v) } case *pb.Message_RegisterProxiesError: c.handleRegisterProxiesError(msg) case *pb.Message_RegisterConfigs: c.handleRegisterConfigs(msg) case *pb.Message_ProxyError: c.handleProxyError(msg) } } } func (c *Client) handleRegisterProxiesError(msg *pb.Message) { registerProxiesError := msg.GetRegisterProxiesError() log.Printf("Register proxies error: %v", registerProxiesError.ProxyConfig.Name) os.Exit(1) } func (c *Client) handleRegisterConfigs(msg *pb.Message) { registerConfigs := msg.GetRegisterConfigs() for name := range registerConfigs.Configs { log.Printf("Register config: %v", name) } } func (c *Client) handleProxyError(msg *pb.Message) { proxyError := msg.GetProxyError() connState, ok := c.tcpConnections.Load(proxyError.ConnId) if ok { connState.(*TCPConnectionState).Close() c.tcpConnections.Delete(proxyError.ConnId) } }