221 lines
5.8 KiB
Go
221 lines
5.8 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"time"
|
|
|
|
pb "net-tunnel/pkg/proto"
|
|
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
// Server 表示服务端
|
|
type Server struct {
|
|
pb.UnimplementedTunnelServiceServer
|
|
clientManager *ClientManager
|
|
proxyManager *ProxyManager
|
|
grpcServer *grpc.Server
|
|
bindAddr string
|
|
bindPort int
|
|
}
|
|
|
|
// NewServer 创建一个新的服务端
|
|
func NewServer(bindAddr string, bindPort int) *Server {
|
|
clientManager := NewClientManager()
|
|
return &Server{
|
|
clientManager: clientManager,
|
|
proxyManager: NewProxyManager(clientManager),
|
|
bindAddr: bindAddr,
|
|
bindPort: bindPort,
|
|
}
|
|
}
|
|
|
|
// Start 启动服务端
|
|
func (s *Server) Start() error {
|
|
addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)
|
|
lis, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen: %v", err)
|
|
}
|
|
|
|
s.grpcServer = grpc.NewServer()
|
|
pb.RegisterTunnelServiceServer(s.grpcServer, s)
|
|
|
|
log.Printf("Server started, listening on %s", addr)
|
|
return s.grpcServer.Serve(lis)
|
|
}
|
|
|
|
// Stop 停止服务端
|
|
func (s *Server) Stop() {
|
|
if s.grpcServer != nil {
|
|
s.grpcServer.GracefulStop()
|
|
}
|
|
}
|
|
|
|
// Connect 实现 gRPC 服务方法
|
|
func (s *Server) Connect(stream pb.TunnelService_ConnectServer) error {
|
|
// 创建控制连接
|
|
clientID := "client_" + fmt.Sprint(time.Now().UnixNano())
|
|
conn := NewClientConnection(stream)
|
|
|
|
// 添加客户端
|
|
s.clientManager.Add(clientID, conn)
|
|
defer s.clientManager.Remove(clientID)
|
|
|
|
log.Printf("New control connection established: %s", clientID)
|
|
|
|
// 接收客户端消息
|
|
for {
|
|
msg, err := stream.Recv()
|
|
if err != nil {
|
|
log.Printf("Control connection closed: %v", err)
|
|
// 清理该客户端的代理
|
|
s.proxyManager.UnregisterAllProxies(clientID)
|
|
return err
|
|
}
|
|
|
|
// 处理不同类型的消息
|
|
switch content := msg.GetContent().(type) {
|
|
case *pb.Message_RegisterConfigs:
|
|
// 处理代理配置注册
|
|
s.handleProxyRegister(clientID, content)
|
|
case *pb.Message_ProxyData:
|
|
s.handleProxyData(clientID, content)
|
|
default:
|
|
log.Printf("收到未知类型的消息")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleProxyRegister(clientID string, msg *pb.Message_RegisterConfigs) {
|
|
hasError := false
|
|
|
|
conn, ok := s.clientManager.Get(clientID)
|
|
if !ok {
|
|
log.Printf("Control connection not found for client %s", clientID)
|
|
return
|
|
}
|
|
|
|
for _, config := range msg.RegisterConfigs.GetConfigs() {
|
|
if err := s.proxyManager.RegisterProxy(clientID, config); err != nil {
|
|
log.Printf("Failed to register proxy %s: %v", config.Name, err)
|
|
msg := &pb.Message_RegisterProxiesError{
|
|
RegisterProxiesError: &pb.RegisterProxiesError{
|
|
ProxyConfig: config,
|
|
Error: err.Error(),
|
|
},
|
|
}
|
|
|
|
if err := conn.Send(&pb.Message{
|
|
Content: msg,
|
|
}); err != nil {
|
|
log.Printf("Failed to send message to client %s: %v", clientID, err)
|
|
}
|
|
hasError = true
|
|
break
|
|
}
|
|
}
|
|
if !hasError {
|
|
if err := conn.Send(&pb.Message{
|
|
Content: msg,
|
|
}); err != nil {
|
|
log.Printf("Failed to send message to client %s: %v", clientID, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleProxyData(clientID string, msg *pb.Message_ProxyData) {
|
|
switch msg.ProxyData.ProxyConfig.Type {
|
|
case pb.ProxyType_TCP:
|
|
s.handleTCPProxyData(clientID, msg)
|
|
case pb.ProxyType_UDP:
|
|
s.handleUDPProxyData(clientID, msg)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleUDPProxyData(clientID string, msg *pb.Message_ProxyData) {
|
|
proxyEntry, ok := s.proxyManager.proxies.Load(msg.ProxyData.ProxyConfig.Name)
|
|
if !ok {
|
|
log.Printf("Proxy %s not found", msg.ProxyData.ProxyConfig.Name)
|
|
return
|
|
}
|
|
entry := proxyEntry.(*ProxyEntry)
|
|
addr, ok := entry.UDPAddrs.Load(msg.ProxyData.ConnId)
|
|
if !ok {
|
|
log.Printf("UDP connection %s not found", msg.ProxyData.ConnId)
|
|
return
|
|
}
|
|
_, err := entry.UDPConn.WriteToUDP(msg.ProxyData.Data, addr.(*net.UDPAddr))
|
|
if err != nil {
|
|
log.Printf("Failed to write data to UDP connection: %v", err)
|
|
entry.UDPAddrs.Delete(msg.ProxyData.ConnId)
|
|
controlConn, ok := s.clientManager.Get(clientID)
|
|
if !ok {
|
|
log.Printf("Control connection not found for client %s", clientID)
|
|
return
|
|
}
|
|
if err := controlConn.Send(&pb.Message{
|
|
Content: &pb.Message_ProxyError{
|
|
ProxyError: &pb.ProxyError{
|
|
ProxyConfig: msg.ProxyData.ProxyConfig,
|
|
ConnId: msg.ProxyData.ConnId,
|
|
Error: "Failed to write data to UDP connection",
|
|
},
|
|
},
|
|
}); err != nil {
|
|
log.Printf("Failed to send message to client %s: %v", clientID, err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleTCPProxyData(clientID string, msg *pb.Message_ProxyData) {
|
|
proxyEntry, ok := s.proxyManager.proxies.Load(msg.ProxyData.ProxyConfig.Name)
|
|
if !ok {
|
|
log.Printf("Proxy %s not found", msg.ProxyData.ProxyConfig.Name)
|
|
return
|
|
}
|
|
entry := proxyEntry.(*ProxyEntry)
|
|
conn, ok := entry.TCPConns.Load(msg.ProxyData.ConnId)
|
|
if !ok {
|
|
log.Printf("TCP connection %s not found", msg.ProxyData.ConnId)
|
|
controlConn, ok := s.clientManager.Get(clientID)
|
|
if ok {
|
|
if err := controlConn.Send(&pb.Message{
|
|
Content: &pb.Message_ProxyError{
|
|
ProxyError: &pb.ProxyError{
|
|
ProxyConfig: msg.ProxyData.ProxyConfig,
|
|
ConnId: msg.ProxyData.ConnId,
|
|
Error: "TCP connection not found",
|
|
},
|
|
},
|
|
}); err != nil {
|
|
log.Printf("Failed to send message to client %s: %v", clientID, err)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
_, err := conn.(net.Conn).Write(msg.ProxyData.Data)
|
|
if err != nil {
|
|
log.Printf("Failed to write data to TCP connection: %v", err)
|
|
conn.(net.Conn).Close()
|
|
entry.TCPConns.Delete(msg.ProxyData.ConnId)
|
|
controlConn, ok := s.clientManager.Get(clientID)
|
|
if ok {
|
|
if err := controlConn.Send(&pb.Message{
|
|
Content: &pb.Message_ProxyError{
|
|
ProxyError: &pb.ProxyError{
|
|
ProxyConfig: msg.ProxyData.ProxyConfig,
|
|
ConnId: msg.ProxyData.ConnId,
|
|
Error: "Failed to write data to TCP connection",
|
|
},
|
|
},
|
|
}); err != nil {
|
|
log.Printf("Failed to send message to client %s: %v", clientID, err)
|
|
}
|
|
}
|
|
}
|
|
}
|