Files
mesh-drop/internal/discovery/service.go
nite 20a25e8c49 refine i18n, fill in miss parts
fix resetTrust cant recover send button in UI
2026-02-07 14:11:57 +08:00

375 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package discovery
import (
"encoding/json"
"fmt"
"log/slog"
"mesh-drop/internal/config"
"mesh-drop/internal/security"
"net"
"runtime"
"sort"
"sync"
"time"
"github.com/wailsapp/wails/v3/pkg/application"
)
const (
DiscoveryPort = 9988
HeartbeatRate = 1 * time.Second
PeerTimeout = 2 * time.Second
)
type Service struct {
app *application.App
ID string
config *config.Config
FileServerPort int
// Key: peer.ID
peers map[string]*Peer
peersMutex sync.RWMutex
self Peer
}
func NewService(config *config.Config, app *application.App, port int) *Service {
return &Service{
app: app,
ID: config.GetID(),
config: config,
FileServerPort: port,
peers: make(map[string]*Peer),
self: Peer{
ID: config.GetID(),
Name: config.GetHostName(),
Port: port,
OS: OS(runtime.GOOS),
PublicKey: config.PublicKey,
},
}
}
func GetLocalIPs() ([]string, bool) {
interfaces, err := net.Interfaces()
if err != nil {
slog.Error("Failed to get network interfaces", "error", err, "component", "discovery")
return nil, false
}
var ips []string
for _, iface := range interfaces {
// 过滤掉 Down 的接口和 Loopback 接口
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}
// 获取该接口的地址
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
continue
}
if ip.To4() == nil {
continue
}
ips = append(ips, ip.String())
}
}
return ips, true
}
func (s *Service) GetLocalIPInSameSubnet(receiverIP string) (string, bool) {
interfaces, err := net.Interfaces()
if err != nil {
slog.Error("Failed to get network interfaces", "error", err, "component", "discovery")
return "", false
}
for _, iface := range interfaces {
// 过滤掉 Down 的接口和 Loopback 接口
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}
// 获取该接口的地址
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ip, ipNet, err := net.ParseCIDR(addr.String())
if err != nil {
continue
}
if ip.To4() == nil {
continue
}
if ipNet.Contains(net.ParseIP(receiverIP)) {
return ip.String(), true
}
}
}
slog.Error("Failed to get local IP in same subnet", "receiverIP", receiverIP, "component", "discovery")
return "", false
}
func (s *Service) startBroadcasting() {
ticker := time.NewTicker(HeartbeatRate)
for range ticker.C {
interfaces, err := net.Interfaces()
if err != nil {
slog.Error("Failed to get network interfaces", "error", err, "component", "discovery")
continue
}
packet := PresencePacket{
ID: s.ID,
Name: s.config.GetHostName(),
Port: s.FileServerPort,
OS: OS(runtime.GOOS),
PublicKey: s.config.PublicKey,
}
// 签名
sigData := packet.SignPayload()
sig, err := security.Sign(s.config.PrivateKey, sigData)
if err != nil {
slog.Error("Failed to sign discovery packet", "error", err)
continue
}
packet.Signature = sig
data, _ := json.Marshal(packet)
for _, iface := range interfaces {
// 过滤掉 Down 的接口和 Loopback 接口
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
continue
}
// 获取该接口的地址
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ip, ipNet, err := net.ParseCIDR(addr.String())
if err != nil {
continue
}
if ip.To4() == nil {
continue
}
// 计算该网段的广播地址
// 例如 IP: 192.168.1.5/24 -> 广播地址: 192.168.1.255
broadcastIPV4 := make(net.IP, len(ip.To4()))
copy(broadcastIPV4, ip.To4())
for i, b := range ipNet.Mask {
broadcastIPV4[i] |= ^b
}
slog.Debug("Broadcast IP", "ip", broadcastIPV4.String(), "component", "discovery")
s.sendPacketTo(broadcastIPV4.String(), DiscoveryPort, data)
}
}
}
}
func (s *Service) sendPacketTo(ip string, port int, data []byte) {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", ip, port))
if err != nil {
return
}
conn, err := net.DialUDP("udp", nil, addr)
if err != nil {
return
}
defer conn.Close()
_, err = conn.Write(data)
if err != nil {
slog.Error("Failed to send packet", "error", err, "component", "discovery")
return
}
}
func (s *Service) startListening() {
addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", DiscoveryPort))
conn, err := net.ListenUDP("udp", addr)
if err != nil {
slog.Error("Failed to start listening", "error", err, "component", "discovery")
return
}
defer conn.Close()
buf := make([]byte, 1024)
for {
n, remoteAddr, err := conn.ReadFromUDP(buf)
if err != nil {
continue
}
var packet PresencePacket
if err := json.Unmarshal(buf[:n], &packet); err != nil {
continue
}
// 忽略自己发出的包
if packet.ID == s.ID {
continue
}
// 验证签名
sig := packet.Signature
sigData := packet.SignPayload()
valid, err := security.Verify(packet.PublicKey, sigData, sig)
if err != nil || !valid {
slog.Warn("Received invalid discovery packet signature", "id", packet.ID, "ip", remoteAddr.IP.String())
continue
}
// 验证身份一致性 (防止 ID 欺骗)
trustMismatch := false
trustedKeys := s.config.GetTrusted()
if knownKey, ok := trustedKeys[packet.ID]; ok {
if knownKey != packet.PublicKey {
slog.Warn("SECURITY ALERT: Peer ID mismatch with known public key (Spoofing attempt?)", "id", packet.ID, "known_key", knownKey, "received_key", packet.PublicKey)
trustMismatch = true
// 当发现 ID 欺骗时,不更新 peer而是标记为 trustMismatch
// 用户可以手动重新添加信任
}
} else {
// 不存在于信任列表
// 存在之前在信任列表,但是不匹配被用户手动重置了,此时需要将 peer.TrustMismatch 标记为 false
// 否则在 handleHeartbeat 里会一直标记为不匹配
if peer, ok := s.peers[packet.ID]; ok {
peer.TrustMismatch = false
}
}
s.handleHeartbeat(packet, remoteAddr.IP.String(), trustMismatch)
}
}
// handleHeartbeat 处理心跳包
func (s *Service) handleHeartbeat(pkt PresencePacket, ip string, trustMismatch bool) {
s.peersMutex.Lock()
peer, exists := s.peers[pkt.ID]
if !exists {
// 发现新节点
peer = &Peer{
ID: pkt.ID,
Name: pkt.Name,
Routes: map[string]*RouteState{
ip: {
IP: ip,
LastSeen: time.Now(),
},
},
Port: pkt.Port,
OS: pkt.OS,
PublicKey: pkt.PublicKey,
TrustMismatch: trustMismatch,
}
s.peers[peer.ID] = peer
slog.Info("New device found", "name", pkt.Name, "ip", ip, "component", "discovery")
} else {
// 更新节点
// 只有在没有身份不匹配的情况下才更新元数据,防止欺骗攻击导致 UI 闪烁/篡改
if !trustMismatch {
peer.Name = pkt.Name
peer.OS = pkt.OS
peer.PublicKey = pkt.PublicKey
}
peer.Routes[ip] = &RouteState{
IP: ip,
LastSeen: time.Now(),
}
// 如果之前存在不匹配,即使这次匹配了,也不要重置,防止欺骗攻击
peer.TrustMismatch = peer.TrustMismatch || trustMismatch
}
s.peersMutex.Unlock()
// 触发前端更新 (防抖逻辑可以之后加,这里每次变动都推)
s.app.Event.Emit("peers:update", s.GetPeers())
}
// 3. 掉线清理协程
func (s *Service) startCleanup() {
ticker := time.NewTicker(2 * time.Second)
for range ticker.C {
s.peersMutex.Lock()
changed := false
now := time.Now()
for id, peer := range s.peers {
for ip, route := range peer.Routes {
if now.Sub(route.LastSeen) > PeerTimeout {
delete(peer.Routes, ip)
changed = true
slog.Info("Device offline", "name", peer.Name, "component", "discovery")
}
}
if len(peer.Routes) == 0 {
delete(s.peers, id)
changed = true
slog.Info("Device offline", "name", peer.Name, "component", "discovery")
}
}
s.peersMutex.Unlock()
if changed {
s.app.Event.Emit("peers:update", s.GetPeers())
}
}
}
func (s *Service) Start() {
go s.startBroadcasting()
go s.startListening()
go s.startCleanup()
}
func (s *Service) GetPeerByIP(ip string) (*Peer, bool) {
s.peersMutex.RLock()
defer s.peersMutex.RUnlock()
for _, p := range s.peers {
if p.Routes[ip] != nil {
return p, true
}
}
return nil, false
}
func (s *Service) GetPeerByID(id string) (*Peer, bool) {
s.peersMutex.RLock()
defer s.peersMutex.RUnlock()
peer, ok := s.peers[id]
return peer, ok
}
func (s *Service) GetPeers() []Peer {
s.peersMutex.RLock()
defer s.peersMutex.RUnlock()
list := make([]Peer, 0)
for _, p := range s.peers {
list = append(list, *p)
}
sort.Slice(list, func(i, j int) bool {
return list[i].Name < list[j].Name
})
return list
}
func (s *Service) GetID() string {
return s.ID
}
func (s *Service) GetSelf() Peer {
return s.self
}