Files
mesh-drop/internal/transfer/server.go
2026-02-04 02:21:23 +08:00

272 lines
6.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package transfer
import (
"bytes"
"fmt"
"io"
"log/slog"
"mesh-drop/internal/discovery"
"net/http"
"os"
"path/filepath"
"sync"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/wailsapp/wails/v3/pkg/application"
)
type Service struct {
app *application.App
port int
savePath string // 默认下载目录
// pendingRequests 存储等待用户确认的通道
// Key: TransferID, Value: Transfer
transferList sync.Map
discoveryService *discovery.Service
}
func NewService(app *application.App, port int, defaultSavePath string, discoveryService *discovery.Service) *Service {
gin.SetMode(gin.ReleaseMode)
return &Service{
app: app,
port: port,
savePath: defaultSavePath,
discoveryService: discoveryService,
}
}
func init() {
application.RegisterEvent[application.Void]("transfer:refreshList")
}
func (s *Service) GetPort() int {
return s.port
}
func (s *Service) Start() {
r := gin.Default()
transfer := r.Group("/transfer")
{
transfer.POST("/ask", s.handleAsk)
transfer.PUT("/upload/:id", s.handleUpload)
}
go func() {
addr := fmt.Sprintf(":%d", s.port)
slog.Info("Transfer service listening", "address", addr, "component", "transfer")
if err := r.Run(addr); err != nil {
slog.Error("Transfer service error", "error", err, "component", "transfer")
}
}()
}
// handleAsk 处理接收文件请求
func (s *Service) handleAsk(c *gin.Context) {
var task Transfer
// Gin 的 BindJSON 自动处理 JSON 解析
if err := c.ShouldBindJSON(&task); err != nil {
c.JSON(http.StatusBadRequest, TransferAskResponse{
ID: task.ID,
Message: "Invalid request",
})
return
}
// 检查是否已经存在
if _, exists := s.transferList.Load(task.ID); exists {
// 如果已经存在,说明是网络重试,直接忽略
return
}
// 存储请求
task.Type = TransferTypeReceive
task.Status = TransferStatusPending
task.DecisionChan = make(chan Decision)
s.transferList.Store(task.ID, task)
// 通知 Wails 前端
s.app.Event.Emit("transfer:refreshList")
// 等待用户决策或发送端放弃
select {
case decision := <-task.DecisionChan:
// 用户决策
if decision.Accepted {
task.Status = TransferStatusAccepted
task.SavePath = decision.SavePath
token := uuid.New().String()
task.Token = token
s.transferList.Store(task.ID, task)
} else {
task.Status = TransferStatusRejected
s.transferList.Store(task.ID, task)
}
c.JSON(http.StatusOK, TransferAskResponse{
ID: task.ID,
Accepted: decision.Accepted,
Token: task.Token,
})
case <-c.Done():
// 发送端放弃
task.Status = TransferStatusCanceled
s.transferList.Store(task.ID, task)
s.app.Event.Emit("transfer:refreshList")
}
}
// ResolvePendingRequest 外部调用,解决待处理的传输请求
// 返回 true 表示成功处理false 表示未找到该 ID 的请求
func (s *Service) ResolvePendingRequest(id string, accept bool, savePath string) bool {
val, ok := s.transferList.Load(id)
if !ok {
return false
}
task := val.(Transfer)
task.DecisionChan <- Decision{
ID: id,
Accepted: accept,
SavePath: savePath,
}
return true
}
// handleUpload 处理接收文件请求
func (s *Service) handleUpload(c *gin.Context) {
id := c.Param("id")
token := c.Query("token")
if id == "" || token == "" {
c.JSON(http.StatusBadRequest, TransferUploadResponse{
ID: id,
Message: "Invalid request: missing id or token",
})
return
}
// 获取传输任务
val, ok := s.transferList.Load(id)
if !ok {
c.JSON(http.StatusUnauthorized, TransferUploadResponse{
ID: id,
Message: "Invalid request: task not found",
})
return
}
task := val.(Transfer)
// 校验 token
if task.Token != token {
c.JSON(http.StatusUnauthorized, TransferUploadResponse{
ID: id,
Message: "Token mismatch",
})
return
}
// 校验状态
if task.Status != TransferStatusAccepted {
c.JSON(http.StatusForbidden, TransferUploadResponse{
ID: id,
Message: "Invalid task status",
})
return
}
// 更新状态为 active
task.Status = TransferStatusActive
s.transferList.Store(task.ID, task)
s.app.Event.Emit("transfer:refreshList")
savePath := task.SavePath
if savePath == "" {
savePath = s.savePath
}
switch task.ContentType {
case ContentTypeFile:
destPath := filepath.Join(savePath, task.FileName)
file, err := os.Create(destPath)
if err != nil {
// 接收方无法创建文件,直接报错,任务结束
c.JSON(http.StatusInternalServerError, TransferUploadResponse{
ID: task.ID,
Message: "Receiver failed to create file",
})
slog.Error("Failed to create file", "error", err, "component", "transfer")
task.Status = TransferStatusError
task.ErrorMsg = fmt.Errorf("receiver failed to create file: %v", err).Error()
s.transferList.Store(task.ID, task)
// 通知前端传输失败
s.app.Event.Emit("transfer:refreshList")
return
}
defer file.Close()
s.receive(c, &task, file)
case ContentTypeText:
var buf bytes.Buffer
s.receive(c, &task, &buf)
task.Text = buf.String()
s.transferList.Store(task.ID, task)
s.app.Event.Emit("transfer:refreshList")
case ContentTypeFolder:
// s.receiveFolder(c, savePath, task)
}
}
func (s *Service) receive(c *gin.Context, task *Transfer, writer io.Writer) {
// 包装 reader用于计算进度
reader := &PassThroughReader{
Reader: c.Request.Body,
total: task.FileSize,
callback: func(current, total int64, speed float64) {
task.Progress = Progress{
Current: current,
Total: total,
Speed: speed,
}
s.transferList.Store(task.ID, task)
s.app.Event.Emit("transfer:refreshList")
},
}
_, err := io.Copy(writer, reader)
if err != nil {
// 文件写入失败,直接报错,任务结束
c.JSON(http.StatusInternalServerError, TransferUploadResponse{
ID: task.ID,
Message: "Failed to write file",
})
slog.Error("Failed to write file", "error", err, "component", "transfer")
task.Status = TransferStatusError
task.ErrorMsg = fmt.Errorf("failed to write file: %v", err).Error()
s.transferList.Store(task.ID, task)
// 通知前端传输失败
s.app.Event.Emit("transfer:refreshList")
return
}
c.JSON(http.StatusOK, TransferUploadResponse{
ID: task.ID,
Message: "File received successfully",
})
// 传输成功,任务结束
task.Status = TransferStatusCompleted
s.transferList.Store(task.ID, task)
s.app.Event.Emit("transfer:refreshList")
}
func (s *Service) GetTransferList() []Transfer {
var requests []Transfer
s.transferList.Range(func(key, value any) bool {
requests = append(requests, value.(Transfer))
return true
})
return requests
}