2025-03-21 22:16:43 +11:00

210 lines
5.2 KiB
Go

package client
import (
"context"
"fmt"
"log"
"net"
"os"
"strings"
"sync"
"time"
pb "net-tunnel/pkg/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Client 客户端
type Client struct {
serverAddr string
serverPort int
conn *grpc.ClientConn
client pb.TunnelServiceClient
stream pb.TunnelService_EstablishControlConnectionClient
ctx context.Context
cancel context.CancelFunc
proxies *pb.ProxyConfigs
mu sync.Mutex
connections sync.Map // map[connID]*ConnectionState
}
// NewClient 创建一个新的客户端
func NewClient(serverAddr string, serverPort int) *Client {
ctx, cancel := context.WithCancel(context.Background())
return &Client{
serverAddr: serverAddr,
serverPort: serverPort,
ctx: ctx,
cancel: cancel,
proxies: &pb.ProxyConfigs{},
connections: sync.Map{},
}
}
// AddProxy 添加一个代理配置
func (c *Client) AddProxy(name string, proxyType pb.ProxyType, localIP string, localPort, remotePort int32) {
c.mu.Lock()
defer c.mu.Unlock()
if c.proxies.Configs == nil {
c.proxies.Configs = make(map[string]*pb.ProxyConfig)
}
c.proxies.Configs[name] = &pb.ProxyConfig{
Name: name,
Type: proxyType,
LocalIp: localIP,
LocalPort: localPort,
RemotePort: remotePort,
}
}
// Start 启动客户端
func (c *Client) Start() error {
// 连接到服务端
addr := fmt.Sprintf("%s:%d", c.serverAddr, c.serverPort)
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("failed to connect to server: %v", err)
}
c.conn = conn
c.client = pb.NewTunnelServiceClient(conn)
// 建立控制连接
if err := c.establishControlConnection(); err != nil {
return err
}
return nil
}
// Stop 停止客户端
func (c *Client) Stop() {
// 关闭所有连接
c.connections.Range(func(key, value interface{}) bool {
connState := value.(*ConnectionState)
connState.Close()
c.connections.Delete(key)
return true
})
// 关闭控制连接
c.cancel()
if c.conn != nil {
c.conn.Close()
}
}
// establishControlConnection 建立控制连接
func (c *Client) establishControlConnection() error {
stream, err := c.client.EstablishControlConnection(c.ctx)
if err != nil {
return fmt.Errorf("failed to establish control connection: %v", err)
}
c.stream = stream
// 注册所有代理
if err := stream.Send(&pb.Message{
Content: &pb.Message_RegisterConfigs{
RegisterConfigs: c.proxies,
},
}); err != nil {
return fmt.Errorf("failed to send proxy config: %v", err)
}
// 接收服务端消息
go c.receiveMessages()
return nil
}
// receiveMessages 接收服务端消息
func (c *Client) receiveMessages() {
for {
msg, err := c.stream.Recv()
if err != nil {
log.Printf("Control connection closed: %v", err)
// 尝试重连
time.Sleep(5 * time.Second)
if err := c.establishControlConnection(); err != nil {
log.Printf("Failed to re-establish control connection: %v", err)
return
}
return
}
switch msg.GetContent().(type) {
case *pb.Message_ProxyData:
c.handleProxyData(msg)
case *pb.Message_RegisterProxiesError:
c.handleRegisterProxiesError(msg)
case *pb.Message_RegisterConfigs:
c.handleRegisterConfigs(msg)
case *pb.Message_ProxyError:
c.handleProxyError(msg)
}
}
}
func (c *Client) handleProxyData(msg *pb.Message) {
proxyData := msg.GetProxyData()
log.Printf("Received proxy data for connection: %v", proxyData.ConnId)
// 处理代理数据
hostPort := net.JoinHostPort(proxyData.ProxyConfig.LocalIp, fmt.Sprintf("%d", proxyData.ProxyConfig.LocalPort))
existingConn, ok := c.connections.Load(proxyData.ConnId)
var connState *ConnectionState
if ok {
connState = existingConn.(*ConnectionState)
} else {
conn, err := net.Dial(strings.ToLower(proxyData.ProxyConfig.Type.String()), hostPort)
if err != nil {
log.Printf("Failed to connect to proxy: %v", err)
return
}
switch strings.ToLower(proxyData.ProxyConfig.Type.String()) {
case "tcp":
if tcpConn, ok := conn.(*net.TCPConn); ok {
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetNoDelay(true)
}
}
connState = NewConnectionState(conn, proxyData.ConnId, proxyData.ProxyConfig, c.stream)
c.connections.Store(proxyData.ConnId, connState)
}
err := connState.WriteData(proxyData.Data)
if err != nil {
log.Printf("Failed to write data: %v", err)
connState.Close()
c.connections.Delete(proxyData.ConnId)
return
}
connState.StartReading()
}
func (c *Client) handleRegisterProxiesError(msg *pb.Message) {
registerProxiesError := msg.GetRegisterProxiesError()
log.Printf("Register proxies error: %v", registerProxiesError.ProxyConfig.Name)
os.Exit(1)
}
func (c *Client) handleRegisterConfigs(msg *pb.Message) {
registerConfigs := msg.GetRegisterConfigs()
for name := range registerConfigs.Configs {
log.Printf("Register config: %v", name)
}
}
func (c *Client) handleProxyError(msg *pb.Message) {
proxyError := msg.GetProxyError()
connState, ok := c.connections.Load(proxyError.ConnId)
if ok {
connState.(*ConnectionState).Close()
c.connections.Delete(proxyError.ConnId)
}
}