253 lines
6.1 KiB
Go

package server
import (
"fmt"
"io"
"log"
"net"
"sync"
"time"
"net-tunnel/pkg/proto"
)
// ProxyEntry 表示一个代理条目
type ProxyEntry struct {
Config *proto.ProxyConfig
ClientID string
TCPListener net.Listener
TCPConns sync.Map
UDPConn *net.UDPConn
UDPAddrs sync.Map // map[connID]net.UDPAddr
}
// ProxyManager 管理所有代理
type ProxyManager struct {
proxies sync.Map // map[proxyName]*ProxyEntry
clientManager *ClientManager
}
// NewProxyManager 创建新的代理管理器
func NewProxyManager(connManager *ClientManager) *ProxyManager {
return &ProxyManager{
proxies: sync.Map{},
clientManager: connManager,
}
}
// RegisterProxy 注册一个新的代理
func (m *ProxyManager) RegisterProxy(clientID string, config *proto.ProxyConfig) error {
if _, exists := m.proxies.Load(config.Name); exists {
return fmt.Errorf("proxy %s already registered", config.Name)
}
entry := &ProxyEntry{
Config: config,
ClientID: clientID,
}
// 根据代理类型启动监听器
var err error
if config.Type == proto.ProxyType_TCP {
err = m.startTCPProxy(entry)
} else if config.Type == proto.ProxyType_UDP {
err = m.startUDPProxy(entry)
} else {
return fmt.Errorf("unsupported proxy type: %v", config.Type)
}
if err != nil {
return err
}
m.proxies.Store(config.Name, entry)
log.Printf("Registered proxy: %s (type: %s, port: %d)",
config.Name, config.Type, config.RemotePort)
return nil
}
// UnregisterProxy 注销一个代理
func (m *ProxyManager) UnregisterProxy(clientID, proxyName string) {
m.closeProxy(proxyName)
}
// UnregisterAllProxies 注销客户端的所有代理
func (m *ProxyManager) UnregisterAllProxies(clientID string) {
m.proxies.Range(func(key, value any) bool {
if entry, ok := value.(*ProxyEntry); ok && entry.ClientID == clientID {
m.closeProxy(key.(string))
}
return true
})
}
// closeProxy 关闭代理
func (m *ProxyManager) closeProxy(proxyName string) {
if entry, exists := m.proxies.Load(proxyName); exists {
entry := entry.(*ProxyEntry)
if entry.TCPListener != nil {
entry.TCPListener.Close()
}
if entry.UDPConn != nil {
entry.UDPConn.Close()
}
entry.TCPConns.Range(func(key, value any) bool {
value.(net.Conn).Close()
entry.TCPConns.Delete(key)
return true
})
m.proxies.Delete(proxyName)
log.Printf("Unregistered proxy: %s", entry.Config.Name)
}
}
// startTCPProxy 启动一个 TCP 代理
func (m *ProxyManager) startTCPProxy(entry *ProxyEntry) error {
addr := fmt.Sprintf(":%d", entry.Config.RemotePort)
listener, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on port %d: %v", entry.Config.RemotePort, err)
}
entry.TCPListener = listener
// 启动协程接受连接
go m.handleTCPConnections(entry)
return nil
}
// startUDPProxy 启动一个 UDP 代理
func (m *ProxyManager) startUDPProxy(entry *ProxyEntry) error {
addr := fmt.Sprintf(":%d", entry.Config.RemotePort)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %v", addr, err)
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return fmt.Errorf("failed to listen on UDP port %d: %v", entry.Config.RemotePort, err)
}
entry.UDPConn = conn
// 启动协程接收 UDP 数据包
go m.handleUDPPackets(entry)
return nil
}
// handleTCPConnections 处理传入的 TCP 连接
func (m *ProxyManager) handleTCPConnections(entry *ProxyEntry) {
for {
conn, err := entry.TCPListener.Accept()
if err != nil {
// 监听器可能已关闭
log.Printf("TCP listener for %s closed: %v", entry.Config.Name, err)
break
}
connID := fmt.Sprintf("%s_%d", conn.RemoteAddr().String(), time.Now().UnixNano())
entry.TCPConns.Store(connID, conn)
log.Printf("TCP connection for %s accepted", connID)
go m.handleTCPConnection(entry, connID, conn)
}
}
// handleUDPPackets 处理传入的 UDP 数据包
func (m *ProxyManager) handleUDPPackets(entry *ProxyEntry) {
buffer := make([]byte, 4096)
for {
n, addr, err := entry.UDPConn.ReadFromUDP(buffer)
if err != nil {
// 连接可能已关闭
log.Printf("UDP connection for %s closed: %v", entry.Config.Name, err)
break
}
entry.UDPAddrs.Store(addr.String(), addr)
log.Printf("UDP connection for %s accepted", addr.String())
go m.handleUDPPacket(entry, buffer[:n], addr)
}
}
func (m *ProxyManager) handleTCPConnection(entry *ProxyEntry, connID string, conn net.Conn) {
if tcpConn, ok := conn.(*net.TCPConn); ok {
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetNoDelay(true)
}
defer func() {
conn.Close()
entry.TCPConns.Delete(connID)
log.Printf("TCP connection for %s closed", connID)
}()
// 获取客户端的控制连接
c, ok := m.clientManager.Get(entry.ClientID)
if !ok {
log.Printf("Control connection not found for client %s", entry.ClientID)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
// 启动数据转发
go func() {
defer wg.Done()
buffer := make([]byte, 4096)
for {
n, err := conn.Read(buffer)
if err != nil {
if err == io.EOF {
log.Printf("TCP connection for %s closed", connID)
} else {
log.Printf("Failed to read from TCP connection: %v", err)
}
return
}
// 发送数据到客户端
err = c.Send(&proto.Message{
Content: &proto.Message_ProxyData{
ProxyData: &proto.ProxyData{
ConnId: connID,
ProxyConfig: entry.Config,
Data: buffer[:n],
},
},
})
if err != nil {
log.Printf("Failed to send proxy data: %v", err)
return
}
}
}()
wg.Wait()
}
func (m *ProxyManager) handleUDPPacket(entry *ProxyEntry, data []byte, addr *net.UDPAddr) {
c, ok := m.clientManager.Get(entry.ClientID)
if !ok {
log.Printf("Control connection not found for client %s", entry.ClientID)
return
}
err := c.Send(&proto.Message{
Content: &proto.Message_ProxyData{
ProxyData: &proto.ProxyData{
ConnId: addr.String(),
ProxyConfig: entry.Config,
Data: data,
},
},
})
if err != nil {
log.Printf("Failed to send proxy data: %v", err)
return
}
}