2025-03-21 22:16:43 +11:00

178 lines
4.6 KiB
Go

package server
import (
"fmt"
"log"
"net"
"time"
pb "net-tunnel/pkg/proto"
"google.golang.org/grpc"
)
// Server 表示服务端
type Server struct {
pb.UnimplementedTunnelServiceServer
connManager *ConnectionManager
proxyManager *ProxyManager
grpcServer *grpc.Server
bindAddr string
bindPort int
}
// NewServer 创建一个新的服务端
func NewServer(bindAddr string, bindPort int) *Server {
connManager := NewConnectionManager()
return &Server{
connManager: connManager,
proxyManager: NewProxyManager(connManager),
bindAddr: bindAddr,
bindPort: bindPort,
}
}
// Start 启动服务端
func (s *Server) Start() error {
addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)
lis, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen: %v", err)
}
s.grpcServer = grpc.NewServer()
pb.RegisterTunnelServiceServer(s.grpcServer, s)
log.Printf("Server started, listening on %s", addr)
return s.grpcServer.Serve(lis)
}
// Stop 停止服务端
func (s *Server) Stop() {
if s.grpcServer != nil {
s.grpcServer.GracefulStop()
}
}
// EstablishControlConnection 实现 gRPC 服务方法
func (s *Server) EstablishControlConnection(stream pb.TunnelService_EstablishControlConnectionServer) error {
// 创建控制连接
clientID := "client_" + fmt.Sprint(time.Now().UnixNano())
conn := NewControlConnection(stream)
s.connManager.AddConnection(clientID, conn)
defer s.connManager.RemoveConnection(clientID)
log.Printf("New control connection established: %s", clientID)
// 接收客户端消息
for {
msg, err := stream.Recv()
if err != nil {
log.Printf("Control connection closed: %v", err)
// 清理该客户端的代理
s.proxyManager.UnregisterAllProxies(clientID)
return err
}
// 处理不同类型的消息
switch content := msg.GetContent().(type) {
case *pb.Message_RegisterConfigs:
// 处理代理配置注册
s.handleProxyRegister(clientID, content)
case *pb.Message_ProxyData:
s.handleProxyData(clientID, content)
default:
log.Printf("收到未知类型的消息")
}
}
}
func (s *Server) handleProxyRegister(clientID string, msg *pb.Message_RegisterConfigs) {
hasError := false
conn, ok := s.connManager.GetConnection(clientID)
if !ok {
log.Printf("Control connection not found for client %s", clientID)
return
}
for _, config := range msg.RegisterConfigs.GetConfigs() {
if err := s.proxyManager.RegisterProxy(clientID, config); err != nil {
log.Printf("Failed to register proxy %s: %v", config.Name, err)
msg := &pb.Message_RegisterProxiesError{
RegisterProxiesError: &pb.RegisterProxiesError{
ProxyConfig: config,
Error: err.Error(),
},
}
if err := conn.Send(&pb.Message{
Content: msg,
}); err != nil {
log.Printf("Failed to send message to client %s: %v", clientID, err)
}
hasError = true
break
}
}
if !hasError {
if err := conn.Send(&pb.Message{
Content: msg,
}); err != nil {
log.Printf("Failed to send message to client %s: %v", clientID, err)
}
}
}
func (s *Server) handleProxyData(clientID string, msg *pb.Message_ProxyData) {
proxyEntry, ok := s.proxyManager.proxies.Load(msg.ProxyData.ProxyConfig.Name)
if !ok {
log.Printf("Proxy %s not found", msg.ProxyData.ProxyConfig.Name)
return
}
entry := proxyEntry.(*ProxyEntry)
switch entry.Config.Type {
case pb.ProxyType_TCP:
conn, ok := entry.TCPConns.Load(msg.ProxyData.ConnId)
if !ok {
log.Printf("TCP connection %s not found", msg.ProxyData.ConnId)
controlConn, ok := s.connManager.GetConnection(clientID)
if ok {
if err := controlConn.Send(&pb.Message{
Content: &pb.Message_ProxyError{
ProxyError: &pb.ProxyError{
ProxyConfig: msg.ProxyData.ProxyConfig,
ConnId: msg.ProxyData.ConnId,
Error: "TCP connection not found",
},
},
}); err != nil {
log.Printf("Failed to send message to client %s: %v", clientID, err)
}
}
return
}
_, err := conn.(net.Conn).Write(msg.ProxyData.Data)
if err != nil {
log.Printf("Failed to write data to TCP connection: %v", err)
conn.(net.Conn).Close()
entry.TCPConns.Delete(msg.ProxyData.ConnId)
controlConn, ok := s.connManager.GetConnection(clientID)
if ok {
if err := controlConn.Send(&pb.Message{
Content: &pb.Message_ProxyError{
ProxyError: &pb.ProxyError{
ProxyConfig: msg.ProxyData.ProxyConfig,
ConnId: msg.ProxyData.ConnId,
Error: "Failed to write data to TCP connection",
},
},
}); err != nil {
log.Printf("Failed to send message to client %s: %v", clientID, err)
}
}
}
}
}