package client import ( "context" "fmt" "log" "net" "os" "strings" "sync" "time" pb "net-tunnel/pkg/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) // Client 客户端 type Client struct { serverAddr string serverPort int conn *grpc.ClientConn client pb.TunnelServiceClient stream pb.TunnelService_EstablishControlConnectionClient ctx context.Context cancel context.CancelFunc proxies *pb.ProxyConfigs mu sync.Mutex connections sync.Map // map[connID]*ConnectionState } // NewClient 创建一个新的客户端 func NewClient(serverAddr string, serverPort int) *Client { ctx, cancel := context.WithCancel(context.Background()) return &Client{ serverAddr: serverAddr, serverPort: serverPort, ctx: ctx, cancel: cancel, proxies: &pb.ProxyConfigs{}, connections: sync.Map{}, } } // AddProxy 添加一个代理配置 func (c *Client) AddProxy(name string, proxyType pb.ProxyType, localIP string, localPort, remotePort int32) { c.mu.Lock() defer c.mu.Unlock() if c.proxies.Configs == nil { c.proxies.Configs = make(map[string]*pb.ProxyConfig) } c.proxies.Configs[name] = &pb.ProxyConfig{ Name: name, Type: proxyType, LocalIp: localIP, LocalPort: localPort, RemotePort: remotePort, } } // Start 启动客户端 func (c *Client) Start() error { // 连接到服务端 addr := fmt.Sprintf("%s:%d", c.serverAddr, c.serverPort) conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { return fmt.Errorf("failed to connect to server: %v", err) } c.conn = conn c.client = pb.NewTunnelServiceClient(conn) // 建立控制连接 if err := c.establishControlConnection(); err != nil { return err } return nil } // Stop 停止客户端 func (c *Client) Stop() { // 关闭所有连接 c.connections.Range(func(key, value interface{}) bool { connState := value.(*ConnectionState) connState.Close() c.connections.Delete(key) return true }) // 关闭控制连接 c.cancel() if c.conn != nil { c.conn.Close() } } // establishControlConnection 建立控制连接 func (c *Client) establishControlConnection() error { stream, err := c.client.EstablishControlConnection(c.ctx) if err != nil { return fmt.Errorf("failed to establish control connection: %v", err) } c.stream = stream // 注册所有代理 if err := stream.Send(&pb.Message{ Content: &pb.Message_RegisterConfigs{ RegisterConfigs: c.proxies, }, }); err != nil { return fmt.Errorf("failed to send proxy config: %v", err) } // 接收服务端消息 go c.receiveMessages() return nil } // receiveMessages 接收服务端消息 func (c *Client) receiveMessages() { for { msg, err := c.stream.Recv() if err != nil { log.Printf("Control connection closed: %v", err) // 尝试重连 time.Sleep(5 * time.Second) if err := c.establishControlConnection(); err != nil { log.Printf("Failed to re-establish control connection: %v", err) return } return } switch msg.GetContent().(type) { case *pb.Message_ProxyData: c.handleProxyData(msg) case *pb.Message_RegisterProxiesError: c.handleRegisterProxiesError(msg) case *pb.Message_RegisterConfigs: c.handleRegisterConfigs(msg) case *pb.Message_ProxyError: c.handleProxyError(msg) } } } func (c *Client) handleProxyData(msg *pb.Message) { proxyData := msg.GetProxyData() log.Printf("Received proxy data for connection: %v", proxyData.ConnId) // 处理代理数据 hostPort := net.JoinHostPort(proxyData.ProxyConfig.LocalIp, fmt.Sprintf("%d", proxyData.ProxyConfig.LocalPort)) existingConn, ok := c.connections.Load(proxyData.ConnId) var connState *ConnectionState if ok { connState = existingConn.(*ConnectionState) } else { conn, err := net.Dial(strings.ToLower(proxyData.ProxyConfig.Type.String()), hostPort) if err != nil { log.Printf("Failed to connect to proxy: %v", err) return } switch strings.ToLower(proxyData.ProxyConfig.Type.String()) { case "tcp": if tcpConn, ok := conn.(*net.TCPConn); ok { _ = tcpConn.SetKeepAlive(true) _ = tcpConn.SetKeepAlivePeriod(30 * time.Second) _ = tcpConn.SetNoDelay(true) } } connState = NewConnectionState(conn, proxyData.ConnId, proxyData.ProxyConfig, c.stream) c.connections.Store(proxyData.ConnId, connState) } err := connState.WriteData(proxyData.Data) if err != nil { log.Printf("Failed to write data: %v", err) connState.Close() c.connections.Delete(proxyData.ConnId) return } connState.StartReading() } func (c *Client) handleRegisterProxiesError(msg *pb.Message) { registerProxiesError := msg.GetRegisterProxiesError() log.Printf("Register proxies error: %v", registerProxiesError.ProxyConfig.Name) os.Exit(1) } func (c *Client) handleRegisterConfigs(msg *pb.Message) { registerConfigs := msg.GetRegisterConfigs() for name := range registerConfigs.Configs { log.Printf("Register config: %v", name) } } func (c *Client) handleProxyError(msg *pb.Message) { proxyError := msg.GetProxyError() connState, ok := c.connections.Load(proxyError.ConnId) if ok { connState.(*ConnectionState).Close() c.connections.Delete(proxyError.ConnId) } }