210 lines
5.2 KiB
Go
210 lines
5.2 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
pb "net-tunnel/pkg/proto"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
)
|
|
|
|
// Client 客户端
|
|
type Client struct {
|
|
serverAddr string
|
|
serverPort int
|
|
conn *grpc.ClientConn
|
|
client pb.TunnelServiceClient
|
|
stream pb.TunnelService_EstablishControlConnectionClient
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
proxies *pb.ProxyConfigs
|
|
mu sync.Mutex
|
|
connections sync.Map // map[connID]*ConnectionState
|
|
}
|
|
|
|
// NewClient 创建一个新的客户端
|
|
func NewClient(serverAddr string, serverPort int) *Client {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &Client{
|
|
serverAddr: serverAddr,
|
|
serverPort: serverPort,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
proxies: &pb.ProxyConfigs{},
|
|
connections: sync.Map{},
|
|
}
|
|
}
|
|
|
|
// AddProxy 添加一个代理配置
|
|
func (c *Client) AddProxy(name string, proxyType pb.ProxyType, localIP string, localPort, remotePort int32) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.proxies.Configs == nil {
|
|
c.proxies.Configs = make(map[string]*pb.ProxyConfig)
|
|
}
|
|
c.proxies.Configs[name] = &pb.ProxyConfig{
|
|
Name: name,
|
|
Type: proxyType,
|
|
LocalIp: localIP,
|
|
LocalPort: localPort,
|
|
RemotePort: remotePort,
|
|
}
|
|
}
|
|
|
|
// Start 启动客户端
|
|
func (c *Client) Start() error {
|
|
// 连接到服务端
|
|
addr := fmt.Sprintf("%s:%d", c.serverAddr, c.serverPort)
|
|
conn, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to server: %v", err)
|
|
}
|
|
c.conn = conn
|
|
|
|
c.client = pb.NewTunnelServiceClient(conn)
|
|
|
|
// 建立控制连接
|
|
if err := c.establishControlConnection(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop 停止客户端
|
|
func (c *Client) Stop() {
|
|
// 关闭所有连接
|
|
c.connections.Range(func(key, value interface{}) bool {
|
|
connState := value.(*ConnectionState)
|
|
connState.Close()
|
|
c.connections.Delete(key)
|
|
return true
|
|
})
|
|
|
|
// 关闭控制连接
|
|
c.cancel()
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
}
|
|
}
|
|
|
|
// establishControlConnection 建立控制连接
|
|
func (c *Client) establishControlConnection() error {
|
|
stream, err := c.client.EstablishControlConnection(c.ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to establish control connection: %v", err)
|
|
}
|
|
c.stream = stream
|
|
|
|
// 注册所有代理
|
|
if err := stream.Send(&pb.Message{
|
|
Content: &pb.Message_RegisterConfigs{
|
|
RegisterConfigs: c.proxies,
|
|
},
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to send proxy config: %v", err)
|
|
}
|
|
|
|
// 接收服务端消息
|
|
go c.receiveMessages()
|
|
|
|
return nil
|
|
}
|
|
|
|
// receiveMessages 接收服务端消息
|
|
func (c *Client) receiveMessages() {
|
|
for {
|
|
msg, err := c.stream.Recv()
|
|
if err != nil {
|
|
log.Printf("Control connection closed: %v", err)
|
|
// 尝试重连
|
|
time.Sleep(5 * time.Second)
|
|
if err := c.establishControlConnection(); err != nil {
|
|
log.Printf("Failed to re-establish control connection: %v", err)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
switch msg.GetContent().(type) {
|
|
case *pb.Message_ProxyData:
|
|
c.handleProxyData(msg)
|
|
case *pb.Message_RegisterProxiesError:
|
|
c.handleRegisterProxiesError(msg)
|
|
case *pb.Message_RegisterConfigs:
|
|
c.handleRegisterConfigs(msg)
|
|
case *pb.Message_ProxyError:
|
|
c.handleProxyError(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleProxyData(msg *pb.Message) {
|
|
proxyData := msg.GetProxyData()
|
|
log.Printf("Received proxy data for connection: %v", proxyData.ConnId)
|
|
|
|
// 处理代理数据
|
|
hostPort := net.JoinHostPort(proxyData.ProxyConfig.LocalIp, fmt.Sprintf("%d", proxyData.ProxyConfig.LocalPort))
|
|
|
|
existingConn, ok := c.connections.Load(proxyData.ConnId)
|
|
var connState *ConnectionState
|
|
if ok {
|
|
connState = existingConn.(*ConnectionState)
|
|
} else {
|
|
conn, err := net.Dial(strings.ToLower(proxyData.ProxyConfig.Type.String()), hostPort)
|
|
if err != nil {
|
|
log.Printf("Failed to connect to proxy: %v", err)
|
|
return
|
|
}
|
|
switch strings.ToLower(proxyData.ProxyConfig.Type.String()) {
|
|
case "tcp":
|
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
|
_ = tcpConn.SetKeepAlive(true)
|
|
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
|
_ = tcpConn.SetNoDelay(true)
|
|
}
|
|
}
|
|
connState = NewConnectionState(conn, proxyData.ConnId, proxyData.ProxyConfig, c.stream)
|
|
c.connections.Store(proxyData.ConnId, connState)
|
|
}
|
|
|
|
err := connState.WriteData(proxyData.Data)
|
|
if err != nil {
|
|
log.Printf("Failed to write data: %v", err)
|
|
connState.Close()
|
|
c.connections.Delete(proxyData.ConnId)
|
|
return
|
|
}
|
|
connState.StartReading()
|
|
}
|
|
|
|
func (c *Client) handleRegisterProxiesError(msg *pb.Message) {
|
|
registerProxiesError := msg.GetRegisterProxiesError()
|
|
log.Printf("Register proxies error: %v", registerProxiesError.ProxyConfig.Name)
|
|
os.Exit(1)
|
|
}
|
|
|
|
func (c *Client) handleRegisterConfigs(msg *pb.Message) {
|
|
registerConfigs := msg.GetRegisterConfigs()
|
|
for name := range registerConfigs.Configs {
|
|
log.Printf("Register config: %v", name)
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleProxyError(msg *pb.Message) {
|
|
proxyError := msg.GetProxyError()
|
|
connState, ok := c.connections.Load(proxyError.ConnId)
|
|
if ok {
|
|
connState.(*ConnectionState).Close()
|
|
c.connections.Delete(proxyError.ConnId)
|
|
}
|
|
}
|