fix save history
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"mesh-drop/internal/security"
|
||||
"os"
|
||||
@@ -8,13 +9,12 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// WindowState 定义窗口状态
|
||||
type WindowState struct {
|
||||
Width int `mapstructure:"width"`
|
||||
Height int `mapstructure:"height"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
var Version = "next"
|
||||
@@ -27,24 +27,24 @@ const (
|
||||
)
|
||||
|
||||
type configData struct {
|
||||
WindowState WindowState `mapstructure:"window_state"`
|
||||
ID string `mapstructure:"id"`
|
||||
PrivateKey string `mapstructure:"private_key"`
|
||||
PublicKey string `mapstructure:"public_key"`
|
||||
SavePath string `mapstructure:"save_path"`
|
||||
HostName string `mapstructure:"host_name"`
|
||||
AutoAccept bool `mapstructure:"auto_accept"`
|
||||
SaveHistory bool `mapstructure:"save_history"`
|
||||
TrustedPeer map[string]string `mapstructure:"trusted_peer"` // ID -> PublicKey
|
||||
WindowState WindowState `json:"window_state"`
|
||||
ID string `json:"id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
PublicKey string `json:"public_key"`
|
||||
SavePath string `json:"save_path"`
|
||||
HostName string `json:"host_name"`
|
||||
AutoAccept bool `json:"auto_accept"`
|
||||
SaveHistory bool `json:"save_history"`
|
||||
TrustedPeer map[string]string `json:"trusted_peer"` // ID -> PublicKey
|
||||
|
||||
Language Language `mapstructure:"language"`
|
||||
CloseToSystray bool `mapstructure:"close_to_systray"`
|
||||
Language Language `json:"language"`
|
||||
CloseToSystray bool `json:"close_to_systray"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
v *viper.Viper
|
||||
mu sync.RWMutex
|
||||
data configData
|
||||
mu sync.RWMutex
|
||||
data configData
|
||||
configPath string
|
||||
}
|
||||
|
||||
func GetConfigDir() string {
|
||||
@@ -65,36 +65,45 @@ func GetUserHomeDir() string {
|
||||
|
||||
// New 读取配置
|
||||
func Load(defaultState WindowState) *Config {
|
||||
v := viper.New()
|
||||
configDir := GetConfigDir()
|
||||
err := os.MkdirAll(configDir, 0755)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create config directory", "error", err)
|
||||
}
|
||||
_ = os.MkdirAll(configDir, 0755)
|
||||
configFile := filepath.Join(configDir, "config.json")
|
||||
|
||||
// 设置默认值
|
||||
defaultSavePath := filepath.Join(GetUserHomeDir(), "Downloads")
|
||||
v.SetDefault("window_state", defaultState)
|
||||
v.SetDefault("save_path", defaultSavePath)
|
||||
defaultHostName, err := os.Hostname()
|
||||
if err != nil {
|
||||
defaultHostName = "localhost"
|
||||
}
|
||||
v.SetDefault("host_name", defaultHostName)
|
||||
v.SetDefault("id", uuid.New().String())
|
||||
v.SetDefault("save_history", true)
|
||||
|
||||
v.SetConfigFile(configFile)
|
||||
v.SetConfigType("json")
|
||||
cfgData := configData{
|
||||
WindowState: defaultState,
|
||||
SavePath: defaultSavePath,
|
||||
AutoAccept: false,
|
||||
SaveHistory: true,
|
||||
Language: LanguageEnglish,
|
||||
CloseToSystray: false,
|
||||
ID: uuid.New().String(),
|
||||
HostName: defaultHostName,
|
||||
TrustedPeer: make(map[string]string),
|
||||
}
|
||||
|
||||
// 尝试读取配置
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
slog.Info("Config file not found, using defaults")
|
||||
fileBytes, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
slog.Error("Failed to read config file", "error", err)
|
||||
} else {
|
||||
slog.Warn("Failed to read config file, using defaults", "error", err)
|
||||
slog.Info("Config file not found, creating new one")
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(fileBytes, &cfgData); err != nil {
|
||||
slog.Error("Failed to unmarshal config", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
config := Config{
|
||||
data: cfgData,
|
||||
configPath: configFile,
|
||||
}
|
||||
|
||||
// 确保默认保存路径存在
|
||||
@@ -103,16 +112,6 @@ func Load(defaultState WindowState) *Config {
|
||||
slog.Error("Failed to create default save path", "path", defaultSavePath, "error", err)
|
||||
}
|
||||
|
||||
var data configData
|
||||
if err := v.Unmarshal(&data); err != nil {
|
||||
slog.Error("Failed to unmarshal config", "error", err)
|
||||
}
|
||||
|
||||
config := Config{
|
||||
v: v,
|
||||
data: data,
|
||||
}
|
||||
|
||||
// 如果没有密钥对,生成新的
|
||||
if config.data.PrivateKey == "" || config.data.PublicKey == "" {
|
||||
priv, pub, err := security.GenerateKey()
|
||||
@@ -121,12 +120,6 @@ func Load(defaultState WindowState) *Config {
|
||||
} else {
|
||||
config.data.PrivateKey = priv
|
||||
config.data.PublicKey = pub
|
||||
v.Set("private_key", priv)
|
||||
v.Set("public_key", pub)
|
||||
// 保存新生成的密钥
|
||||
if err := config.Save(); err != nil {
|
||||
slog.Error("Failed to save generated keys", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,6 +128,11 @@ func Load(defaultState WindowState) *Config {
|
||||
config.data.TrustedPeer = make(map[string]string)
|
||||
}
|
||||
|
||||
// 保存
|
||||
if err := config.Save(); err != nil {
|
||||
slog.Error("Failed to save config", "error", err)
|
||||
}
|
||||
|
||||
return &config
|
||||
}
|
||||
|
||||
@@ -146,21 +144,21 @@ func (c *Config) Save() error {
|
||||
}
|
||||
|
||||
func (c *Config) save() error {
|
||||
configDir := GetConfigDir()
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
dir := filepath.Dir(c.configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.v.WriteConfig(); err != nil {
|
||||
slog.Error("Failed to write config", "error", err)
|
||||
jsonData, err := json.MarshalIndent(c.data, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置配置文件权限为 0600 (仅所有者读写)
|
||||
configFile := c.v.ConfigFileUsed()
|
||||
if configFile != "" {
|
||||
if err := os.Chmod(configFile, 0600); err != nil {
|
||||
slog.Warn("Failed to set config file permissions", "error", err)
|
||||
if c.configPath != "" {
|
||||
if err := os.WriteFile(c.configPath, jsonData, 0600); err != nil {
|
||||
slog.Warn("Failed to write config file", "error", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +181,6 @@ func (c *Config) update(fn func()) {
|
||||
func (c *Config) SetSavePath(savePath string) {
|
||||
c.update(func() {
|
||||
c.data.SavePath = savePath
|
||||
c.v.Set("save_path", savePath)
|
||||
_ = os.MkdirAll(savePath, 0755)
|
||||
})
|
||||
}
|
||||
@@ -197,7 +194,6 @@ func (c *Config) GetSavePath() string {
|
||||
func (c *Config) SetHostName(hostName string) {
|
||||
c.update(func() {
|
||||
c.data.HostName = hostName
|
||||
c.v.Set("host_name", hostName)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -216,7 +212,6 @@ func (c *Config) GetID() string {
|
||||
func (c *Config) SetAutoAccept(autoAccept bool) {
|
||||
c.update(func() {
|
||||
c.data.AutoAccept = autoAccept
|
||||
c.v.Set("auto_accept", autoAccept)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -229,7 +224,6 @@ func (c *Config) GetAutoAccept() bool {
|
||||
func (c *Config) SetSaveHistory(saveHistory bool) {
|
||||
c.update(func() {
|
||||
c.data.SaveHistory = saveHistory
|
||||
c.v.Set("save_history", saveHistory)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -246,7 +240,6 @@ func (c *Config) GetVersion() string {
|
||||
func (c *Config) SetWindowState(state WindowState) {
|
||||
c.update(func() {
|
||||
c.data.WindowState = state
|
||||
c.v.Set("window_state", state)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -262,7 +255,6 @@ func (c *Config) AddTrust(peerID string, publicKey string) {
|
||||
c.data.TrustedPeer = make(map[string]string)
|
||||
}
|
||||
c.data.TrustedPeer[peerID] = publicKey
|
||||
c.v.Set("trusted_peer", c.data.TrustedPeer)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -275,7 +267,6 @@ func (c *Config) GetTrusted() map[string]string {
|
||||
func (c *Config) RemoveTrust(peerID string) {
|
||||
c.update(func() {
|
||||
delete(c.data.TrustedPeer, peerID)
|
||||
c.v.Set("trusted_peer", c.data.TrustedPeer)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -289,7 +280,6 @@ func (c *Config) IsTrusted(peerID string) bool {
|
||||
func (c *Config) SetLanguage(language Language) {
|
||||
c.update(func() {
|
||||
c.data.Language = language
|
||||
c.v.Set("language", language)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -302,7 +292,6 @@ func (c *Config) GetLanguage() Language {
|
||||
func (c *Config) SetCloseToSystray(closeToSystray bool) {
|
||||
c.update(func() {
|
||||
c.data.CloseToSystray = closeToSystray
|
||||
c.v.Set("close_to_systray", closeToSystray)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -8,25 +8,36 @@ import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func (s *Service) SaveHistory(transfers []*Transfer) {
|
||||
func (s *Service) SaveHistory() {
|
||||
if !s.config.GetSaveHistory() {
|
||||
return
|
||||
}
|
||||
configDir := config.GetConfigDir()
|
||||
historyPath := filepath.Join(configDir, "history.json")
|
||||
historyJson, err := json.Marshal(transfers)
|
||||
tempPath := historyPath + ".tmp"
|
||||
|
||||
// 序列化传输列表
|
||||
historyJson, err := json.MarshalIndent(s.GetTransferList(), "", " ")
|
||||
if err != nil {
|
||||
slog.Error("Failed to marshal history", "error", err, "component", "transfer")
|
||||
return
|
||||
}
|
||||
file, err := os.OpenFile(historyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
|
||||
// 写入临时文件
|
||||
if err := os.WriteFile(tempPath, historyJson, 0644); err != nil {
|
||||
slog.Error("Failed to write temp history file", "error", err, "component", "transfer")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
_, err = file.Write(historyJson)
|
||||
if err != nil {
|
||||
slog.Error("Failed to write history", "error", err)
|
||||
|
||||
// 原子性重命名
|
||||
if err := os.Rename(tempPath, historyPath); err != nil {
|
||||
slog.Error("Failed to rename temp history file", "error", err, "component", "transfer")
|
||||
// 清理临时文件
|
||||
_ = os.Remove(tempPath)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("History saved successfully", "path", historyPath, "component", "transfer")
|
||||
}
|
||||
|
||||
func (s *Service) LoadHistory() {
|
||||
|
||||
@@ -32,7 +32,7 @@ func (s *Service) handleAsk(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 检查是否已经存在
|
||||
if _, exists := s.transferList.Load(task.ID); exists {
|
||||
if _, exists := s.transfers.Load(task.ID); exists {
|
||||
// 如果已经存在,说明是网络重试,直接忽略
|
||||
return
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ type Service struct {
|
||||
|
||||
// pendingRequests 存储等待用户确认的通道
|
||||
// Key: TransferID, Value: *Transfer
|
||||
transferList sync.Map
|
||||
transfers sync.Map
|
||||
|
||||
discoveryService *discovery.Service
|
||||
|
||||
@@ -90,9 +90,13 @@ func (s *Service) Start() {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Service) GetTransferSyncMap() *sync.Map {
|
||||
return &s.transfers
|
||||
}
|
||||
|
||||
func (s *Service) GetTransferList() []*Transfer {
|
||||
var requests []*Transfer = make([]*Transfer, 0)
|
||||
s.transferList.Range(func(key, value any) bool {
|
||||
s.transfers.Range(func(key, value any) bool {
|
||||
transfer := value.(*Transfer)
|
||||
requests = append(requests, transfer)
|
||||
return true
|
||||
@@ -105,7 +109,7 @@ func (s *Service) GetTransferList() []*Transfer {
|
||||
}
|
||||
|
||||
func (s *Service) GetTransfer(transferID string) (*Transfer, bool) {
|
||||
val, ok := s.transferList.Load(transferID)
|
||||
val, ok := s.transfers.Load(transferID)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
@@ -126,15 +130,13 @@ func (s *Service) CancelTransfer(transferID string) {
|
||||
|
||||
func (s *Service) StoreTransfersToList(transfers []*Transfer) {
|
||||
for _, transfer := range transfers {
|
||||
s.transferList.Store(transfer.ID, transfer)
|
||||
s.transfers.Store(transfer.ID, transfer)
|
||||
}
|
||||
s.SaveHistory(transfers)
|
||||
s.NotifyTransferListUpdate()
|
||||
}
|
||||
|
||||
func (s *Service) StoreTransferToList(transfer *Transfer) {
|
||||
s.transferList.Store(transfer.ID, transfer)
|
||||
s.SaveHistory([]*Transfer{transfer})
|
||||
s.transfers.Store(transfer.ID, transfer)
|
||||
s.NotifyTransferListUpdate()
|
||||
}
|
||||
|
||||
@@ -144,22 +146,20 @@ func (s *Service) NotifyTransferListUpdate() {
|
||||
|
||||
// CleanTransferList 清理完成的 transfer
|
||||
func (s *Service) CleanFinishedTransferList() {
|
||||
s.transferList.Range(func(key, value any) bool {
|
||||
s.transfers.Range(func(key, value any) bool {
|
||||
task := value.(*Transfer)
|
||||
if task.Status == TransferStatusCompleted ||
|
||||
task.Status == TransferStatusError ||
|
||||
task.Status == TransferStatusCanceled ||
|
||||
task.Status == TransferStatusRejected {
|
||||
s.transferList.Delete(key)
|
||||
s.transfers.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
s.SaveHistory(s.GetTransferList())
|
||||
s.NotifyTransferListUpdate()
|
||||
}
|
||||
|
||||
func (s *Service) DeleteTransfer(transferID string) {
|
||||
s.transferList.Delete(transferID)
|
||||
s.SaveHistory(s.GetTransferList())
|
||||
s.transfers.Delete(transferID)
|
||||
s.NotifyTransferListUpdate()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user