feat: trust peer

This commit is contained in:
2026-02-07 03:17:37 +08:00
parent d8ffc5eea5
commit f3adb56bd0
19 changed files with 438 additions and 155 deletions

View File

@@ -2,6 +2,7 @@ package config
import (
"log/slog"
"mesh-drop/internal/security"
"os"
"path/filepath"
"sync"
@@ -19,18 +20,21 @@ type WindowState struct {
Maximised bool `mapstructure:"maximised"`
}
var Version = "0.0.2"
var Version = "next"
type Config struct {
v *viper.Viper
mu sync.RWMutex
WindowState WindowState `mapstructure:"window_state"`
ID string `mapstructure:"id"`
SavePath string `mapstructure:"save_path"`
HostName string `mapstructure:"host_name"`
AutoAccept bool `mapstructure:"auto_accept"`
SaveHistory bool `mapstructure:"save_history"`
WindowState WindowState `mapstructure:"window_state"`
ID string `mapstructure:"id"`
PrivateKey string `mapstructure:"private_key"`
PublicKey string `mapstructure:"public_key"`
SavePath string `mapstructure:"save_path"`
HostName string `mapstructure:"host_name"`
AutoAccept bool `mapstructure:"auto_accept"`
SaveHistory bool `mapstructure:"save_history"`
TrustedPeer map[string]string `mapstructure:"trusted_peer"` // ID -> PublicKey
}
// 默认窗口配置
@@ -104,6 +108,28 @@ func Load() *Config {
config.v = v
// 如果没有密钥对,生成新的
if config.PrivateKey == "" || config.PublicKey == "" {
priv, pub, err := security.GenerateKey()
if err != nil {
slog.Error("Failed to generate identity keys", "error", err)
} else {
config.PrivateKey = priv
config.PublicKey = pub
v.Set("private_key", priv)
v.Set("public_key", pub)
// 保存新生成的密钥
if err := config.Save(); err != nil {
slog.Error("Failed to save generated keys", "error", err)
}
}
}
// 初始化 TrustedPeer map if nil
if config.TrustedPeer == nil {
config.TrustedPeer = make(map[string]string)
}
return &config
}
@@ -111,7 +137,10 @@ func Load() *Config {
func (c *Config) Save() error {
c.mu.RLock()
defer c.mu.RUnlock()
return c.save()
}
func (c *Config) save() error {
configDir := GetConfigDir()
if err := os.MkdirAll(configDir, 0755); err != nil {
return err
@@ -122,6 +151,14 @@ func (c *Config) Save() error {
return err
}
// 设置配置文件权限为 0600 (仅所有者读写)
configFile := c.v.ConfigFileUsed()
if configFile != "" {
if err := os.Chmod(configFile, 0600); err != nil {
slog.Warn("Failed to set config file permissions", "error", err)
}
}
return nil
}
@@ -133,6 +170,7 @@ func (c *Config) SetSavePath(savePath string) {
c.SavePath = savePath
c.v.Set("save_path", savePath)
_ = os.MkdirAll(savePath, 0755)
_ = c.save()
}
func (c *Config) GetSavePath() string {
@@ -146,6 +184,7 @@ func (c *Config) SetHostName(hostName string) {
defer c.mu.Unlock()
c.HostName = hostName
c.v.Set("host_name", hostName)
_ = c.save()
}
func (c *Config) GetHostName() string {
@@ -165,6 +204,7 @@ func (c *Config) SetAutoAccept(autoAccept bool) {
defer c.mu.Unlock()
c.AutoAccept = autoAccept
c.v.Set("auto_accept", autoAccept)
_ = c.save()
}
func (c *Config) GetAutoAccept() bool {
@@ -178,6 +218,7 @@ func (c *Config) SetSaveHistory(saveHistory bool) {
defer c.mu.Unlock()
c.SaveHistory = saveHistory
c.v.Set("save_history", saveHistory)
_ = c.save()
}
func (c *Config) GetSaveHistory() bool {
@@ -195,6 +236,7 @@ func (c *Config) SetWindowState(state WindowState) {
defer c.mu.Unlock()
c.WindowState = state
c.v.Set("window_state", state)
_ = c.save()
}
func (c *Config) GetWindowState() WindowState {
@@ -202,3 +244,35 @@ func (c *Config) GetWindowState() WindowState {
defer c.mu.RUnlock()
return c.WindowState
}
func (c *Config) AddTrustedPeer(peerID string, publicKey string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.TrustedPeer == nil {
c.TrustedPeer = make(map[string]string)
}
c.TrustedPeer[peerID] = publicKey
c.v.Set("trusted_peer", c.TrustedPeer)
_ = c.save()
}
func (c *Config) GetTrustedPeer() map[string]string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.TrustedPeer
}
func (c *Config) RemoveTrustedPeer(peerID string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.TrustedPeer, peerID)
c.v.Set("trusted_peer", c.TrustedPeer)
_ = c.save()
}
func (c *Config) IsTrustedPeer(peerID string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
_, exists := c.TrustedPeer[peerID]
return exists
}

View File

@@ -1,6 +1,9 @@
package discovery
import "time"
import (
"fmt"
"time"
)
// Peer 代表一个可达的网络端点 (Network Endpoint)。
// 注意:一个物理设备 (Device) 可能通过多个网络接口广播,因此会对应多个 Peer 结构体。
@@ -20,6 +23,12 @@ type Peer struct {
Port int `json:"port"`
OS OS `json:"os"`
PublicKey string `json:"pk"`
// TrustMismatch 指示该节点的公钥与本地信任列表中的公钥不匹配
// 如果为 true说明可能存在 ID 欺骗或密钥轮换
TrustMismatch bool `json:"trust_mismatch"`
}
// RouteState 记录单条路径的状态
@@ -38,8 +47,17 @@ const (
// PresencePacket 是 UDP 广播的载荷
type PresencePacket struct {
ID string `json:"id"`
Name string `json:"name"`
Port int `json:"port"`
OS OS `json:"os"`
ID string `json:"id"`
Name string `json:"name"`
Port int `json:"port"`
OS OS `json:"os"`
PublicKey string `json:"pk"`
Signature string `json:"sig"`
}
// SignPayload 生成用于签名的确定性数据
func (p *PresencePacket) SignPayload() []byte {
// 使用固定格式拼接字段,避免 JSON 序列化的不确定性
// 格式: id|name|port|os|pk
return fmt.Appendf(nil, "%s|%s|%d|%s|%s", p.ID, p.Name, p.Port, p.OS, p.PublicKey)
}

View File

@@ -5,8 +5,10 @@ import (
"fmt"
"log/slog"
"mesh-drop/internal/config"
"mesh-drop/internal/security"
"net"
"runtime"
"sort"
"sync"
"time"
@@ -15,8 +17,8 @@ import (
const (
DiscoveryPort = 9988
HeartbeatRate = 3 * time.Second
PeerTimeout = 10 * time.Second
HeartbeatRate = 1 * time.Second
PeerTimeout = 2 * time.Second
)
type Service struct {
@@ -26,9 +28,11 @@ type Service struct {
config *config.Config
FileServerPort int
// key 使用 peer.id 和 peer.ip 组合而成的 hash
// Key: peer.ID
peers map[string]*Peer
peersMutex sync.RWMutex
self Peer
}
func NewService(config *config.Config, app *application.App, port int) *Service {
@@ -38,10 +42,17 @@ func NewService(config *config.Config, app *application.App, port int) *Service
config: config,
FileServerPort: port,
peers: make(map[string]*Peer),
self: Peer{
ID: config.GetID(),
Name: config.GetHostName(),
Port: port,
OS: OS(runtime.GOOS),
PublicKey: config.PublicKey,
},
}
}
func (s *Service) GetLocalIPs() ([]string, bool) {
func GetLocalIPs() ([]string, bool) {
interfaces, err := net.Interfaces()
if err != nil {
slog.Error("Failed to get network interfaces", "error", err, "component", "discovery")
@@ -114,11 +125,22 @@ func (s *Service) startBroadcasting() {
continue
}
packet := PresencePacket{
ID: s.ID,
Name: s.config.GetHostName(),
Port: s.FileServerPort,
OS: OS(runtime.GOOS),
ID: s.ID,
Name: s.config.GetHostName(),
Port: s.FileServerPort,
OS: OS(runtime.GOOS),
PublicKey: s.config.PublicKey,
}
// 签名
sigData := packet.SignPayload()
sig, err := security.Sign(s.config.PrivateKey, sigData)
if err != nil {
slog.Error("Failed to sign discovery packet", "error", err)
continue
}
packet.Signature = sig
data, _ := json.Marshal(packet)
for _, iface := range interfaces {
// 过滤掉 Down 的接口和 Loopback 接口
@@ -195,12 +217,33 @@ func (s *Service) startListening() {
continue
}
s.handleHeartbeat(packet, remoteAddr.IP.String())
// 验证签名
sig := packet.Signature
sigData := packet.SignPayload()
valid, err := security.Verify(packet.PublicKey, sigData, sig)
if err != nil || !valid {
slog.Warn("Received invalid discovery packet signature", "id", packet.ID, "ip", remoteAddr.IP.String())
continue
}
// 验证身份一致性 (防止 ID 欺骗)
trustMismatch := false
trustedKeys := s.config.GetTrustedPeer()
if knownKey, ok := trustedKeys[packet.ID]; ok {
if knownKey != packet.PublicKey {
slog.Warn("SECURITY ALERT: Peer ID mismatch with known public key (Spoofing attempt?)", "id", packet.ID, "known_key", knownKey, "received_key", packet.PublicKey)
trustMismatch = true
// 当发现 ID 欺骗时,不更新 peer而是标记为 trustMismatch
// 用户可以手动重新添加信任
}
}
s.handleHeartbeat(packet, remoteAddr.IP.String(), trustMismatch)
}
}
// handleHeartbeat 处理心跳包
func (s *Service) handleHeartbeat(pkt PresencePacket, ip string) {
func (s *Service) handleHeartbeat(pkt PresencePacket, ip string, trustMismatch bool) {
s.peersMutex.Lock()
peer, exists := s.peers[pkt.ID]
@@ -215,19 +258,27 @@ func (s *Service) handleHeartbeat(pkt PresencePacket, ip string) {
LastSeen: time.Now(),
},
},
Port: pkt.Port,
OS: pkt.OS,
Port: pkt.Port,
OS: pkt.OS,
PublicKey: pkt.PublicKey,
TrustMismatch: trustMismatch,
}
s.peers[peer.ID] = peer
slog.Info("New device found", "name", pkt.Name, "ip", ip, "component", "discovery")
} else {
// 更新节点
peer.Name = pkt.Name
peer.OS = pkt.OS
// 只有在没有身份不匹配的情况下才更新元数据,防止欺骗攻击导致 UI 闪烁/篡改
if !trustMismatch {
peer.Name = pkt.Name
peer.OS = pkt.OS
peer.PublicKey = pkt.PublicKey
}
peer.Routes[ip] = &RouteState{
IP: ip,
LastSeen: time.Now(),
}
// 如果之前存在不匹配,即使这次匹配了,也不要重置,防止欺骗攻击
peer.TrustMismatch = peer.TrustMismatch || trustMismatch
}
s.peersMutex.Unlock()
@@ -246,7 +297,6 @@ func (s *Service) startCleanup() {
for id, peer := range s.peers {
for ip, route := range peer.Routes {
// 超过10秒没心跳认为下线
if now.Sub(route.LastSeen) > PeerTimeout {
delete(peer.Routes, ip)
changed = true
@@ -274,16 +324,24 @@ func (s *Service) Start() {
go s.startCleanup()
}
func (s *Service) GetPeerByIP(ip string) *Peer {
func (s *Service) GetPeerByIP(ip string) (*Peer, bool) {
s.peersMutex.RLock()
defer s.peersMutex.RUnlock()
for _, p := range s.peers {
if p.Routes[ip] != nil {
return p
return p, true
}
}
return nil
return nil, false
}
func (s *Service) GetPeerByID(id string) (*Peer, bool) {
s.peersMutex.RLock()
defer s.peersMutex.RUnlock()
peer, ok := s.peers[id]
return peer, ok
}
func (s *Service) GetPeers() []Peer {
@@ -294,9 +352,16 @@ func (s *Service) GetPeers() []Peer {
for _, p := range s.peers {
list = append(list, *p)
}
sort.Slice(list, func(i, j int) bool {
return list[i].Name < list[j].Name
})
return list
}
func (s *Service) GetID() string {
return s.ID
}
func (s *Service) GetSelf() Peer {
return s.self
}

View File

@@ -0,0 +1,56 @@
package security
import (
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"fmt"
)
// GenerateKey 生成新的 Ed25519 密钥对
// 返回 base64 编码的私钥和公钥
func GenerateKey() (string, string, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return "", "", err
}
return base64.StdEncoding.EncodeToString(priv), base64.StdEncoding.EncodeToString(pub), nil
}
// Sign 使用私钥对数据进行签名
// privKeyStr: base64 编码的私钥
// data: 要签名的数据
// 返回: base64 编码的签名
func Sign(privKeyStr string, data []byte) (string, error) {
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyStr)
if err != nil {
return "", fmt.Errorf("invalid private key: %w", err)
}
if len(privKeyBytes) != ed25519.PrivateKeySize {
return "", fmt.Errorf("invalid private key length")
}
signature := ed25519.Sign(ed25519.PrivateKey(privKeyBytes), data)
return base64.StdEncoding.EncodeToString(signature), nil
}
// Verify 使用公钥验证签名
// pubKeyStr: base64 编码的公钥
// data: 原始数据
// sigStr: base64 编码的签名
func Verify(pubKeyStr string, data []byte, sigStr string) (bool, error) {
pubKeyBytes, err := base64.StdEncoding.DecodeString(pubKeyStr)
if err != nil {
return false, fmt.Errorf("invalid public key: %w", err)
}
if len(pubKeyBytes) != ed25519.PublicKeySize {
return false, fmt.Errorf("invalid public key length")
}
sigBytes, err := base64.StdEncoding.DecodeString(sigStr)
if err != nil {
return false, fmt.Errorf("invalid signature: %w", err)
}
return ed25519.Verify(ed25519.PublicKey(pubKeyBytes), data, sigBytes), nil
}

View File

@@ -44,11 +44,7 @@ func (s *Service) SendFile(target *discovery.Peer, targetIP string, filePath str
task := NewTransfer(
taskID,
NewSender(
s.discoveryService.GetID(),
s.config.GetHostName(),
WithReceiverIP(targetIP, s.discoveryService),
),
s.discoveryService.GetSelf(),
WithFileName(filepath.Base(filePath)),
WithFileSize(stat.Size()),
WithType(TransferTypeSend),
@@ -111,11 +107,7 @@ func (s *Service) SendFolder(target *discovery.Peer, targetIP string, folderPath
task := NewTransfer(
taskID,
NewSender(
s.discoveryService.GetID(),
s.config.GetHostName(),
WithReceiverIP(targetIP, s.discoveryService),
),
s.discoveryService.GetSelf(),
WithFileName(filepath.Base(folderPath)),
WithFileSize(size),
WithType(TransferTypeSend),
@@ -164,11 +156,7 @@ func (s *Service) SendText(target *discovery.Peer, targetIP string, text string)
r := bytes.NewReader([]byte(text))
task := NewTransfer(
taskID,
NewSender(
s.discoveryService.GetID(),
s.config.GetHostName(),
WithReceiverIP(targetIP, s.discoveryService),
),
s.discoveryService.GetSelf(),
WithFileSize(int64(len(text))),
WithType(TransferTypeSend),
WithContentType(ContentTypeText),

View File

@@ -1,7 +1,6 @@
package transfer
import (
"log/slog"
"mesh-drop/internal/discovery"
"time"
)
@@ -37,7 +36,7 @@ const (
type Transfer struct {
ID string `json:"id" binding:"required"` // 传输会话 ID
CreateTime int64 `json:"create_time"` // 创建时间
Sender Sender `json:"sender" binding:"required"` // 发送者
Sender discovery.Peer `json:"sender" binding:"required"` // 发送者
FileName string `json:"file_name"` // 文件名
FileSize int64 `json:"file_size"` // 文件大小 (字节)
SavePath string `json:"savePath"` // 保存路径
@@ -53,7 +52,7 @@ type Transfer struct {
type TransferOption func(*Transfer)
func NewTransfer(id string, sender Sender, opts ...TransferOption) *Transfer {
func NewTransfer(id string, sender discovery.Peer, opts ...TransferOption) *Transfer {
t := &Transfer{
ID: id,
CreateTime: time.Now().UnixMilli(),
@@ -122,41 +121,6 @@ func WithToken(token string) TransferOption {
}
}
type Sender struct {
ID string `json:"id" binding:"required"` // 发送者 ID
Name string `json:"name" binding:"required"` // 发送者名称
IP string `json:"ip" binding:"required"` // 发送者 IP
}
type NewSenderOption func(*Sender)
func NewSender(id string, name string, opts ...NewSenderOption) Sender {
s := &Sender{
ID: id,
Name: name,
}
for _, opt := range opts {
opt(s)
}
return *s
}
func WithIP(ip string) NewSenderOption {
return func(s *Sender) {
s.IP = ip
}
}
func WithReceiverIP(ip string, discoveryService *discovery.Service) NewSenderOption {
return func(s *Sender) {
ip, ok := discoveryService.GetLocalIPInSameSubnet(ip)
if !ok {
slog.Error("Failed to get local IP in same subnet", "ip", ip, "component", "transfer-client")
}
s.IP = ip
}
}
// Progress 用户前端传输进度
type Progress struct {
Current int64 `json:"current"` // 当前进度

View File

@@ -43,7 +43,13 @@ func (s *Service) handleAsk(c *gin.Context) {
task.DecisionChan = make(chan Decision, 1)
s.StoreTransferToList(&task)
if s.config.GetAutoAccept() {
// 从本地获取 peer 检查是否 mismatch
peer, ok := s.discoveryService.GetPeerByID(task.Sender.ID)
if ok {
task.Sender.TrustMismatch = peer.TrustMismatch
}
if s.config.GetAutoAccept() || (s.config.IsTrustedPeer(task.Sender.ID) && !task.Sender.TrustMismatch) {
task.DecisionChan <- Decision{
ID: task.ID,
Accepted: true,
@@ -54,7 +60,7 @@ func (s *Service) handleAsk(c *gin.Context) {
_ = s.notifier.SendNotification(notifications.NotificationOptions{
ID: uuid.New().String(),
Title: "File Transfer Request",
Body: fmt.Sprintf("%s(%s) wants to transfer %s", task.Sender.Name, task.Sender.IP, task.FileName),
Body: fmt.Sprintf("%s wants to transfer %s", task.Sender.Name, task.FileName),
})
}
@@ -74,6 +80,11 @@ func (s *Service) handleAsk(c *gin.Context) {
})
} else {
task.Status = TransferStatusRejected
c.JSON(http.StatusOK, TransferAskResponse{
ID: task.ID,
Accepted: false,
Message: "Transfer rejected",
})
}
case <-c.Request.Context().Done():
// 发送端放弃