add: cancel transfer

This commit is contained in:
2026-02-04 15:06:41 +08:00
parent c2f3c2c3df
commit 68533dad31
9 changed files with 529 additions and 221 deletions

View File

@@ -3,74 +3,24 @@ package transfer
import (
"archive/tar"
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
"mesh-drop/internal/discovery"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/wailsapp/wails/v3/pkg/application"
)
type Service struct {
app *application.App
port int
savePath string // 默认下载目录
// pendingRequests 存储等待用户确认的通道
// Key: TransferID, Value: Transfer
transferList sync.Map
discoveryService *discovery.Service
}
func NewService(app *application.App, port int, defaultSavePath string, discoveryService *discovery.Service) *Service {
gin.SetMode(gin.ReleaseMode)
return &Service{
app: app,
port: port,
savePath: defaultSavePath,
discoveryService: discoveryService,
}
}
func init() {
application.RegisterEvent[application.Void]("transfer:refreshList")
}
func (s *Service) GetPort() int {
return s.port
}
func (s *Service) Start() {
r := gin.Default()
transfer := r.Group("/transfer")
{
transfer.POST("/ask", s.handleAsk)
transfer.PUT("/upload/:id", s.handleUpload)
}
go func() {
addr := fmt.Sprintf(":%d", s.port)
slog.Info("Transfer service listening", "address", addr, "component", "transfer")
if err := r.Run(addr); err != nil {
slog.Error("Transfer service error", "error", err, "component", "transfer")
}
}()
}
// handleAsk 处理接收文件请求
func (s *Service) handleAsk(c *gin.Context) {
var task Transfer
// Gin 的 BindJSON 自动处理 JSON 解析
if err := c.ShouldBindJSON(&task); err != nil {
c.JSON(http.StatusBadRequest, TransferAskResponse{
ID: task.ID,
@@ -113,7 +63,8 @@ func (s *Service) handleAsk(c *gin.Context) {
Accepted: decision.Accepted,
Token: task.Token,
})
case <-c.Done():
s.app.Event.Emit("transfer:refreshList")
case <-c.Request.Context().Done():
// 发送端放弃
task.Status = TransferStatusCanceled
s.transferList.Store(task.ID, task)
@@ -147,6 +98,7 @@ func (s *Service) handleUpload(c *gin.Context) {
c.JSON(http.StatusBadRequest, TransferUploadResponse{
ID: id,
Message: "Invalid request: missing id or token",
Status: TransferStatusError,
})
return
}
@@ -157,16 +109,24 @@ func (s *Service) handleUpload(c *gin.Context) {
c.JSON(http.StatusUnauthorized, TransferUploadResponse{
ID: id,
Message: "Invalid request: task not found",
Status: TransferStatusError,
})
return
}
task := val.(Transfer)
ctx, cancel := context.WithCancel(c.Request.Context())
s.cancelMap.Store(task.ID, cancel)
defer func() {
s.cancelMap.Delete(task.ID)
cancel()
}()
// 校验 token
if task.Token != token {
c.JSON(http.StatusUnauthorized, TransferUploadResponse{
ID: id,
Message: "Token mismatch",
Status: TransferStatusError,
})
return
}
@@ -176,6 +136,7 @@ func (s *Service) handleUpload(c *gin.Context) {
c.JSON(http.StatusForbidden, TransferUploadResponse{
ID: id,
Message: "Invalid task status",
Status: TransferStatusError,
})
return
}
@@ -190,6 +151,11 @@ func (s *Service) handleUpload(c *gin.Context) {
savePath = s.savePath
}
ctxReader := &ContextReader{
ctx: ctx,
r: c.Request.Body,
}
switch task.ContentType {
case ContentTypeFile:
destPath := filepath.Join(savePath, task.FileName)
@@ -199,6 +165,7 @@ func (s *Service) handleUpload(c *gin.Context) {
c.JSON(http.StatusInternalServerError, TransferUploadResponse{
ID: task.ID,
Message: "Receiver failed to create file",
Status: TransferStatusError,
})
slog.Error("Failed to create file", "error", err, "component", "transfer")
task.Status = TransferStatusError
@@ -209,22 +176,22 @@ func (s *Service) handleUpload(c *gin.Context) {
return
}
defer file.Close()
s.receive(c, &task, file)
s.receive(c, &task, file, ctxReader)
case ContentTypeText:
var buf bytes.Buffer
s.receive(c, &task, &buf)
s.receive(c, &task, &buf, ctxReader)
task.Text = buf.String()
s.transferList.Store(task.ID, task)
s.app.Event.Emit("transfer:refreshList")
case ContentTypeFolder:
s.receiveFolder(c, savePath, &task)
s.receiveFolder(c, savePath, &task, ctxReader)
}
}
func (s *Service) receive(c *gin.Context, task *Transfer, writer io.Writer) {
func (s *Service) receive(c *gin.Context, task *Transfer, writer io.Writer, ctxReader io.Reader) {
// 包装 reader用于计算进度
reader := &PassThroughReader{
Reader: c.Request.Body,
Reader: ctxReader,
total: task.FileSize,
callback: func(current, total int64, speed float64) {
task.Progress = Progress{
@@ -239,16 +206,42 @@ func (s *Service) receive(c *gin.Context, task *Transfer, writer io.Writer) {
_, err := io.Copy(writer, reader)
if err != nil {
// 文件写入失败,直接报错,任务结束
// 发送端断线,任务取消
if c.Request.Context().Err() != nil {
slog.Info("Sender canceled transfer (Network/Context disconnected)", "id", task.ID, "raw_err", err)
task.ErrorMsg = "Sender disconnected"
task.Status = TransferStatusCanceled
s.transferList.Store(task.ID, *task)
s.app.Event.Emit("transfer:refreshList")
return
}
// 用户取消传输
if errors.Is(err, context.Canceled) {
slog.Info("User canceled transfer", "component", "transfer")
task.ErrorMsg = "User canceled transfer"
task.Status = TransferStatusCanceled
// 通知发送端
c.JSON(http.StatusOK, TransferUploadResponse{
ID: task.ID,
Message: "File transfer canceled",
Status: TransferStatusCanceled,
})
s.transferList.Store(task.ID, *task)
s.app.Event.Emit("transfer:refreshList")
return
}
// 接收端写文件失败
c.JSON(http.StatusInternalServerError, TransferUploadResponse{
ID: task.ID,
Message: "Failed to write file",
Status: TransferStatusError,
})
slog.Error("Failed to write file", "error", err, "component", "transfer")
task.Status = TransferStatusError
task.ErrorMsg = fmt.Errorf("failed to write file: %v", err).Error()
s.transferList.Store(task.ID, *task)
// 通知前端传输失败
s.app.Event.Emit("transfer:refreshList")
return
}
@@ -256,6 +249,7 @@ func (s *Service) receive(c *gin.Context, task *Transfer, writer io.Writer) {
c.JSON(http.StatusOK, TransferUploadResponse{
ID: task.ID,
Message: "File received successfully",
Status: TransferStatusCompleted,
})
// 传输成功,任务结束
task.Status = TransferStatusCompleted
@@ -263,20 +257,26 @@ func (s *Service) receive(c *gin.Context, task *Transfer, writer io.Writer) {
s.app.Event.Emit("transfer:refreshList")
}
func (s *Service) receiveFolder(c *gin.Context, savePath string, task *Transfer) {
func (s *Service) receiveFolder(c *gin.Context, savePath string, task *Transfer, ctxReader io.Reader) {
// 创建根目录
destPath := filepath.Join(savePath, task.FileName)
if err := os.MkdirAll(destPath, 0755); err != nil {
c.JSON(http.StatusInternalServerError, TransferUploadResponse{
ID: task.ID,
Message: "Receiver failed to create folder",
Status: TransferStatusError,
})
slog.Error("Failed to create folder", "error", err, "component", "transfer")
task.Status = TransferStatusError
task.ErrorMsg = fmt.Errorf("receiver failed to create folder: %v", err).Error()
s.transferList.Store(task.ID, *task)
s.app.Event.Emit("transfer:refreshList")
return
}
// 包装 reader用于计算进度
reader := &PassThroughReader{
Reader: c.Request.Body,
Reader: ctxReader,
total: task.FileSize,
callback: func(current, total int64, speed float64) {
task.Progress = Progress{
@@ -289,18 +289,56 @@ func (s *Service) receiveFolder(c *gin.Context, savePath string, task *Transfer)
},
}
handleError := func(err error, stage string) bool {
if err == nil {
return false
}
if c.Request.Context().Err() != nil {
slog.Info("Transfer canceled by sender (Network disconnect)", "id", task.ID, "stage", stage)
task.Status = TransferStatusCanceled
task.ErrorMsg = "Sender disconnected"
// 发送端已断开,无需也不应再发送 c.JSON
s.transferList.Store(task.ID, *task)
s.app.Event.Emit("transfer:refreshList")
return true
}
if errors.Is(err, context.Canceled) {
slog.Info("Transfer canceled by user", "id", task.ID, "stage", stage)
task.Status = TransferStatusCanceled
task.ErrorMsg = "User canceled transfer"
// 通知发送端(虽然此时连接可能即将关闭,但尽力通知)
c.JSON(http.StatusOK, TransferUploadResponse{
ID: task.ID,
Message: "File transfer canceled",
Status: TransferStatusCanceled,
})
s.transferList.Store(task.ID, *task)
s.app.Event.Emit("transfer:refreshList")
return true
}
slog.Error("Transfer failed", "error", err, "stage", stage)
task.Status = TransferStatusError
task.ErrorMsg = fmt.Sprintf("Failed at %s: %v", stage, err)
c.JSON(http.StatusInternalServerError, TransferUploadResponse{
ID: task.ID,
Message: fmt.Sprintf("Transfer failed: %v", err),
Status: TransferStatusError,
})
s.transferList.Store(task.ID, *task)
s.app.Event.Emit("transfer:refreshList")
return true
}
tr := tar.NewReader(reader)
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
c.JSON(http.StatusInternalServerError, TransferUploadResponse{
ID: task.ID,
Message: "Stream error",
})
slog.Error("Tar stream error", "error", err)
if handleError(err, "read_tar_header") {
return
}
@@ -325,12 +363,9 @@ func (s *Service) receiveFolder(c *gin.Context, savePath string, task *Transfer)
if _, err := io.Copy(f, tr); err != nil {
f.Close()
slog.Error("Failed to write file", "path", target, "error", err)
c.JSON(http.StatusInternalServerError, TransferUploadResponse{
ID: task.ID,
Message: "Write error",
})
return
if handleError(err, "write_file_content") {
return
}
}
f.Close()
}
@@ -346,12 +381,3 @@ func (s *Service) receiveFolder(c *gin.Context, savePath string, task *Transfer)
s.transferList.Store(task.ID, *task)
s.app.Event.Emit("transfer:refreshList")
}
func (s *Service) GetTransferList() []Transfer {
var requests []Transfer
s.transferList.Range(func(key, value any) bool {
requests = append(requests, value.(Transfer))
return true
})
return requests
}