This commit is contained in:
2026-02-27 21:12:56 +08:00
commit a878084cbb
233 changed files with 22988 additions and 0 deletions

1
rust/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/target

3531
rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

42
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,42 @@
[package]
name = "rust_lib_mesh_drop_flutter"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib", "staticlib"]
[dependencies]
flutter_rust_bridge = "=2.11.1"
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
log = "0.4"
async-tar = { version = "0.6.0", features = [
"runtime-tokio",
], default-features = false }
axum = { version = "0.8.8", features = ["json"] }
axum-server = { version = "0.8.0", features = ["tls-rustls"] }
base64 = "0.22.1"
chrono = "0.4.44"
dashmap = { version = "6.1.0", features = ["serde"] }
dirs = "6.0.0"
ed25519-dalek = { version = "2.2.0", features = ["rand_core"] }
fd-lock = "4.0.4"
futures-util = "0.3.32"
gethostname = "1.1.0"
if-addrs = "0.15.0"
rand = "0.8"
rcgen = "0.14.7"
reqwest = { version = "0.13.2", features = ["stream", "json"] }
rustyline = { version = "17.0.2", features = ["derive"] }
shlex = "1.3.0"
thiserror = "2.0.18"
tokio = { version = "1.49.0", features = ["full"] }
tokio-stream = "0.1.18"
tokio-util = { version = "0.7.18", features = ["io"] }
tracing = "0.1.44"
tracing-subscriber = "0.3.22"
uuid = { version = "1.21.0", features = ["v4"] }
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(frb_expand)'] }

135
rust/src/api/commands.rs Normal file
View File

@@ -0,0 +1,135 @@
use crate::discovery::model::Peer;
use crate::event::AppEvent;
use crate::transfer::model::Transfer;
pub async fn create_event_stream(sink: crate::frb_generated::StreamSink<AppEvent>) {
let mut rx = get_appstate().await.event_tx.subscribe();
tokio::spawn(async move {
while let Ok(event) = rx.recv().await {
let _ = sink.add(event);
}
});
}
use crate::get_appstate;
pub async fn get_save_path() -> String {
get_appstate().await.config.get_save_path()
}
pub async fn set_save_path(save_path: String) {
get_appstate().await.config.set_save_path(save_path);
}
pub async fn set_hostname(hostname: String) {
get_appstate().await.config.set_hostname(hostname);
}
pub async fn get_hostname() -> String {
get_appstate().await.config.get_hostname()
}
pub async fn get_auto_accept() -> bool {
get_appstate().await.config.get_auto_accept()
}
pub async fn set_auto_accept(auto_accept: bool) {
get_appstate().await.config.set_auto_accept(auto_accept);
}
pub async fn get_save_history() -> bool {
get_appstate().await.config.get_save_history()
}
pub async fn set_save_history(save_history: bool) {
get_appstate().await.config.set_save_history(save_history);
}
pub async fn get_enable_tls() -> bool {
get_appstate().await.config.get_enable_tls()
}
pub async fn set_enable_tls(enable_tls: bool) {
get_appstate().await.config.set_enable_tls(enable_tls);
}
pub async fn get_peers() -> Result<Vec<Peer>, String> {
Ok(get_appstate().await.discovery.get_peers().await)
}
pub async fn send_file(target: Peer, target_ip: &str, file_path: &str) -> Result<(), String> {
let sender = get_appstate().await.discovery.get_self().await;
match get_appstate()
.await
.transfer
.send_file(target, target_ip, sender, file_path)
.await
{
Ok(_) => Ok(()),
Err(e) => Err(e.to_string()),
}
}
pub async fn send_text(target: Peer, target_ip: &str, text: &str) -> Result<(), String> {
let sender = get_appstate().await.discovery.get_self().await;
match get_appstate()
.await
.transfer
.send_text(target, target_ip, sender, text)
.await
{
Ok(_) => Ok(()),
Err(e) => Err(e.to_string()),
}
}
pub async fn send_folder(target: Peer, target_ip: &str, folder_path: &str) -> Result<(), String> {
let sender = get_appstate().await.discovery.get_self().await;
match get_appstate()
.await
.transfer
.send_folder(target, target_ip, sender, folder_path)
.await
{
Ok(_) => Ok(()),
Err(e) => Err(e.to_string()),
}
}
pub async fn get_transfers() -> Vec<Transfer> {
get_appstate().await.transfer.get_transfers()
}
pub async fn resolve_pending_request(id: &str, accept: bool, path: &str) {
get_appstate()
.await
.transfer
.make_decision(id, accept, path);
}
pub async fn cancel_transfer(id: &str) {
get_appstate().await.transfer.cancel(id);
}
pub async fn delete_transfer(id: &str) {
get_appstate().await.transfer.delete(id);
}
pub async fn clear_transfers() {
get_appstate().await.transfer.clear_transfers();
}
pub async fn is_trusted(peer_id: &str) -> bool {
get_appstate().await.trust.is_trusted(peer_id)
}
pub async fn trust_peer(peer_id: &str) -> Result<(), String> {
if let Some(peer) = get_appstate().await.discovery.get_peer(peer_id).await {
get_appstate().await.trust.trust(peer_id, &peer.public_key);
}
Ok(())
}
pub async fn untrust_peer(peer_id: &str) {
get_appstate().await.trust.untrust(peer_id);
}

1
rust/src/api/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod commands;

136
rust/src/config/mod.rs Normal file
View File

@@ -0,0 +1,136 @@
use std::{fs, sync::RwLock};
use crate::error::AppError;
use crate::security::{self, identity::Identity};
pub mod model;
pub struct Config {
data: RwLock<model::ConfigData>,
config_path: String,
}
impl Config {
// Getter
pub fn get_id(&self) -> String {
self.data.read().unwrap().id.clone()
}
pub fn get_hostname(&self) -> String {
self.data.read().unwrap().hostname.clone()
}
pub fn set_hostname(&self, name: String) {
self.data.write().unwrap().hostname = name;
}
pub fn get_public_key(&self) -> String {
self.data.read().unwrap().public_key.clone()
}
pub fn get_private_key(&self) -> String {
self.data.read().unwrap().private_key.clone()
}
pub fn get_cert_pem(&self) -> String {
self.data.read().unwrap().cert_pem.clone()
}
pub fn get_key_pem(&self) -> String {
self.data.read().unwrap().key_pem.clone()
}
pub fn get_save_path(&self) -> String {
self.data.read().unwrap().save_path.clone()
}
pub fn set_save_path(&self, save_path: String) {
self.data.write().unwrap().save_path = save_path;
}
pub fn get_enable_tls(&self) -> bool {
self.data.read().unwrap().enable_tls.clone()
}
pub fn set_enable_tls(&self, enable_tls: bool) {
self.data.write().unwrap().enable_tls = enable_tls;
}
pub fn get_save_history(&self) -> bool {
self.data.read().unwrap().save_history.clone()
}
pub fn set_save_history(&self, save_history: bool) {
self.data.write().unwrap().save_history = save_history;
}
pub fn get_config_dir(&self) -> std::path::PathBuf {
let config_dir = dirs::config_dir()
.unwrap_or(std::env::temp_dir())
.join("mesh-drop");
config_dir
}
pub fn get_auto_accept(&self) -> bool {
self.data.read().unwrap().auto_accept.clone()
}
pub fn set_auto_accept(&self, auto_accept: bool) {
self.data.write().unwrap().auto_accept = auto_accept;
}
}
impl Config {
/// 加载配置文件,后期跨平台(Android/IOS)需要重新设计
pub fn load() -> Result<Self, AppError> {
let config_dir = dirs::config_dir()
.unwrap_or(std::env::temp_dir())
.join("mesh-drop");
let config_file = config_dir.join("config.json");
// 创建目录
fs::create_dir_all(&config_dir)?;
// 尝试读取配置文件
let mut data: model::ConfigData = if config_file.exists() {
let content = fs::read_to_string(&config_file)?;
serde_json::from_str(&content)?
} else {
// 文件不存在 → 创建默认配置
model::ConfigData::default()
};
// 如果密钥为空,生成新的
if data.private_key.is_empty() || data.public_key.is_empty() {
let identity = Identity::new();
data.private_key = identity.private_key_base64();
data.public_key = identity.public_key_base64();
}
// 如果证书为空,生成新的
if data.cert_pem.is_empty() || data.key_pem.is_empty() {
let (cert, key) = security::cert::generate_self_signed()?;
data.cert_pem = cert;
data.key_pem = key;
}
let config = Config {
data: RwLock::new(data),
config_path: config_file.to_str().unwrap().to_string(),
};
// 保存一次配置
config.save()?;
Ok(config)
}
pub fn save(&self) -> Result<(), AppError> {
let data = self
.data
.read()
.map_err(|e| AppError::ConfigError(e.to_string()))?;
let content = serde_json::to_string_pretty(&*data)?;
fs::write(&self.config_path, content)?;
Ok(())
}
}

41
rust/src/config/model.rs Normal file
View File

@@ -0,0 +1,41 @@
use gethostname::gethostname;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct ConfigData {
/// uuid
pub id: String,
pub hostname: String,
pub private_key: String,
pub public_key: String,
pub save_path: String,
pub auto_accept: bool,
pub save_history: bool,
pub trusted_peer: HashMap<String, String>, // peer_id -> public_key
pub cert_pem: String, // TLS 证书PEM 格式)
pub key_pem: String, // TLS 私钥PEM 格式)
pub enable_tls: bool,
}
impl Default for ConfigData {
fn default() -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
hostname: gethostname().to_str().unwrap_or("localhost").to_string(),
private_key: String::new(), // load 里检测到空再生成
public_key: String::new(),
save_path: dirs::download_dir()
.unwrap_or(std::env::temp_dir())
.to_string_lossy()
.to_string(),
auto_accept: false,
save_history: true,
trusted_peer: HashMap::new(),
cert_pem: String::new(),
key_pem: String::new(),
enable_tls: true,
}
}
}

View File

@@ -0,0 +1,33 @@
use if_addrs::get_if_addrs;
use std::net::Ipv4Addr;
/// 获取所有子网的广播地址
pub fn get_broadcast_addresses() -> Vec<Ipv4Addr> {
let mut addrs = Vec::new();
let interfaces = match get_if_addrs() {
Ok(ifaces) => ifaces,
Err(_) => return addrs,
};
for iface in interfaces {
// 跳过 loopback
if iface.is_loopback() {
continue;
}
// 只处理 IPv4
if let if_addrs::IfAddr::V4(v4) = &iface.addr {
let ip = v4.ip;
let mask = v4.netmask;
// 广播地址 = IP | ~掩码
let broadcast = Ipv4Addr::new(
ip.octets()[0] | !mask.octets()[0],
ip.octets()[1] | !mask.octets()[1],
ip.octets()[2] | !mask.octets()[2],
ip.octets()[3] | !mask.octets()[3],
);
addrs.push(broadcast);
}
}
addrs
}

256
rust/src/discovery/mod.rs Normal file
View File

@@ -0,0 +1,256 @@
use chrono::Utc;
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{net::UdpSocket, sync::RwLock, time::interval};
use crate::{
config::Config,
discovery::model::{Peer, PresencePacket, RouteState},
event::AppEvent,
security::identity::{self, Identity},
trust::TrustStore,
};
pub mod address;
pub mod model;
pub struct DiscoveryService {
id: String,
config: Arc<Config>,
peers: Arc<RwLock<HashMap<String, Peer>>>,
port: u16,
trust: Arc<TrustStore>,
tx: tokio::sync::broadcast::Sender<AppEvent>,
}
impl DiscoveryService {
// Getter
pub async fn get_peers(&self) -> Vec<Peer> {
self.peers.read().await.values().cloned().collect()
}
pub async fn get_peer(&self, id: &str) -> Option<Peer> {
self.peers.read().await.get(id).cloned()
}
pub async fn get_self(&self) -> Peer {
Peer {
id: self.id.clone(),
name: self.config.get_hostname(),
port: self.port,
os: std::env::consts::OS.to_string(),
public_key: self.config.get_public_key(),
trust_mismatch: false,
routes: HashMap::new(),
enable_tls: self.config.get_enable_tls(),
}
}
}
impl DiscoveryService {
pub fn new(
config: Arc<Config>,
port: u16,
trust: Arc<TrustStore>,
tx: tokio::sync::broadcast::Sender<AppEvent>,
) -> Self {
Self {
id: config.get_id(),
config,
peers: Arc::new(RwLock::new(HashMap::new())),
port,
trust,
tx,
}
}
pub fn start(&self) {
// 广播
let config = self.config.clone();
let id = self.id.clone();
let port = self.port;
tokio::spawn(async move {
Self::start_broadcasting(config, id, port).await;
});
// 监听
let config = self.config.clone();
let id = self.id.clone();
let peers = self.peers.clone();
let trust = self.trust.clone();
let tx = self.tx.clone();
tokio::spawn(async move {
Self::start_listening(config, id, peers, trust, tx).await;
});
// 清理
let peers = self.peers.clone();
let tx = self.tx.clone();
tokio::spawn(async move {
Self::start_cleanup(peers, tx).await;
});
}
async fn start_broadcasting(config: Arc<Config>, id: String, port: u16) {
let socket = UdpSocket::bind("0.0.0.0:0")
.await
.expect("Failed to bind socket");
socket.set_broadcast(true).expect("Failed to set broadcast");
let identity = Identity::from_base64(&config.get_private_key(), &config.get_public_key())
.expect("Failed to create identity");
let mut ticker = interval(Duration::from_secs(1));
loop {
ticker.tick().await;
// 构造 PresencePacket
let mut packet = PresencePacket {
id: id.clone(),
name: config.get_hostname(),
port,
os: std::env::consts::OS.to_string(),
public_key: config.get_public_key(),
signature: String::new(),
enable_tls: config.get_enable_tls(),
};
// 签名 packet
let sign_data = packet.sign_payload();
let signature = identity.sign(&sign_data);
packet.signature = signature;
let data = serde_json::to_vec(&packet).unwrap();
// 发送广播
for addr in address::get_broadcast_addresses() {
let target = format!("{}:{}", addr, 9988);
if let Err(e) = socket.send_to(&data, &target).await {
tracing::error!("Failed to send broadcast: {}", e);
} else {
tracing::debug!("Broadcast sent to {}", target);
}
}
}
}
async fn start_listening(
_: Arc<Config>,
id: String,
peers: Arc<RwLock<HashMap<String, Peer>>>,
trust: Arc<TrustStore>,
tx: tokio::sync::broadcast::Sender<AppEvent>,
) {
let socket = UdpSocket::bind(format!("0.0.0.0:{}", 9988))
.await
.expect("Failed to bind discovery listener");
let mut buf = [0u8; 1024];
loop {
// 接受广播
let (len, addr) = match socket.recv_from(&mut buf).await {
Ok(result) => result,
Err(e) => {
tracing::error!("Failed to receive: {}", e);
continue;
}
};
// 反序列化
let packet: PresencePacket = match serde_json::from_slice(&buf[..len]) {
Ok(p) => p,
Err(e) => {
tracing::error!("Failed to deserialize packet: {}", e);
continue;
}
};
// 忽略自己
if packet.id == id {
continue;
}
// 验证签名
let sig = packet.signature.clone();
let sign_data = packet.sign_payload();
match identity::verify(&packet.public_key, &sign_data, &sig) {
Ok(true) => {}
_ => {
tracing::warn!("Invailed signature from {}", addr.ip());
continue;
}
}
// 更新 peers 列表
let ip = addr.ip().to_string();
let mut peers = peers.write().await;
let peer = peers.entry(packet.id.clone()).or_insert_with(|| {
tracing::info!("New device: {}({})", packet.name, ip);
Peer {
id: packet.id.clone(),
name: packet.name.clone(),
routes: HashMap::new(),
port: packet.port,
os: packet.os.clone(),
public_key: packet.public_key.clone(),
trust_mismatch: false,
enable_tls: packet.enable_tls,
}
});
// 检测公钥是否和信任列表一致
if let Some(trusted_key) = trust.get_trusted_key(&packet.id) {
peer.trust_mismatch = trusted_key != packet.public_key;
}
// 更新路由
peer.name = packet.name.clone();
peer.routes.insert(
ip.clone(),
RouteState {
ip,
last_seen: Utc::now().timestamp_millis() as f64,
},
);
peer.enable_tls = packet.enable_tls;
// 发送事件
let _ = tx.send(AppEvent::PeerConnectOrUpdated { peer: peer.clone() });
}
}
async fn start_cleanup(
peers: Arc<RwLock<HashMap<String, Peer>>>,
tx: tokio::sync::broadcast::Sender<AppEvent>,
) {
let mut ticker = interval(Duration::from_secs(2));
loop {
ticker.tick().await;
let mut peers = peers.write().await;
// 移除 peer 中的超时 route
for peer in peers.values_mut() {
peer.routes.retain(|_, route| {
route.last_seen > (Utc::now().timestamp_millis() - 3000) as f64
});
if !peer.routes.is_empty() {
tracing::debug!("Device updated: {}", peer.name);
let _ = tx.send(AppEvent::PeerConnectOrUpdated { peer: peer.clone() });
}
}
// 清理 route 为空的 peer
peers.retain(|_, p| {
if p.routes.is_empty() {
tracing::info!("Device offline: {}", p.name);
let _ = tx.send(AppEvent::PeerDisconnected { id: p.id.clone() });
false
} else {
true
}
});
}
}
}

View File

@@ -0,0 +1,44 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Peer {
pub id: String,
pub name: String,
pub routes: HashMap<String, RouteState>,
pub port: u16,
pub os: String,
pub public_key: String,
pub trust_mismatch: bool,
pub enable_tls: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouteState {
pub ip: String,
pub last_seen: f64, // timestamp
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PresencePacket {
pub id: String,
pub name: String,
pub port: u16,
pub os: String,
#[serde(rename = "pk")]
pub public_key: String,
#[serde(rename = "sig")]
pub signature: String,
pub enable_tls: bool,
}
impl PresencePacket {
pub fn sign_payload(&self) -> Vec<u8> {
format!(
"{}|{}|{}|{:?}|{}",
self.id, self.name, self.port, self.os, self.public_key
)
.into_bytes()
}
}

26
rust/src/error.rs Normal file
View File

@@ -0,0 +1,26 @@
use base64::DecodeError;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Base64: unable to decode {0}")]
Base64DecodeError(DecodeError),
#[error("Identity: unable to convert base64 string to private key")]
LoadPrivateKeyError,
#[error("Identity: unable to convert base64 string to public key")]
LoadPublicKeyError,
#[error("Identity: unable to convert base64 string to signature")]
LoadSignatureError,
#[error("Security: unable to generate self-signed certificate: {0}")]
GenerateSelfSignedError(String),
#[error("IO: unable to read file: {0}")]
IoError(#[from] std::io::Error),
#[error("JSON: unable to parse: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Config: unable to load config: {0}")]
ConfigError(String),
#[error("Network: {0}")]
Network(String),
#[error("Canceled: {0}")]
Canceled(String),
}

29
rust/src/event.rs Normal file
View File

@@ -0,0 +1,29 @@
use crate::{discovery::model::Peer, transfer::model::Transfer};
use serde::Serialize;
#[derive(Clone, Serialize, Debug)]
#[serde(tag = "type", content = "payload")]
pub enum AppEvent {
TransferStatusChanged {
transfer: Transfer,
},
TransferProgressChanged {
id: String,
progress: f64,
total: f64,
speed: f64,
},
PeerConnectOrUpdated {
peer: Peer,
},
PeerDisconnected {
id: String,
},
TransferAdded {
transfer: Transfer,
},
TransferRemoved {
id: String,
},
TransferClear,
}

1892
rust/src/frb_generated.rs Normal file

File diff suppressed because it is too large Load Diff

67
rust/src/lib.rs Normal file
View File

@@ -0,0 +1,67 @@
use std::sync::Arc;
use tokio::sync::OnceCell;
use crate::{discovery::DiscoveryService, event::AppEvent};
pub mod api;
mod config;
mod discovery;
mod error;
mod event;
mod frb_generated;
mod security;
mod transfer;
mod trust;
static APPSTATE: OnceCell<AppState> = OnceCell::const_new();
pub struct AppState {
pub config: Arc<config::Config>,
pub trust: Arc<trust::TrustStore>,
pub transfer: Arc<transfer::TransferService>,
pub discovery: Arc<discovery::DiscoveryService>,
pub event_tx: tokio::sync::broadcast::Sender<AppEvent>,
}
pub async fn get_appstate() -> &'static AppState {
APPSTATE
.get_or_init(|| async {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.init();
let config = Arc::new(config::Config::load().expect("Failed to load config"));
tracing::info!("Config loaded");
let trust = Arc::new(trust::TrustStore::new(config.clone()));
let (event_tx, _event_rx) = tokio::sync::broadcast::channel::<AppEvent>(100);
// 启动传输服务
let transfer = Arc::new(transfer::TransferService::new(
config.clone(),
trust.clone(),
event_tx.clone(),
));
let port = transfer.start().await;
// 启动发现服务
let discovery = Arc::new(DiscoveryService::new(
config.clone(),
port,
trust.clone(),
event_tx.clone(),
));
discovery.start();
AppState {
config,
trust,
transfer,
discovery,
event_tx,
}
})
.await
}

21
rust/src/security/cert.rs Normal file
View File

@@ -0,0 +1,21 @@
use crate::error::AppError;
pub fn generate_self_signed() -> Result<(String, String), AppError> {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.map_err(|e| AppError::GenerateSelfSignedError(e.to_string()))?;
Ok((cert.cert.pem(), cert.signing_key.serialize_pem()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_self_signed() {
let result = generate_self_signed();
assert!(result.is_ok());
let (cert, key) = result.unwrap();
assert!(!cert.is_empty());
assert!(!key.is_empty());
}
}

View File

@@ -0,0 +1,121 @@
use base64::{prelude::BASE64_STANDARD, Engine};
use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
use crate::error;
use error::AppError;
use rand::rngs::OsRng;
pub struct Identity {
private_key: SigningKey,
public_key: VerifyingKey,
}
impl Identity {
/// 生成新的 ed25519 密钥对
pub fn new() -> Self {
let mut csprng = OsRng;
let signing_key: SigningKey = SigningKey::generate(&mut csprng);
let verifying_key = signing_key.verifying_key();
Identity {
private_key: signing_key,
public_key: verifying_key,
}
}
/// 使用私钥对数据进行签名
pub fn sign(&self, data: &[u8]) -> String {
let signature = self.private_key.sign(data);
BASE64_STANDARD.encode(signature.to_bytes())
}
/// 导出公钥base64用于广播给其他设备
pub fn public_key_base64(&self) -> String {
BASE64_STANDARD.encode(self.public_key.as_bytes())
}
/// 导出私钥base64用于持久化到配置文件
pub fn private_key_base64(&self) -> String {
BASE64_STANDARD.encode(self.private_key.as_bytes())
}
/// 从 base64 密钥对恢复(从配置文件加载时用)
pub fn from_base64(private_key: &str, public_key: &str) -> Result<Self, AppError> {
let private_bytes = BASE64_STANDARD
.decode(private_key)
.map_err(|e| AppError::Base64DecodeError(e))?;
let public_bytes = BASE64_STANDARD
.decode(public_key)
.map_err(|e| AppError::Base64DecodeError(e))?;
let private_array: [u8; 32] = private_bytes
.as_slice()
.try_into()
.map_err(|_| AppError::LoadPrivateKeyError)?;
let public_array: [u8; 32] = public_bytes
.as_slice()
.try_into()
.map_err(|_| AppError::LoadPublicKeyError)?;
let private_key = SigningKey::from_bytes(&private_array);
let public_key =
VerifyingKey::from_bytes(&public_array).map_err(|_| AppError::LoadPublicKeyError)?;
Ok(Identity {
private_key,
public_key,
})
}
}
pub fn verify(public_key: &str, data: &[u8], signature: &str) -> Result<bool, AppError> {
let public_bytes = BASE64_STANDARD
.decode(public_key)
.map_err(|e| AppError::Base64DecodeError(e))?;
let public_array: [u8; 32] = public_bytes
.as_slice()
.try_into()
.map_err(|_| AppError::LoadPublicKeyError)?;
let public_key =
VerifyingKey::from_bytes(&public_array).map_err(|_| AppError::LoadPublicKeyError)?;
let signature_bytes = BASE64_STANDARD
.decode(signature)
.map_err(|e| AppError::Base64DecodeError(e))?;
let signature_array: [u8; 64] = signature_bytes
.as_slice()
.try_into()
.map_err(|_| AppError::LoadSignatureError)?;
let signature = Signature::from_bytes(&signature_array);
Ok(public_key.verify(data, &signature).is_ok())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sign_and_verify() {
let identity = Identity::new();
let data = b"hello meshdrop";
let sig = identity.sign(data);
// 用独立函数验证(模拟从网络收到公钥字符串的场景)
let pub_key = identity.public_key_base64();
let valid = verify(&pub_key, data, &sig).unwrap();
assert!(valid);
// 错误数据应该验证失败
let invalid = verify(&pub_key, b"wrong data", &sig).unwrap();
assert!(!invalid);
}
#[test]
fn test_from_base64_roundtrip() {
// 生成 → 导出 → 导入 → 再签名验证
let identity1 = Identity::new();
let priv_b64 = identity1.private_key_base64();
let pub_b64 = identity1.public_key_base64();
let identity2 = Identity::from_base64(&priv_b64, &pub_b64).unwrap();
let data = b"roundtrip test";
let sig = identity2.sign(data);
let valid = verify(&pub_b64, data, &sig).unwrap();
assert!(valid);
}
}

2
rust/src/security/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod cert;
pub mod identity;

View File

@@ -0,0 +1,81 @@
use std::{path::PathBuf, sync::Arc};
use dashmap::DashMap;
use crate::{
config::Config,
transfer::model::{CanceledBy, Transfer, TransferStatus},
};
pub fn get_history_path(config: &Arc<Config>) -> PathBuf {
let mut path = config.get_config_dir();
path.push("history.json");
path
}
pub fn save_history(config: &Arc<Config>, transfers: &DashMap<String, Transfer>) {
if !config.get_save_history() {
return;
}
let history_path = get_history_path(config);
let temp_path = history_path.with_extension("json.tmp");
// 取出所有值形成列表
let history_list: Vec<Transfer> = transfers.iter().map(|kv| kv.value().clone()).collect();
let json_data = match serde_json::to_string_pretty(&history_list) {
Ok(data) => data,
Err(e) => {
tracing::error!("Failed to marshal history: {}", e);
return;
}
};
if let Err(e) = std::fs::write(&temp_path, json_data) {
tracing::error!("Failed to write temp history file: {}", e);
return;
}
if let Err(e) = std::fs::rename(&temp_path, &history_path) {
tracing::error!("Failed to rename temp history file: {}", e);
let _ = std::fs::remove_file(temp_path);
return;
}
tracing::info!("History saved successfully to {:?}", history_path);
}
pub fn load_history(config: &Arc<Config>, transfers: &DashMap<String, Transfer>) {
let history_path = get_history_path(config);
if !history_path.exists() {
return;
}
let file = match std::fs::File::open(&history_path) {
Ok(f) => f,
Err(e) => {
tracing::warn!("Could not open history file: {}", e);
return;
}
};
let history: Vec<Transfer> = match serde_json::from_reader(file) {
Ok(h) => h,
Err(e) => {
tracing::error!("Failed to parse history: {}", e);
return;
}
};
for mut transfer in history {
// 在加载的时候,如果发现状态依然是 Pending/Active直接重置为 Canceled
if transfer.status == TransferStatus::Pending || transfer.status == TransferStatus::Active {
transfer.status = TransferStatus::Canceled(CanceledBy::Receiver);
}
transfers.insert(transfer.id.clone(), transfer);
}
tracing::info!("History loaded successfully, {} items", transfers.len());
}

79
rust/src/transfer/mod.rs Normal file
View File

@@ -0,0 +1,79 @@
use std::sync::Arc;
use crate::{
event::AppEvent,
transfer::{model::Transfer, store::TransferStore},
};
pub mod history;
pub mod model;
pub mod progress;
pub mod receiver;
pub mod sender;
pub mod store;
pub mod tar_size_counter;
pub struct TransferService {
pub(super) config: Arc<crate::config::Config>,
pub(super) transfers: Arc<TransferStore>,
pub(super) trust: Arc<crate::trust::TrustStore>,
}
impl TransferService {
pub fn new(
config: Arc<crate::config::Config>,
trust: Arc<crate::trust::TrustStore>,
tx: tokio::sync::broadcast::Sender<AppEvent>,
) -> Self {
let transfers = Arc::new(TransferStore::new(config.clone(), tx.clone()));
transfers.start_speed_sampler();
Self {
config,
transfers,
trust,
}
}
pub async fn start(&self) -> u16 {
let listener = tokio::net::TcpListener::bind("0.0.0.0:0")
.await
.expect("Failed to bind transfer server");
let port = listener.local_addr().unwrap().port();
drop(listener); // 释放端口,以便 axum_server 重新绑定
tracing::info!("Transfer server listening on port {}", port);
let config = self.config.clone();
let transfers = self.transfers.clone();
let trust = self.trust.clone();
tokio::spawn(async move {
receiver::start_server(config, transfers, port, trust).await;
});
port
}
pub fn get_transfers(&self) -> Vec<Transfer> {
self.transfers.get_all_transfers()
}
pub fn clear_transfers(&self) {
self.transfers.clear();
}
pub fn make_decision(&self, id: &str, accepted: bool, save_path: &str) -> bool {
self.transfers.make_decision(id, accepted, save_path)
}
pub fn cancel(&self, id: &str) -> bool {
self.transfers.cancel(id)
}
pub fn delete(&self, id: &str) -> bool {
self.transfers.delete(id)
}
}
impl Drop for TransferService {
fn drop(&mut self) {
self.transfers.shutdown();
}
}

View File

@@ -0,0 +1,88 @@
use serde::{Deserialize, Serialize};
use crate::discovery::model::Peer;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum TransferStatus {
#[serde(rename = "pending")]
Pending,
#[serde(rename = "accepted")]
Accepted,
#[serde(rename = "rejected")]
Rejected,
#[serde(rename = "completed")]
Completed,
#[serde(rename = "error")]
Error,
#[serde(rename = "canceled")]
Canceled(CanceledBy),
#[serde(rename = "active")]
Active,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TransferType {
#[serde(rename = "send")]
Send,
#[serde(rename = "receive")]
Receive,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ContentType {
#[serde(rename = "file")]
File,
#[serde(rename = "text")]
Text,
#[serde(rename = "folder")]
Folder,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum CanceledBy {
#[serde(rename = "sender")]
Sender,
#[serde(rename = "receiver")]
Receiver,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Transfer {
pub id: String,
pub create_time: f64,
pub sender: Peer,
pub sender_ip: String,
pub file_name: String,
pub file_size: f64,
pub save_path: String,
pub status: TransferStatus,
#[serde(rename = "type")]
pub r#type: TransferType, // type 是 Rust 关键字,用 r# 转义
pub content_type: ContentType,
pub text: String,
pub error_msg: String,
pub token: String,
#[serde(default)]
pub progress: f64,
#[serde(default)]
pub last_read_time: i64,
#[serde(default)]
pub speed: f64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TransferAskResponse {
pub id: String,
pub accepted: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TransferUploadResponse {
pub id: String,
pub message: String,
pub status: TransferStatus,
}

View File

@@ -0,0 +1,94 @@
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::sync::CancellationToken;
use crate::transfer::store::TransferStore;
pub struct ProgressMonitor<T> {
inner: T,
current_size: u64,
transfer_id: String,
transfers: Arc<TransferStore>,
cancel_token: Option<CancellationToken>,
}
impl<T: Unpin> ProgressMonitor<T> {
pub fn new(inner: T, transfer_id: String, transfers: Arc<TransferStore>) -> Self {
Self {
inner,
current_size: 0,
transfer_id,
transfers,
cancel_token: None,
}
}
/// 附加一个取消令牌。当令牌被取消时poll_read 会返回 IO 错误使上层操作中断。
pub fn with_cancel(mut self, token: CancellationToken) -> Self {
self.cancel_token = Some(token);
self
}
}
impl<T: AsyncRead + Unpin> AsyncRead for ProgressMonitor<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
// 在每次读取前同步检查取消状态。
// is_cancelled() 是纯内存检查(原子读),不需要 await可以在同步的 poll 函数中安全调用。
if self.cancel_token.as_ref().is_some_and(|t| t.is_cancelled()) {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"Transfer canceled",
)));
}
let before_len = buf.filled().len();
match Pin::new(&mut self.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let after_len = buf.filled().len();
let bytes_read = after_len - before_len;
if bytes_read > 0 {
self.current_size += bytes_read as u64;
self.transfers
.update_progress(&self.transfer_id, self.current_size);
}
Poll::Ready(Ok(()))
}
other => other,
}
}
}
impl<T: AsyncWrite + Unpin> AsyncWrite for ProgressMonitor<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match Pin::new(&mut self.inner).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
self.current_size += n as u64;
self.transfers
.update_progress(&self.transfer_id, self.current_size);
}
Poll::Ready(Ok(n))
}
other => other,
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}

View File

@@ -0,0 +1,346 @@
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use crate::config::Config;
use crate::transfer::model::{
CanceledBy, ContentType, Transfer, TransferAskResponse, TransferStatus, TransferType,
TransferUploadResponse,
};
use crate::transfer::progress::ProgressMonitor;
use crate::transfer::store::TransferStore;
use crate::trust::TrustStore;
use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::http::StatusCode;
use axum::routing::{post, put};
use axum::{Json, Router};
use axum_server::tls_rustls::RustlsConfig;
use futures_util::TryStreamExt;
use serde::Deserialize;
use tokio::io::AsyncWriteExt;
use tokio_stream::StreamExt;
#[derive(Clone)]
pub struct AxumState {
pub config: Arc<Config>,
pub transfers: Arc<TransferStore>,
pub trust: Arc<TrustStore>,
}
pub async fn start_server(
config: Arc<Config>,
transfers: Arc<TransferStore>,
port: u16,
trust: Arc<TrustStore>,
) {
let state = AxumState {
config: config.clone(),
transfers,
trust,
};
let app = Router::new()
.route("/transfer/ask", post(handle_ask))
.route("/transfer/upload/{id}", put(handle_upload))
.with_state(state);
let cert_pem = config.get_cert_pem();
let key_pem = config.get_key_pem();
let tls_config = RustlsConfig::from_pem(cert_pem.into_bytes(), key_pem.into_bytes())
.await
.expect("Failed to load TLS config");
let addr = std::net::SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
if config.get_enable_tls() {
tracing::info!("Transfer server (HTTPS) listening on {}", addr);
axum_server::bind_rustls(addr, tls_config)
.serve(app.into_make_service())
.await
.unwrap();
} else {
axum_server::bind(addr)
.serve(app.into_make_service())
.await
.unwrap();
}
}
async fn handle_ask(
State(state): State<AxumState>,
Json(mut transfer): Json<Transfer>,
) -> Json<TransferAskResponse> {
// 如果已存在则拒绝
if let Some(_) = state.transfers.get(&transfer.id) {
return Json(TransferAskResponse {
id: transfer.id,
accepted: false,
token: None,
message: Some("Transfer already exists".to_string()),
});
}
tracing::info!(
"Received transfer request: {} from {}",
transfer.file_name,
transfer.sender.name
);
let token = uuid::Uuid::new_v4().to_string();
let id = transfer.id.clone();
transfer.r#type = TransferType::Receive;
state.transfers.insert(transfer.clone());
let sender_peer_id = transfer.sender.id.clone();
let auto_accept = state.config.get_auto_accept() || state.trust.is_trusted(&sender_peer_id);
let accepted = if auto_accept {
true
} else {
let rx = state.transfers.register_decision(&id);
rx.await.unwrap_or(false) // 无限等待channel关闭时返回false
};
if accepted {
transfer.token = token.clone();
state.transfers.insert(transfer.clone());
state.transfers.update_status(&id, TransferStatus::Accepted);
Json(TransferAskResponse {
id,
accepted: true,
token: Some(token),
message: None,
})
} else {
state.transfers.update_status(&id, TransferStatus::Rejected);
Json(TransferAskResponse {
id,
accepted: false,
token: None,
message: Some("Rejected".to_string()),
})
}
}
#[derive(Debug, Deserialize)]
pub struct UploadQuery {
pub token: String,
}
async fn handle_upload(
State(state): State<AxumState>,
Path(id): Path<String>,
Query(query): Query<UploadQuery>,
body: Body,
) -> (StatusCode, Json<TransferUploadResponse>) {
let (content_type, file_name) = {
let transfer = match state.transfers.get(&id) {
Some(t) => t,
None => {
return (
StatusCode::NOT_FOUND,
Json(TransferUploadResponse {
id,
message: "Transfer not found".to_string(),
status: TransferStatus::Error,
}),
);
}
};
if transfer.token != query.token {
return (
StatusCode::UNAUTHORIZED,
Json(TransferUploadResponse {
id,
message: "Invalid token".to_string(),
status: TransferStatus::Error,
}),
);
}
(transfer.content_type.clone(), transfer.file_name.clone())
};
state.transfers.update_status(&id, TransferStatus::Active);
let result = match content_type {
ContentType::File => {
// 确定保存路径
let save_dir = &state.config.get_save_path();
let file_path = resolve_filename(std::path::Path::new(save_dir), &file_name);
receive_file(&id, &state.transfers.clone(), body, &file_path).await
}
ContentType::Text => receive_text(&id, &state.transfers.clone(), body).await,
ContentType::Folder => {
let save_dir = &state.config.get_save_path();
let folder_path = resolve_filename(std::path::Path::new(save_dir), &file_name);
receive_folder(&id, state.transfers.clone(), body, &folder_path).await
}
};
match result {
Ok(()) => {
state
.transfers
.update_status(&id, TransferStatus::Completed);
tracing::info!("Transfer completed: {}", id);
(
StatusCode::OK,
Json(TransferUploadResponse {
id,
message: "Transfer completed".to_string(),
status: TransferStatus::Completed,
}),
)
}
Err(message) => {
let status = if message == "Transfer canceled" {
TransferStatus::Canceled(CanceledBy::Receiver)
} else {
TransferStatus::Error
};
state.transfers.update_status(&id, status.clone());
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(TransferUploadResponse {
id,
message,
status,
}),
)
}
}
}
async fn receive_file(
transfer_id: &str,
transfers: &Arc<TransferStore>,
body: Body,
file_path: &std::path::Path,
) -> Result<(), String> {
let cancel_token = transfers.register_cancel(transfer_id);
let mut file = match tokio::fs::File::create(&file_path).await {
Ok(f) => f,
Err(_) => {
return Err("Failed to create file".to_string());
}
};
let mut stream = body.into_data_stream();
let mut current_size: u64 = 0;
loop {
tokio::select! {
chunk = stream.next() =>{
match chunk {
Some(Ok(bytes)) => {
current_size+=bytes.len() as u64;
transfers.update_progress(transfer_id, current_size);
file.write_all(&bytes).await.map_err(|_| "Failed to write chunk".to_string())?;
}
Some(Err(e)) => return Err(format!("Stream error: {}", e)),
None=>break,
}
}
_=cancel_token.cancelled()=>{
// 取消任务
let _ = tokio::fs::remove_file(file_path).await;
return Err("Transfer canceled".to_string());
}
}
}
transfers.remove_cancel_token(transfer_id);
Ok(())
}
async fn receive_text(
transfer_id: &str,
transfers: &Arc<TransferStore>,
body: Body,
) -> Result<(), String> {
let cancel_token = transfers.register_cancel(transfer_id);
let mut stream = body.into_data_stream();
let mut text_bytes = Vec::new();
let mut current_size: u64 = 0;
loop {
tokio::select! {
chunk = stream.next() =>{
match chunk {
Some(Ok(bytes)) => {
current_size+=bytes.len() as u64;
transfers.update_progress(transfer_id, current_size);
text_bytes.extend_from_slice(&bytes);
}
Some(Err(e)) => return Err(format!("Stream error: {}", e)),
None=>break,
}
}
_=cancel_token.cancelled()=>{
// 取消任务
return Err("Transfer canceled".to_string());
}
}
}
let text = String::from_utf8_lossy(&text_bytes);
transfers.update_text(transfer_id, &text);
Ok(())
}
fn resolve_filename(dir: &std::path::Path, name: &str) -> std::path::PathBuf {
let path = dir.join(name);
if !path.exists() {
return path;
}
let stem = std::path::Path::new(name)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(name);
let ext = std::path::Path::new(name)
.extension()
.and_then(|s| s.to_str());
for i in 1.. {
let new_name = match ext {
Some(e) => format!("{} ({}).{}", stem, i, e),
None => format!("{} ({})", stem, i),
};
let new_path = dir.join(new_name);
if !new_path.exists() {
return new_path;
}
}
unreachable!()
}
async fn receive_folder(
transfer_id: &str,
transfers: Arc<TransferStore>,
body: Body,
folder_path: &std::path::Path,
) -> Result<(), String> {
let cancel_token = transfers.register_cancel(transfer_id);
let stream = body
.into_data_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
let stream_reader = tokio_util::io::StreamReader::new(stream);
// 将 CancellationToken 附加到 ProgressMonitor。
// 在 poll_read 内同步检查 is_cancelled()
// 取消时返回 Poll::Ready(Err(...)),使 unpack() 因 IO 错误而中断。
let monitor = ProgressMonitor::new(stream_reader, transfer_id.to_owned(), transfers.clone())
.with_cancel(cancel_token);
let archive = async_tar::Archive::new(monitor);
let result = archive.unpack(&folder_path).await;
transfers.remove_cancel_token(transfer_id);
result.map_err(|e| {
// 如果是取消导致的错误,清理已创建的目录
if e.kind() == std::io::ErrorKind::Interrupted {
let _ = std::fs::remove_dir_all(folder_path);
"Transfer canceled".to_string()
} else {
format!("Unpack error: {}", e)
}
})
}

404
rust/src/transfer/sender.rs Normal file
View File

@@ -0,0 +1,404 @@
use std::{path::Path, sync::Arc};
use async_tar::Builder;
use reqwest::Client;
use tokio::io::duplex;
use tokio_util::io::ReaderStream;
use crate::{
discovery::model::Peer,
error::AppError,
transfer::{
model::{
ContentType, Transfer, TransferAskResponse, TransferStatus, TransferType,
TransferUploadResponse,
},
progress::ProgressMonitor,
store::TransferStore,
tar_size_counter::{self},
},
};
pub async fn ask(
client: &Client,
target: Peer,
target_ip: &str,
transfer: &Transfer,
) -> Result<TransferAskResponse, AppError> {
let ask_url = format!(
"{}://{}:{}/transfer/ask",
if target.enable_tls { "https" } else { "http" },
target_ip,
target.port
);
tracing::info!("Sending transfer ask to {}", ask_url);
let resp = client
.post(&ask_url)
.json(&transfer)
.send()
.await
.map_err(|e| AppError::Network(e.to_string()))?;
let ask_resp: TransferAskResponse = resp
.json()
.await
.map_err(|e| AppError::Network(e.to_string()))?;
return Ok(ask_resp);
}
pub async fn upload<T: Into<reqwest::Body>>(
transfer_id: &str,
transfers: &Arc<TransferStore>,
client: &Client,
target: Peer,
target_ip: &str,
token: &str,
body: T,
) -> Result<TransferUploadResponse, AppError> {
transfers.update_status(&transfer_id, TransferStatus::Active);
let upload_url = format!(
"{}://{}:{}/transfer/upload/{}?token={}",
if target.enable_tls { "https" } else { "http" },
target_ip,
target.port,
transfer_id,
token
);
tracing::info!("Uploading file to {}", upload_url);
let resp = client
.put(&upload_url)
.body(body)
.send()
.await
.map_err(|e| AppError::Network(e.to_string()))?;
let upload_resp: TransferUploadResponse = resp
.json()
.await
.map_err(|e| AppError::Network(e.to_string()))?;
Ok(upload_resp)
}
impl super::TransferService {
pub async fn send_file(
&self,
target: Peer,
target_ip: &str,
sender: Peer,
file_path: &str,
) -> Result<(), AppError> {
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.build()
.map_err(|e| AppError::Network(e.to_string()))?;
let path = Path::new(file_path);
let file_name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string();
let metadata = tokio::fs::metadata(path)
.await
.map_err(|e| AppError::IoError(e))?;
let id = uuid::Uuid::new_v4().to_string();
let sender_ip = {
let socket = std::net::UdpSocket::bind("0.0.0.0:0").unwrap();
if let Ok(_) = socket.connect(target_ip) {
socket.local_addr().unwrap().ip().to_string()
} else {
String::new()
}
};
let transfer = Transfer {
id: id.clone(),
create_time: chrono::Utc::now().timestamp() as f64,
sender,
sender_ip,
file_name,
file_size: metadata.len() as f64,
save_path: String::new(),
status: TransferStatus::Pending,
r#type: TransferType::Send,
content_type: ContentType::File,
text: String::new(),
error_msg: String::new(),
token: String::new(),
progress: 0.0,
speed: 0.0,
last_read_time: 0,
};
// 保存
self.transfers.insert(transfer.clone());
let ask_resp = ask(&client, target.clone(), target_ip, &transfer).await?;
if !ask_resp.accepted {
self.transfers.update_status(&id, TransferStatus::Rejected);
tracing::info!("Transfer rejected: {}", ask_resp.id);
return Ok(());
}
let token = ask_resp.token.unwrap_or_default();
tracing::info!("Transfer accepted: {}", transfer.id);
let file = tokio::fs::File::open(path)
.await
.map_err(|e| AppError::IoError(e))?;
let cancel_token = self.transfers.register_cancel(&id);
let monitor = ProgressMonitor::new(file, id.clone(), self.transfers.clone())
.with_cancel(cancel_token.clone());
let stream = ReaderStream::new(monitor);
let body = reqwest::Body::wrap_stream(stream);
let upload_resp = upload(
&id,
&self.transfers,
&client,
target.clone(),
target_ip,
&token,
body,
)
.await
.map_err(|e| {
self.transfers.remove_cancel_token(&id);
if cancel_token.is_cancelled() {
// 如果这里发生错误时 cancel_token 已经被取消了,则认为错误是由任务取消引起的
self.transfers.update_status(
&id,
TransferStatus::Canceled(crate::transfer::model::CanceledBy::Sender),
);
AppError::Canceled(id.to_string())
} else {
// 其他错误
self.transfers.update_status(&id, TransferStatus::Error);
e
}
})?;
self.transfers
.update_status(&id, upload_resp.status.clone());
tracing::info!("Transfer upload response: {}", upload_resp.message);
Ok(())
}
pub async fn send_text(
&self,
target: Peer,
target_ip: &str,
sender: Peer,
text: &str,
) -> Result<(), AppError> {
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.build()
.map_err(|e| AppError::Network(e.to_string()))?;
let id = uuid::Uuid::new_v4().to_string();
let sender_ip = {
let socket = std::net::UdpSocket::bind("0.0.0.0:0").unwrap();
if let Ok(_) = socket.connect(target_ip) {
socket.local_addr().unwrap().ip().to_string()
} else {
String::new()
}
};
let transfer = Transfer {
id: id.clone(),
create_time: chrono::Utc::now().timestamp() as f64,
sender,
sender_ip,
file_name: "text".to_string(),
file_size: text.len() as f64,
save_path: String::new(),
status: TransferStatus::Pending,
r#type: TransferType::Send,
content_type: ContentType::Text,
text: "".to_string(), // 保持流程一致性
error_msg: String::new(),
token: String::new(),
progress: 0.0,
speed: 0.0,
last_read_time: 0,
};
// 保存
self.transfers.insert(transfer.clone());
let ask_resp = ask(&client, target.clone(), target_ip, &transfer).await?;
if !ask_resp.accepted {
self.transfers.update_status(&id, TransferStatus::Rejected);
tracing::info!("Transfer rejected: {}", ask_resp.id);
return Ok(());
}
let token = ask_resp.token.unwrap_or_default();
tracing::info!("Transfer accepted: {}", transfer.id);
let cursor = std::io::Cursor::new(text.to_string());
let cancel_token = self.transfers.register_cancel(&id);
let monitor = ProgressMonitor::new(cursor, id.clone(), self.transfers.clone())
.with_cancel(cancel_token.clone());
let stream = ReaderStream::new(monitor);
let body = reqwest::Body::wrap_stream(stream);
let upload_resp = upload(
&id,
&self.transfers,
&client,
target.clone(),
target_ip,
&token,
body,
)
.await
.map_err(|e| {
self.transfers.remove_cancel_token(&id);
if cancel_token.is_cancelled() {
// 如果这里发生错误时 cancel_token 已经被取消了,则认为错误是由任务取消引起的
self.transfers.update_status(
&id,
TransferStatus::Canceled(crate::transfer::model::CanceledBy::Sender),
);
AppError::Canceled(id.to_string())
} else {
// 其他错误
self.transfers.update_status(&id, TransferStatus::Error);
e
}
})?;
self.transfers
.update_status(&id, upload_resp.status.clone());
tracing::info!("Transfer upload response: {}", upload_resp.message);
Ok(())
}
pub async fn send_folder(
&self,
target: Peer,
target_ip: &str,
sender: Peer,
folder_path: &str,
) -> Result<(), AppError> {
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.build()
.map_err(|e| AppError::Network(e.to_string()))?;
let path = Path::new(folder_path);
let folder_name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string();
let total_size = tar_size_counter::get_folder_tar_size(path, &folder_name).await?;
let id = uuid::Uuid::new_v4().to_string();
let sender_ip = {
let socket = std::net::UdpSocket::bind("0.0.0.0:0").unwrap();
if let Ok(_) = socket.connect(target_ip) {
socket.local_addr().unwrap().ip().to_string()
} else {
String::new()
}
};
let transfer = Transfer {
id: id.clone(),
create_time: chrono::Utc::now().timestamp() as f64,
sender,
sender_ip,
file_name: folder_name.clone(),
file_size: total_size as f64,
save_path: String::new(),
status: TransferStatus::Pending,
r#type: TransferType::Send,
content_type: ContentType::Folder,
text: String::new(),
error_msg: String::new(),
token: String::new(),
progress: 0.0,
speed: 0.0,
last_read_time: 0,
};
// 保存
self.transfers.insert(transfer.clone());
let ask_resp = ask(&client, target.clone(), target_ip, &transfer).await?;
if !ask_resp.accepted {
self.transfers.update_status(&id, TransferStatus::Rejected);
tracing::info!("Transfer rejected: {}", ask_resp.id);
return Ok(());
}
let token = ask_resp.token.unwrap_or_default();
tracing::info!("Transfer accepted: {}", transfer.id);
let (tx, rx) = duplex(65536);
let path_clone = path.to_path_buf();
tokio::spawn(async move {
let mut builder = Builder::new(tx);
match builder.append_dir_all(".", path_clone).await {
Ok(_) => {
let _ = builder.into_inner().await;
tracing::info!("Tar stream build finished");
}
Err(e) => {
tracing::error!("Error building tar stream: {:?}", e);
}
}
});
let cancel_token = self.transfers.register_cancel(&id);
let monitor = ProgressMonitor::new(rx, id.clone(), self.transfers.clone())
.with_cancel(cancel_token.clone());
let stream = ReaderStream::new(monitor);
let body = reqwest::Body::wrap_stream(stream);
let upload_resp = upload(
&id,
&self.transfers,
&client,
target.clone(),
target_ip,
&token,
body,
)
.await
.map_err(|e| {
self.transfers.remove_cancel_token(&id);
if cancel_token.is_cancelled() {
// 如果这里发生错误时 cancel_token 已经被取消了,则认为错误是由任务取消引起的
self.transfers.update_status(
&id,
TransferStatus::Canceled(crate::transfer::model::CanceledBy::Sender),
);
AppError::Canceled(id.to_string())
} else {
// 其他错误
self.transfers.update_status(&id, TransferStatus::Error);
e
}
})?;
self.transfers
.update_status(&id, upload_resp.status.clone());
tracing::info!("Transfer upload response: {}", upload_resp.message);
Ok(())
}
}

271
rust/src/transfer/store.rs Normal file
View File

@@ -0,0 +1,271 @@
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::config::Config;
use crate::event::AppEvent;
use crate::transfer::{
history,
model::{Transfer, TransferStatus},
};
use dashmap::DashMap;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
#[derive(Clone, Copy)]
struct SpeedSnapshot {
last_bytes: f64,
last_at: Instant,
ema_speed: f64,
stall_ticks: u8,
}
pub struct TransferStore {
pub(crate) inner: DashMap<String, Transfer>,
config: Arc<Config>,
decision: DashMap<String, oneshot::Sender<bool>>,
cancel_tokens: DashMap<String, CancellationToken>,
event_tx: tokio::sync::broadcast::Sender<AppEvent>,
speed_snapshots: DashMap<String, SpeedSnapshot>,
}
impl TransferStore {
pub fn new(config: Arc<Config>, event_tx: tokio::sync::broadcast::Sender<AppEvent>) -> Self {
let store = Self {
inner: DashMap::new(),
config,
decision: DashMap::new(),
cancel_tokens: DashMap::new(),
event_tx,
speed_snapshots: DashMap::new(),
};
store.load();
store
}
pub fn get(&self, id: &str) -> Option<dashmap::mapref::one::Ref<'_, String, Transfer>> {
self.inner.get(id)
}
pub fn get_all_transfers(&self) -> Vec<Transfer> {
self.inner.iter().map(|r| r.value().clone()).collect()
}
pub fn clear(&self) {
self.inner.retain(|_k, v| {
!(v.status == TransferStatus::Completed
|| matches!(v.status, TransferStatus::Canceled(_))
|| v.status == TransferStatus::Rejected
|| v.status == TransferStatus::Error)
});
let _ = self.event_tx.send(AppEvent::TransferClear);
self.save();
}
/// 更新状态并保存(每次 status 变更都落盘)
pub fn update_status(&self, id: &str, status: TransferStatus) {
if let Some(mut t) = self.inner.get_mut(id) {
t.status = status;
}
self.save();
let _ = self.event_tx.send(AppEvent::TransferStatusChanged {
transfer: self.inner.get(id).unwrap().value().clone(),
});
}
pub fn update_progress(&self, id: &str, progress: u64) {
if let Some(mut t) = self.inner.get_mut(id) {
t.progress = progress as f64;
}
}
pub fn start_speed_sampler(self: &Arc<Self>) {
let store = Arc::clone(self);
tokio::spawn(async move {
let mut ticker = tokio::time::interval(Duration::from_millis(500));
loop {
ticker.tick().await;
store.sample_and_emit_speeds();
}
});
}
fn sample_and_emit_speeds(&self) {
let ids: Vec<String> = self.inner.iter().map(|r| r.key().clone()).collect();
for id in ids {
let now = Instant::now();
let mut emit: Option<(f64, f64, f64)> = None;
let mut remove_snapshot = false;
if let Some(mut t) = self.inner.get_mut(&id) {
let is_terminal = matches!(
t.status,
TransferStatus::Completed
| TransferStatus::Rejected
| TransferStatus::Error
| TransferStatus::Canceled(_)
);
if is_terminal {
if t.speed != 0.0 {
t.speed = 0.0;
emit = Some((t.progress, t.file_size, 0.0));
}
remove_snapshot = true;
} else {
let current_bytes = t.progress.max(0.0);
let mut snap = self
.speed_snapshots
.get(&id)
.map(|s| *s)
.unwrap_or(SpeedSnapshot {
last_bytes: current_bytes,
last_at: now,
ema_speed: 0.0,
stall_ticks: 0,
});
let dt = now.saturating_duration_since(snap.last_at).as_secs_f64();
if dt > 0.0 {
let delta = (current_bytes - snap.last_bytes).max(0.0);
let inst_speed = delta / dt;
if delta <= 0.0 {
snap.stall_ticks = snap.stall_ticks.saturating_add(1);
if snap.stall_ticks >= 2 {
snap.ema_speed = 0.0;
}
} else {
snap.stall_ticks = 0;
snap.ema_speed = if snap.ema_speed <= 0.0 {
inst_speed
} else {
snap.ema_speed * 0.65 + inst_speed * 0.35
};
}
snap.last_bytes = current_bytes;
snap.last_at = now;
t.speed = snap.ema_speed;
emit = Some((t.progress, t.file_size, t.speed));
}
self.speed_snapshots.insert(id.clone(), snap);
}
}
if remove_snapshot {
self.speed_snapshots.remove(&id);
}
if let Some((progress, total, speed)) = emit {
let _ = self.event_tx.send(AppEvent::TransferProgressChanged {
id: id.clone(),
progress,
total,
speed,
});
}
}
}
fn load(&self) {
history::load_history(&self.config, &self.inner);
}
fn save(&self) {
// 内部判断是否保存历史记录
history::save_history(&self.config, &self.inner);
tracing::info!("Transfer history saved");
}
/// 退出前:把所有 Pending/Active 置为 Canceled然后保存
pub fn shutdown(&self) {
for mut entry in self.inner.iter_mut() {
let t = entry.value_mut();
if t.status == TransferStatus::Pending || t.status == TransferStatus::Active {
let canceled_by = match t.r#type {
crate::transfer::model::TransferType::Send => {
crate::transfer::model::CanceledBy::Sender
}
crate::transfer::model::TransferType::Receive => {
crate::transfer::model::CanceledBy::Receiver
}
};
t.status = TransferStatus::Canceled(canceled_by);
}
}
self.save();
}
/// 插入新 transfer 并保存
pub fn insert(&self, transfer: Transfer) {
let id = transfer.id.clone();
self.inner.insert(id.clone(), transfer);
self.save();
let _ = self.event_tx.send(AppEvent::TransferAdded {
transfer: self.inner.get(&id).unwrap().value().clone(),
});
}
/// 注册用户决策 oneshot 通道
pub fn register_decision(&self, id: &str) -> oneshot::Receiver<bool> {
let (tx, rx) = oneshot::channel();
self.decision.insert(id.to_string(), tx);
rx
}
/// 用户决策
pub fn make_decision(&self, id: &str, accepted: bool, save_path: &str) -> bool {
if let Some((_, tx)) = self.decision.remove(id) {
let _ = tx.send(accepted);
if accepted {
// 把 save_path 存到 transfer 中
if let Some(mut t) = self.inner.get_mut(id) {
t.save_path = save_path.to_string();
}
}
true
} else {
false
}
}
/// 注册用户取消任务 token
pub fn register_cancel(&self, id: &str) -> CancellationToken {
let token = CancellationToken::new();
self.cancel_tokens.insert(id.to_string(), token.clone());
token
}
/// 用户取消任务
pub fn cancel(&self, id: &str) -> bool {
if let Some((_, token)) = self.cancel_tokens.remove(id) {
token.cancel();
true
} else {
false
}
}
pub fn delete(&self, id: &str) -> bool {
self.inner.remove(id);
self.speed_snapshots.remove(id);
self.save();
let _ = self
.event_tx
.send(AppEvent::TransferRemoved { id: id.to_string() });
true
}
pub fn remove_cancel_token(&self, id: &str) {
self.cancel_tokens.remove(id);
}
pub fn update_text(&self, id: &str, text: &str) {
if let Some(mut t) = self.inner.get_mut(id) {
t.text = text.to_string();
}
}
}

View File

@@ -0,0 +1,48 @@
use std::{path::Path, task::Poll};
use async_tar::Builder;
use tokio::io::AsyncWrite;
pub struct TarSizeCounter {
size: u64,
}
impl TarSizeCounter {
pub fn new() -> Self {
Self { size: 0 }
}
}
impl AsyncWrite for TarSizeCounter {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.size += buf.len() as u64;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
/// 精确计算打包该文件夹后的 tar 流大小
pub async fn get_folder_tar_size(path: &Path, folder_name: &str) -> std::io::Result<u64> {
let counter = TarSizeCounter::new();
let mut builder = Builder::new(counter);
builder.append_dir_all(folder_name, path).await?;
let counter = builder.into_inner().await?;
Ok(counter.size)
}

96
rust/src/trust/mod.rs Normal file
View File

@@ -0,0 +1,96 @@
use dashmap::DashMap;
use std::{path::PathBuf, sync::Arc};
/// 信任列表peer_id → public_keybase64
pub struct TrustStore {
inner: DashMap<String, String>, // id → public_key
path: PathBuf, // 持久化文件路径
}
impl TrustStore {
pub fn new(config: Arc<crate::config::Config>) -> Self {
let path = config.get_config_dir().join("trust.json");
let mut store = TrustStore {
inner: DashMap::new(),
path,
};
store.load();
store
}
pub fn is_trusted(&self, peer_id: &str) -> bool {
self.inner.contains_key(peer_id)
}
/// 获取信任的公钥,用于 mismatch 检测
pub fn get_trusted_key(&self, peer_id: &str) -> Option<String> {
self.inner.get(peer_id).map(|v| v.value().clone())
}
/// 添加信任(用户主动执行)
pub fn trust(&self, peer_id: &str, public_key: &str) {
self.inner
.insert(peer_id.to_string(), public_key.to_string());
self.save();
}
pub fn untrust(&self, peer_id: &str) -> bool {
let removed = self.inner.remove(peer_id).is_some();
if removed {
self.save();
}
removed
}
// pub fn list(&self) -> Vec<(String, String)> {
// self.inner
// .iter()
// .map(|r| (r.key().clone(), r.value().clone()))
// .collect()
// }
fn save(&self) {
let tmp_path = self.path.with_extension("json.tmp");
let json_data = match serde_json::to_string_pretty(&self.inner) {
Ok(data) => data,
Err(e) => {
tracing::error!("Failed to serialize trust list: {}", e);
return;
}
};
if let Err(e) = std::fs::write(&tmp_path, &json_data) {
tracing::error!("Failed to write temp trust file: {}", e);
return;
}
if let Err(e) = std::fs::rename(&tmp_path, &self.path) {
tracing::error!("Failed to rename trust file: {}", e);
}
}
fn load(&mut self) {
match std::fs::File::open(&self.path) {
// 首次运行,文件不存在,正常情况
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
tracing::debug!("Trust file not found, starting with empty trust list");
}
Err(e) => {
tracing::error!("Failed to open trust file: {}", e);
}
Ok(file) => {
let reader = std::io::BufReader::new(file);
match serde_json::from_reader(reader) {
Ok(data) => {
self.inner = data;
tracing::info!("Trust list loaded, {} entries", self.inner.len());
}
Err(e) => {
tracing::error!("Failed to parse trust file: {}", e);
}
}
}
}
}
}