diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 199d4d93..33a12a77 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -82,6 +82,34 @@ jobs: name: pony-api path: target/release/api + build-auth: + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install latest rust toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + default: true + override: true + + - name: Update apt package index + run: sudo apt-get update + + - name: Install protobuf-compiler + run: sudo apt-get install -y protobuf-compiler + + - name: Build API + run: cargo build --release --bin auth --no-default-features + + - name: Upload API binary + uses: actions/upload-artifact@v4 + with: + name: pony-auth + path: target/release/auth + collect-binaries: runs-on: ubuntu-latest needs: [build-agent, build-api] @@ -111,6 +139,12 @@ jobs: name: pony-api path: collected/api + - name: Download auth + uses: actions/download-artifact@v4 + with: + name: pony-auth + path: collected/auth + - name: Upload combined binaries uses: actions/upload-artifact@v4 with: diff --git a/Cargo.lock b/Cargo.lock index eb2b2163..6392f5a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2112,7 +2112,7 @@ checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "pony" -version = "0.1.48" +version = "0.2.0" dependencies = [ "anyhow", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index ad9bdb42..a11012c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pony" -version = "0.1.48" +version = "0.2.0" edition = "2021" build = "build.rs" @@ -36,7 +36,7 @@ prost = { version = "0.13" } prost-derive = { version = "0.13" } rand = "0.8" reqwest = { version = "0.12", features = ["json"] } -rkyv = { version = "0.7", features = ["std", "alloc", "validation", "uuid"] } +rkyv = { version = "0.7", features = ["std", "alloc", "validation", "uuid", ] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_urlencoded = "0.7" @@ -80,6 +80,10 @@ path = "src/bin/agent/main.rs" name = "api" path = "src/bin/api/main.rs" +[[bin]] +name = "auth" +path = "src/bin/auth/main.rs" + [[bin]] name = "utils" path = "src/bin/utils.rs" diff --git a/dev/dev.sql b/dev/dev.sql index a9b0634f..cfa4cf40 100644 --- a/dev/dev.sql +++ b/dev/dev.sql @@ -185,5 +185,17 @@ ADD COLUMN bonus_days INTEGER DEFAULT NULL; +ALTER TABLE subscriptions +ALTER COLUMN refer_code SET NOT NULL; + +ALTER TABLE subscriptions +ADD CONSTRAINT subscriptions_refer_code_unique UNIQUE (refer_code); + + +ALTER TYPE proto ADD VALUE 'hysteria2'; + +ALTER TABLE connections ADD COLUMN token UUID DEFAULT NULL; + +ALTER TABLE inbounds ADD COLUMN h2 JSONB DEFAULT NULL; diff --git a/dev/user-state.html b/dev/user-state.html index fa87881f..ddd6ea13 100644 --- a/dev/user-state.html +++ b/dev/user-state.html @@ -246,7 +246,7 @@

Subscriptions

const ul = document.createElement('ul'); data.forEach(item => { const li = document.createElement('li'); - li.textContent = item; + li.textContent = ` ${item} `; ul.appendChild(li); }); connsElement.appendChild(ul); @@ -266,7 +266,7 @@

Subscriptions

data.forEach(sub => { const li = document.createElement('li'); li.textContent = - `ID: ${sub.id}, expires: ${sub.expires_at}, deleted: ${sub.is_deleted}, ref:_by ${sub.referred_by}`; + `ID: ${sub.id}, expires: ${sub.expires_at}, deleted: ${sub.is_deleted}, ref_by: ${sub.referred_by}`; ul.appendChild(li); }); subsElement.appendChild(ul); diff --git a/src/bin/agent/core/http.rs b/src/bin/agent/core/http.rs index 6a84286e..c76934c0 100644 --- a/src/bin/agent/core/http.rs +++ b/src/bin/agent/core/http.rs @@ -1,11 +1,11 @@ use async_trait::async_trait; +use pony::http::requests::ConnTypeParam; +use pony::Tag; use reqwest::Client as HttpClient; use reqwest::StatusCode; use reqwest::Url; use pony::http::requests::NodeRequest; -use pony::http::requests::NodeType; -use pony::http::requests::NodeTypeParam; use pony::ConnectionBaseOp; use pony::NodeStorageOp; use pony::SubscriptionOp; @@ -15,11 +15,12 @@ use super::Agent; #[async_trait] pub trait ApiRequests { - async fn register_node( + async fn register_node(&self, _endpoint: String, _token: String) -> Result<()>; + async fn get_connections( &self, - _endpoint: String, - _token: String, - node_type: NodeType, + endpoint: String, + token: String, + proto: Tag, last_update: Option, ) -> Result<()>; } @@ -31,11 +32,11 @@ where C: ConnectionBaseOp + Send + Sync + Clone + 'static, S: SubscriptionOp + Send + Sync + Clone + 'static, { - async fn register_node( + async fn get_connections( &self, endpoint: String, token: String, - node_type: NodeType, + proto: Tag, last_update: Option, ) -> Result<()> { let node = { @@ -46,6 +47,51 @@ where .clone() }; + let env = node.env; + + let conn_type_param = ConnTypeParam { + proto: proto, + last_update: last_update, + env: env, + }; + + let mut endpoint_url = Url::parse(&endpoint)?; + endpoint_url + .path_segments_mut() + .map_err(|_| PonyError::Custom("Invalid API endpoint".to_string()))? + .push("connections"); + let endpoint_str = endpoint_url.to_string(); + + let res = HttpClient::new() + .get(&endpoint_str) + .query(&conn_type_param) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .send() + .await?; + + let status = res.status(); + let body = res.text().await?; + if status.is_success() { + log::debug!("Connections Request Accepted: {:?}", status); + Ok(()) + } else { + log::error!("Connections Request failed: {} - {}", status, body); + Err(PonyError::Custom( + format!("Connections Request failed: {} - {}", status, body).into(), + )) + } + } + + async fn register_node(&self, endpoint: String, token: String) -> Result<()> { + let node = { + let mem = self.memory.read().await; + mem.nodes + .get_self() + .expect("No node available to register") + .clone() + }; + let mut endpoint_url = Url::parse(&endpoint)?; endpoint_url .path_segments_mut() @@ -58,11 +104,6 @@ where Err(e) => log::error!("Error serializing node '{}': {}", node.hostname, e), } - let node_type_param = NodeTypeParam { - node_type: Some(node_type), - last_update, - }; - let node_request = NodeRequest { env: node.env.clone(), hostname: node.hostname.clone(), @@ -77,7 +118,6 @@ where let res = HttpClient::new() .post(&endpoint_str) - .query(&node_type_param) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .json(&node_request) diff --git a/src/bin/agent/core/service.rs b/src/bin/agent/core/service.rs index b8cd0d6c..5eba29f2 100644 --- a/src/bin/agent/core/service.rs +++ b/src/bin/agent/core/service.rs @@ -11,12 +11,13 @@ use tokio::task::JoinHandle; use tokio::time::sleep; use tokio::time::Duration; +use pony::config::h2::H2Settings; +use pony::config::h2::HysteriaServerConfig; use pony::config::settings::AgentSettings; use pony::config::settings::NodeConfig; use pony::config::wireguard::WireguardSettings; use pony::config::xray::Config as XrayConfig; use pony::http::debug; -use pony::http::requests::NodeType; use pony::memory::connection::wireguard::IpAddrMaskSerializable; use pony::memory::connection::wireguard::Param as WgParam; use pony::memory::node::Node; @@ -93,6 +94,32 @@ pub async fn run(settings: AgentSettings) -> Result<()> { (None, None) }; + // Init Hysteria2 + let h2_config = if settings.h2.enabled { + match HysteriaServerConfig::from_file(&settings.h2.path) { + Ok(cfg) => { + if let Err(e) = cfg.validate() { + log::error!("Hysteria2 config validation failed: {}", e); + None + } else { + match H2Settings::try_from(cfg) { + Ok(settings) => Some(settings), + Err(e) => { + log::error!("Hysteria2 validation error: {}", e); + None + } + } + } + } + Err(e) => { + log::error!("Failed to load Hysteria2 config: {}", e); + None + } + } + } else { + None + }; + let subscriber = ZmqSubscriber::new( &settings.zmq.endpoint, &settings.node.uuid, @@ -100,7 +127,7 @@ pub async fn run(settings: AgentSettings) -> Result<()> { ); let node_config = NodeConfig::from_raw(settings.node.clone()); - let node = Node::new(node_config?, xray_config, wg_config.clone()); + let node = Node::new(node_config?, xray_config, wg_config.clone(), h2_config); let memory: Arc> = Arc::new(RwLock::new(MemoryCache::with_node(node.clone()))); @@ -119,7 +146,8 @@ pub async fn run(settings: AgentSettings) -> Result<()> { let snapshot_timestamp = if Path::new(&snapshot_manager.snapshot_path).exists() { match snapshot_manager.load_snapshot().await { Ok(Some(timestamp)) => { - log::info!("Loaded connections snapshot from {}", timestamp); + let count = snapshot_manager.count().await; + log::info!("Loaded connections snapshot from {} {}", timestamp, count); Some(timestamp) } Ok(None) => { @@ -144,10 +172,16 @@ pub async fn run(settings: AgentSettings) -> Result<()> { )); loop { interval.tick().await; - if let Err(e) = snapshot_manager.create_snapshot().await { + if let Err(e) = measure_time( + snapshot_manager.create_snapshot(), + "Snapshot took".to_string(), + ) + .await + { log::error!("Failed to create snapshot: {}", e); } else { - log::info!("Connections snapshot saved successfully"); + let count = snapshot_manager.count().await; + log::info!("Connections snapshot saved successfully {}", count); } } }); @@ -169,33 +203,36 @@ pub async fn run(settings: AgentSettings) -> Result<()> { tokio::time::sleep(std::time::Duration::from_millis(500)).await; - let node_type = if settings.wg.enabled && settings.xray.enabled { - NodeType::All - } else if settings.wg.enabled { - NodeType::Wireguard - } else if settings.xray.enabled { - NodeType::Xray - } else { - panic!("At least Wg or Xray should be enabled"); - }; - if !settings.agent.local { let _ = { let settings = settings.clone(); - log::debug!("Register node task {:?}", node_type); - if let Err(e) = agent - .register_node( - settings.api.endpoint.clone(), - settings.api.token.clone(), - node_type, - snapshot_timestamp, - ) + + match agent + .register_node(settings.api.endpoint.clone(), settings.api.token.clone()) .await { - log::error!( - "-->>Cannot register node, use setting local mode for running no deps\n {:?}", + Ok(_) => { + let tags: Vec<_> = node + .inbounds + .keys() + .filter(|k| !matches!(k, Tag::Hysteria2)) + .collect(); + + for tag in tags { + agent + .get_connections( + settings.api.endpoint.clone(), + settings.api.token.clone(), + *tag, + snapshot_timestamp, + ) + .await? + } + } + Err(e) => log::error!( + "Cannot register node, use setting local mode for running no deps\n {:?}", e - ); + ), } }; } @@ -307,7 +344,7 @@ pub async fn run(settings: AgentSettings) -> Result<()> { let mut mem = agent.memory.write().await; for (tag, _) in node.inbounds { let proto = Proto::new_xray(&tag); - let conn = Connection::new(proto, None); + let conn = Connection::new(proto, None, None); let _ = mem.connections.insert(conn_id, conn.into()); let _ = xray_handler_client.create(&conn_id, tag, None).await; } @@ -329,6 +366,7 @@ pub async fn run(settings: AgentSettings) -> Result<()> { inbound.as_inbound_response(), &node.label, node.address, + &None ) { println!("->> {tag} ➜ {:?}\n", conn); let qrcode = QrCode::new(conn).unwrap(); @@ -382,7 +420,7 @@ pub async fn run(settings: AgentSettings) -> Result<()> { }; let wg_params = WgParam::new(next.clone()); let proto = Proto::new_wg(&wg_params, &node.uuid); - let conn = Connection::new(proto, None); + let conn = Connection::new(proto, None, None); let _ = mem.connections.insert(conn_id, conn.clone().into()); if let Err(e) = wg_api.create(&wg_params.keys.pubkey, next) { diff --git a/src/bin/agent/core/tasks.rs b/src/bin/agent/core/tasks.rs index 775cd485..6351ee05 100644 --- a/src/bin/agent/core/tasks.rs +++ b/src/bin/agent/core/tasks.rs @@ -1,8 +1,10 @@ use async_trait::async_trait; use defguard_wireguard_rs::net::IpAddrMask; +use futures::future::try_join_all; use pony::Proto; use rkyv::AlignedVec; - +use rkyv::Infallible; +use tokio::time::Duration; use tonic::Status; use pony::xray_op::client::HandlerActions; @@ -26,6 +28,7 @@ use super::Agent; pub trait Tasks { async fn run_subscriber(&self) -> Result<()>; async fn handle_message(&self, msg: Message) -> Result<()>; + async fn handle_messages_batch(&self, msg: Vec) -> Result<()>; } #[async_trait] @@ -37,7 +40,6 @@ where { async fn run_subscriber(&self) -> Result<()> { let sub = self.subscriber.clone(); - assert!(self.subscriber.topics.contains(&"all".to_string())); loop { @@ -71,23 +73,36 @@ where } } + if payload_bytes.is_empty() { + log::warn!("SUB: Empty payload, skipping"); + continue; + } + let mut aligned = AlignedVec::new(); aligned.extend_from_slice(&payload_bytes); - let archived = unsafe { rkyv::archived_root::(&aligned) }; + let archived = match { rkyv::check_archived_root::>(&aligned) } { + Ok(a) => a, + Err(e) => { + log::error!("SUB: Invalid rkyv root: {:?}", e); + log::error!("SUB: Payload bytes (hex) = {}", hex::encode(payload_bytes)); + continue; + } + }; - match archived.deserialize(&mut rkyv::Infallible) { - Ok(message) => { - log::debug!("SUB: Successfully deserialized message: {}", message); - if let Err(err) = self.handle_message(message).await { - log::error!("ZMQ SUB: Failed to handle message: {}", err); + match archived.deserialize(&mut Infallible) { + Ok(messages) => { + if let Err(err) = self.handle_messages_batch(messages).await { + log::error!("SUB: Failed to handle messages: {}", err); } } Err(err) => { - log::error!("ZMQ SUB: Failed to deserialize message: {}", err); + log::error!("SUB: Failed to deserialize messages: {}", err); log::error!("SUB: Payload bytes (hex) = {}", hex::encode(payload_bytes)); } } + + tokio::time::sleep(Duration::from_millis(10)).await; } } @@ -95,9 +110,8 @@ where match msg.action { Action::Create | Action::Update => { let conn_id: uuid::Uuid = msg.conn_id.clone().into(); - let tag = msg.tag; - match tag { + match msg.tag { Tag::Wireguard => { let wg = msg .wg @@ -113,7 +127,11 @@ where }; let proto = Proto::new_wg(&wg, &node_id); - let conn = Connection::new(proto, None); + let conn = Connection::new( + proto, + msg.expires_at.map(Into::into), + msg.subscription_id, + ); { let mut mem = self.memory.write().await; @@ -145,22 +163,26 @@ where log::debug!("Created {}", conn); - Ok(()) + return Ok(()); } Tag::VlessTcpReality | Tag::VlessGrpcReality | Tag::VlessXhttpReality | Tag::Vmess => { - let proto = Proto::new_xray(&tag); - let conn = Connection::new(proto, None); + let proto = Proto::new_xray(&msg.tag); + let conn = Connection::new( + proto, + msg.expires_at.map(Into::into), + msg.subscription_id, + ); let client = self.xray_handler_client.as_ref().ok_or_else(|| { PonyError::Grpc(Status::unavailable("Xray handler unavailable")) })?; client - .create(&conn_id.clone(), tag, None) + .create(&conn_id.clone(), msg.tag, None) .await .map_err(|err| { PonyError::Custom(format!( @@ -180,19 +202,23 @@ where )) })?; - Ok(()) + return Ok(()); } Tag::Shadowsocks => { if let Some(password) = msg.password { let proto = Proto::new_ss(&password); - let conn = Connection::new(proto, None); + let conn = Connection::new( + proto, + msg.expires_at.map(Into::into), + msg.subscription_id, + ); let client = self.xray_handler_client.as_ref().ok_or_else(|| { PonyError::Grpc(Status::unavailable("Xray handler unavailable")) })?; client - .create(&conn_id.clone(), tag, Some(password)) + .create(&conn_id.clone(), msg.tag, Some(password)) .await .map_err(|err| { PonyError::Custom(format!( @@ -211,13 +237,16 @@ where )) })?; - Ok(()) + return Ok(()); } else { - Err(PonyError::Custom( + return Err(PonyError::Custom( "Password not provided for Shadowsocks user".into(), - )) + )); } } + Tag::Hysteria2 => { + return Err(PonyError::Custom("Hysteria2 is not supported".into())) + } } } @@ -241,7 +270,7 @@ where let mut mem = self.memory.write().await; let _ = mem.connections.remove(&conn_id); - Ok(()) + return Ok(()); } Tag::VlessTcpReality | Tag::VlessGrpcReality @@ -265,7 +294,10 @@ where let mut mem = self.memory.write().await; let _ = mem.connections.remove(&conn_id); - Ok(()) + return Ok(()); + } + Tag::Hysteria2 => { + return Err(PonyError::Custom("Hysteria2 is not supported".into())) } } } @@ -275,8 +307,21 @@ where .await .map_err(|e| PonyError::Custom(format!("Couldn't reset stat: {}", e)))?; log::debug!("Reset stat for {}", msg.conn_id); - Ok(()) + return Ok(()); } } } + + async fn handle_messages_batch(&self, messages: Vec) -> Result<()> { + log::debug!("Got {} messages", messages.len()); + + let handles: Vec<_> = messages + .into_iter() + .map(|msg| self.handle_message(msg)) + .collect(); + + let _results = try_join_all(handles).await?; + + Ok(()) + } } diff --git a/src/bin/api/core/http/handlers/connection.rs b/src/bin/api/core/http/handlers/connection.rs index ff85f3a1..b018b8d2 100644 --- a/src/bin/api/core/http/handlers/connection.rs +++ b/src/bin/api/core/http/handlers/connection.rs @@ -1,12 +1,19 @@ use chrono::DateTime; use chrono::Utc; use defguard_wireguard_rs::net::IpAddrMask; +use rkyv::to_bytes; +use warp::http::StatusCode; use pony::http::helpers as http; use pony::http::requests::ConnCreateRequest; use pony::http::requests::ConnQueryParam; +use pony::http::requests::ConnTypeParam; use pony::http::requests::ConnUpdateRequest; +use pony::http::IdResponse; use pony::http::IpParseError; +use pony::http::MyRejection; +use pony::http::ResponseMessage; +use pony::memory::tag::ProtoTag; use pony::utils; use pony::zmq::publisher::Publisher as ZmqPublisher; use pony::Connection; @@ -24,6 +31,93 @@ use pony::WgParam; use crate::core::sync::tasks::SyncOp; use crate::core::sync::MemSync; +/// Handler get connection +// GET /connections +pub async fn get_connections_handler( + conn_req: ConnTypeParam, + publisher: ZmqPublisher, + memory: MemSync, +) -> Result +where + N: NodeStorageOp + Sync + Send + Clone + 'static, + C: ConnectionApiOp + + ConnectionBaseOp + + Sync + + Send + + Clone + + 'static + + From + + PartialEq, + Connection: From, + S: SubscriptionOp + + Send + + Sync + + Clone + + 'static + + std::cmp::PartialEq + + std::convert::From, +{ + let mem = memory.memory.read().await; + let proto = conn_req.proto; + let env = conn_req.env; + + let last_update = conn_req.last_update; + + let connections_to_send: Vec<_> = mem + .connections + .iter() + .filter(|(_, conn)| { + if conn.get_deleted() { + return false; + } + + if conn.get_proto().proto() != proto { + return false; + } + + if let Some(ts) = last_update { + conn.get_modified_at().and_utc().timestamp() as u64 >= ts + } else { + true + } + }) + .collect(); + + log::debug!( + "Sending {} {:?} connections to auth", + connections_to_send.len(), + proto + ); + + let messages: Vec<_> = connections_to_send + .iter() + .map(|(conn_id, conn)| conn.as_create_message(conn_id)) + .collect(); + + let bytes = to_bytes::<_, 1024>(&messages).map_err(|e| { + log::error!("Serialization error: {}", e); + warp::reject::custom(MyRejection(Box::new(e))) + })?; + + publisher + .send_binary(&env, bytes.as_ref()) + .await + .map_err(|e| { + log::error!("Publish error: {}", e); + warp::reject::custom(MyRejection(Box::new(e))) + })?; + + let resp = ResponseMessage::> { + status: StatusCode::OK.as_u16(), + message: "Ok".to_string(), + response: None, + }; + Ok(warp::reply::with_status( + warp::reply::json(&resp), + StatusCode::OK, + )) +} + /// Handler creates connection // POST /connection pub async fn create_connection_handler( @@ -50,12 +144,17 @@ where + std::cmp::PartialEq + std::convert::From, { - let env = conn_req.env.clone(); - let conn_id = uuid::Uuid::new_v4(); - let mem = memory.memory.read().await; + if let Err(e) = conn_req.validate() { + return Ok(http::bad_request(&e)); + } + let expired_at: Option> = conn_req + .days + .map(|days| Utc::now() + chrono::Duration::days(days.into())); + + let mem = memory.memory.read().await; if let Some(sub_id) = conn_req.subscription_id { - if let None = mem.subscriptions.find_by_id(&sub_id) { + if mem.subscriptions.find_by_id(&sub_id).is_none() { return Ok(http::bad_request(&format!( "Subscription {} not found", sub_id @@ -63,222 +162,187 @@ where } } - let expired_at: Option> = conn_req - .days - .map(|days| Utc::now() + chrono::Duration::days(days.into())); - - if conn_req.password.is_some() && conn_req.wg.is_some() { - return Ok(http::bad_request( - "Cannot specify both password (Shadowsocks) and wg (WireGuard)", - )); - } - - if !conn_req.proto.is_wireguard() && conn_req.wg.is_some() { - return Ok(http::bad_request( - "Wg params are allowed only for Wireguard proto", - )); - } - - if !conn_req.proto.is_wireguard() && conn_req.node_id.is_some() { - return Ok(http::bad_request( - "node_id param are allowed only for Wireguard proto", - )); - } - - let node_id = if conn_req.proto.is_wireguard() { - match conn_req.node_id { - Some(node_id) => { - let node_valid = mem.nodes.get_by_id(&node_id).is_some_and(|n| { - n.env == conn_req.env && n.inbounds.contains_key(&conn_req.proto) - }); - - if !node_valid { - return Ok(http::bad_request( + let proto = match conn_req.proto { + ProtoTag::Wireguard => { + let node_id = { + match conn_req.node_id { + Some(node_id) => { + let node_valid = mem.nodes.get_by_id(&node_id).is_some_and(|n| { + n.env == conn_req.env && n.inbounds.contains_key(&conn_req.proto) + }); + + if !node_valid { + return Ok(http::bad_request( "node_id doesn't exist, mismatched env or missing WireGuard inbound", )); - } - - Some(node_id) - } - None => { - mem.nodes - .select_least_loaded_node(&conn_req.env, &conn_req.proto, &mem.connections) - } - } - } else { - None - }; + } - if conn_req.password.is_some() && !conn_req.proto.is_shadowsocks() { - return Ok(http::bad_request(&format!( - "Password is only allowed for Shadowsocks, but got {:?}", - conn_req.proto - ))); - } + node_id + } + None => { + if let Some(node_id) = mem.nodes.select_least_loaded_node( + &conn_req.env, + &conn_req.proto, + &mem.connections, + ) { + node_id + } else { + return Ok(http::not_found("Node not found for WireGuard connection")); + } + } + } + }; - if let (Some(wg), Some(node_id)) = (&conn_req.wg, node_id) { - let address_taken = mem.connections.values().any(|c| { - if let Proto::Wireguard { - param, - node_id: existing_node_id, - } = c.get_proto() - { - existing_node_id == node_id && param.address.addr == wg.address.addr + let wg_param = if let Some(wg_param) = conn_req.wg { + if wg_param.address.cidr > 32 { + return Ok(http::bad_request("Invalid CIDR: must be 0..=32")); + } + let address_taken = mem.connections.values().any(|c| { + if let Proto::Wireguard { + param, + node_id: existing_node_id, + } = c.get_proto() + { + existing_node_id == node_id && param.address.addr == wg_param.address.addr + } else { + false + } + }); + if address_taken { + return Ok(http::conflict("Address already taken for this node_id")); + } + if let Some(node) = mem.nodes.get_by_id(&node_id) { + if let Some(inbound) = node.inbounds.get(&conn_req.proto) { + if let Some(wg_settings) = &inbound.wg { + let ip = wg_param + .address + .addr + .parse() + .map_err(|e| warp::reject::custom(IpParseError(e)))?; + if !utils::ip_in_mask(&wg_settings.network, ip) { + return Ok(http::bad_request(&format!( + "Address out of node netmask {}", + wg_settings.network + ))); + } + } + } + } + wg_param } else { - false - } - }); - - if wg.address.cidr > 32 { - return Ok(http::bad_request("Invalid CIDR: must be 0..=32")); - } + let node = match mem.nodes.get_by_id(&node_id) { + Some(n) => n, + None => { + return Ok(http::not_found("Node not found")); + } + }; - if address_taken { - return Ok(http::conflict("Address already taken for this node_id")); - } + let inbound = match node.inbounds.get(&conn_req.proto) { + Some(i) => i, + None => { + return Ok(http::not_found("Inbound for proto not found")); + } + }; - if let Some(node) = mem.nodes.get_by_id(&node_id) { - if let Some(inbound) = node.inbounds.get(&conn_req.proto) { - if let Some(wg_settings) = &inbound.wg { - let ip = wg - .address - .addr - .parse() - .map_err(|e| warp::reject::custom(IpParseError(e)))?; - if !utils::ip_in_mask(&wg_settings.network, ip) { - return Ok(http::bad_request(&format!( - "Address out of node netmask {}", - wg_settings.network - ))); + let wg_settings = match &inbound.wg { + Some(wg) => wg, + None => { + return Ok(http::bad_request("WireGuard settings missing")); } - } - } - } - } + }; - let wg_param = if conn_req.proto.is_wireguard() && conn_req.wg.is_none() { - let node_id = match node_id { - Some(id) => id, - None => { - return Ok(http::bad_request("Missing node_id for WireGuard")); - } - }; + let base_ip = node + .inbounds + .get(&Tag::Wireguard) + .and_then(|inb| inb.wg.as_ref()) + .map(|wg| wg.address) + .and_then(utils::increment_ip); + + let max_ip = mem + .connections + .iter() + .filter(|(_, conn)| conn.get_proto().proto() == Tag::Wireguard) + .filter_map(|(_, conn)| { + conn.get_wireguard() + .and_then(|wg| wg.address.addr.parse().ok()) + }) + .max(); + + let next_ip = match max_ip + .and_then(utils::increment_ip) + .or_else(|| base_ip.and_then(utils::increment_ip)) + .map(std::net::IpAddr::V4) + { + Some(ip) => { + log::debug!("IP Gen: {:?} {:?} {:?}", base_ip, max_ip, ip); + ip + } + None => { + return Ok(http::bad_request("Failed to generate next IP")); + } + }; - let node = match mem.nodes.get_by_id(&node_id) { - Some(n) => n, - None => { - return Ok(http::not_found("Node not found")); - } - }; + if !utils::ip_in_mask(&wg_settings.network, next_ip) { + return Ok(http::bad_request(&format!( + "Generated address {} is out of node netmask {}", + next_ip, wg_settings.network + ))); + } - let inbound = match node.inbounds.get(&conn_req.proto) { - Some(i) => i, - None => { - return Ok(http::bad_request("Inbound for proto not found")); - } - }; + WgParam::new(IpAddrMask::new(next_ip, 32)) + }; - let wg_settings = match &inbound.wg { - Some(wg) => wg, - None => { - return Ok(http::bad_request("WireGuard settings missing")); - } - }; - - let base_ip = node - .inbounds - .get(&Tag::Wireguard) - .and_then(|inb| inb.wg.as_ref()) - .map(|wg| wg.address) - .and_then(utils::increment_ip); - - let max_ip = mem - .connections - .iter() - .filter(|(_, conn)| conn.get_proto().proto() == Tag::Wireguard) - .filter_map(|(_, conn)| { - conn.get_wireguard() - .and_then(|wg| wg.address.addr.parse().ok()) - }) - .max(); - - let next_ip = match max_ip - .and_then(utils::increment_ip) - .or_else(|| base_ip.and_then(utils::increment_ip)) - .map(std::net::IpAddr::V4) - { - Some(ip) => { - log::debug!("IP Gen: {:?} {:?} {:?}", base_ip, max_ip, ip); - ip + Proto::Wireguard { + param: wg_param, + node_id: node_id, } - None => { - return Ok(http::bad_request("Failed to generate next IP")); - } - }; - - if !utils::ip_in_mask(&wg_settings.network, next_ip) { - return Ok(http::bad_request(&format!( - "Generated address {} is out of node netmask {}", - next_ip, wg_settings.network - ))); } - - Some(WgParam::new(IpAddrMask::new(next_ip, 32))) - } else { - None + ProtoTag::Shadowsocks => Proto::Shadowsocks { + password: conn_req.password.unwrap(), + }, + ProtoTag::VlessTcpReality + | ProtoTag::VlessGrpcReality + | ProtoTag::VlessXhttpReality + | ProtoTag::Vmess => Proto::Xray(conn_req.proto), + ProtoTag::Hysteria2 => Proto::Hysteria2 { + token: conn_req.token.unwrap(), + }, }; drop(mem); - log::debug!("WG params {:?}", wg_param); - let proto = if let Some(wg) = &conn_req.wg { - Proto::Wireguard { - param: wg.clone(), - node_id: node_id.unwrap(), - } - } else if let Some(wg) = wg_param { - Proto::Wireguard { - param: wg.clone(), - node_id: node_id.unwrap(), - } - } else if let Some(password) = &conn_req.password { - Proto::Shadowsocks { - password: password.clone(), - } - } else { - Proto::Xray(conn_req.proto) - }; - let conn: Connection = Connection::new( - &env, + &conn_req.env, conn_req.subscription_id, ConnectionStat::default(), proto, - node_id, expired_at, ) .into(); log::debug!("New connection to create {}", conn); - + let conn_id = uuid::Uuid::new_v4(); let msg = conn.as_create_message(&conn_id); + let mut messages = vec![]; + messages.push(msg); + match SyncOp::add_conn(&memory, &conn_id, conn.clone()).await { Ok(StorageOperationStatus::Ok(id)) => { - let bytes = match rkyv::to_bytes::<_, 1024>(&msg) { + let bytes = match rkyv::to_bytes::<_, 1024>(&messages) { Ok(b) => b, Err(e) => { return Ok(http::internal_error(&format!("Serialization error: {}", e))); } }; - if let Some(node_id) = conn.node_id { - let _ = publisher - .send_binary(&node_id.to_string(), bytes.as_ref()) - .await; + let topic = if let Some(node_id) = conn.get_wireguard_node_id() { + node_id.to_string() } else { - let _ = publisher.send_binary(&env, bytes.as_ref()).await; - } + conn.get_env() + }; + + let _ = publisher.send_binary(&topic, bytes.as_ref()).await; return Ok(http::success_response( format!("Connection {} has been created", id), diff --git a/src/bin/api/core/http/handlers/node.rs b/src/bin/api/core/http/handlers/node.rs index 564bc912..510a64d7 100644 --- a/src/bin/api/core/http/handlers/node.rs +++ b/src/bin/api/core/http/handlers/node.rs @@ -1,14 +1,10 @@ -use rkyv::to_bytes; use warp::http::StatusCode; use pony::http::requests::NodeIdParam; use pony::http::requests::NodeRequest; use pony::http::requests::NodeResponse; -use pony::http::requests::NodeType; -use pony::http::requests::NodeTypeParam; use pony::http::requests::NodesQueryParams; use pony::http::IdResponse; -use pony::http::MyRejection; use pony::http::ResponseMessage; use pony::memory::node::Status as NodeStatus; use pony::Connection; @@ -16,9 +12,7 @@ use pony::ConnectionApiOp; use pony::ConnectionBaseOp; use pony::NodeStorageOp; use pony::OperationStatus as StorageOperationStatus; -use pony::Publisher as ZmqPublisher; use pony::SubscriptionOp; -use pony::Tag; use crate::core::clickhouse::score::NodeScore; use crate::core::clickhouse::ChContext; @@ -29,9 +23,7 @@ use crate::core::sync::MemSync; // POST /node pub async fn post_node_handler( node_req: NodeRequest, - node_param: NodeTypeParam, memory: MemSync, - publisher: ZmqPublisher, ) -> Result where N: NodeStorageOp + Sync + Send + Clone + 'static, @@ -56,9 +48,6 @@ where let node = node_req.clone().as_node(); let node_id = node_req.uuid; - let last_update = node_param.last_update; - - let node_type = node_param.node_type.unwrap_or(NodeType::All); let status = SyncOp::add_node(&memory, &node_id, node.clone()).await; @@ -69,58 +58,6 @@ where let _ = SyncOp::update_node_status(&memory, &node_id, &node.env, NodeStatus::Online).await; - let mem = memory.memory.read().await; - - let connections_to_send: Vec<_> = mem - .connections - .iter() - .filter(|(_, conn)| { - let matches_type = !conn.get_deleted() - && match node_type { - NodeType::Wireguard => { - conn.get_proto().proto() == Tag::Wireguard - && Some(node_id) == conn.get_wireguard_node_id() - } - NodeType::Xray => conn.get_proto().proto() != Tag::Wireguard, - NodeType::All => match conn.get_proto().proto() { - Tag::Wireguard => conn.get_wireguard_node_id() == Some(node_id), - _ => true, - }, - }; - - let matches_time = if let Some(ts) = last_update { - conn.get_modified_at().and_utc().timestamp() as u64 >= ts - } else { - true - }; - - matches_type && matches_time - }) - .collect(); - - log::info!( - "Sending {} connections to node {}", - connections_to_send.len(), - node_id - ); - - for (conn_id, conn) in connections_to_send { - let message = conn.as_create_message(conn_id); - - let bytes = to_bytes::<_, 1024>(&message).map_err(|e| { - log::error!("Serialization error: {}", e); - warp::reject::custom(MyRejection(Box::new(e))) - })?; - - publisher - .send_binary(&node_id.to_string(), bytes.as_ref()) - .await - .map_err(|e| { - log::error!("Publish error: {}", e); - warp::reject::custom(MyRejection(Box::new(e))) - })?; - } - ResponseMessage::> { status: StatusCode::OK.as_u16(), message: "Ok".to_string(), diff --git a/src/bin/api/core/http/handlers/sub.rs b/src/bin/api/core/http/handlers/sub.rs index bad2c472..1eed8fe8 100644 --- a/src/bin/api/core/http/handlers/sub.rs +++ b/src/bin/api/core/http/handlers/sub.rs @@ -2,6 +2,7 @@ use base64::Engine; use chrono::DateTime; use chrono::Utc; +use pony::http::requests::TagReq; use warp::http::Response; use warp::http::StatusCode; @@ -384,7 +385,7 @@ ul {{
-

Подписка

+

Подписка на Рилзопровод

Статус: {status_text}
@@ -400,14 +401,25 @@ ul {{

Ссылки для подключения

-Универсальная ссылка (рекомендуется) +Xray Универсальная ссылка (рекомендуется) +
+ +
+
+Hysteria2(Beta) Универсальная ссылка (рекомендуется)
-

Дополнительные форматы:

+ +

Дополнительные форматы Xray:

    -
  • TXT для v2ray
  • -
  • Clash — для Clash / Clash Meta
  • +
  • TXT для v2ray
  • +
  • Clash — для Clash / Clash Meta
  • +
+ +

Дополнительные форматы Hysteria2:

+
    +
  • TXT для v2ray

@@ -599,6 +611,26 @@ where let env = sub_param.env; + let mut tags = vec![]; + tags = match sub_param.proto { + TagReq::Xray => { + tags.push(Tag::VlessTcpReality); + tags.push(Tag::VlessGrpcReality); + tags.push(Tag::VlessXhttpReality); + tags.push(Tag::Vmess); + tags.push(Tag::Shadowsocks); + tags + } + TagReq::Wireguard => { + tags.push(Tag::Wireguard); + tags + } + TagReq::Hysteria2 => { + tags.push(Tag::Hysteria2); + tags + } + }; + if let Some(conns) = conns { for (conn_id, conn) in conns { if conn.get_deleted() { @@ -608,6 +640,14 @@ where continue; } + let proto = conn.get_proto().proto(); + + if !tags.contains(&proto) { + continue; + } + + let token = conn.get_token(); + if let Some(nodes) = mem.nodes.get_by_env(&conn.get_env()) { for node in nodes.iter() { if let Some(inbound) = &node.inbounds.get(&conn.get_proto().proto()) { @@ -616,6 +656,7 @@ where conn_id, node.label.clone(), node.address, + token, )); } } @@ -634,7 +675,7 @@ where "clash" => { let mut proxies = vec![]; - for (inbound, conn_id, label, address) in &inbounds_by_node { + for (inbound, conn_id, label, address, _) in &inbounds_by_node { if let Some(proxy) = generate_proxy_config(inbound, *conn_id, *address, label) { proxies.push(proxy) } @@ -655,8 +696,16 @@ where "txt" => { let links = inbounds_by_node .iter() - .filter_map(|(inbound, conn_id, label, ip)| { - utils::create_conn_link(inbound.tag, conn_id, inbound.clone(), label, *ip).ok() + .filter_map(|(inbound, conn_id, label, ip, token)| { + utils::create_conn_link( + inbound.tag, + conn_id, + inbound.clone(), + label, + *ip, + token, + ) + .ok() }) .collect::>(); @@ -671,8 +720,16 @@ where _ => { let links = inbounds_by_node .iter() - .filter_map(|(inbound, conn_id, label, ip)| { - utils::create_conn_link(inbound.tag, conn_id, inbound.clone(), label, *ip).ok() + .filter_map(|(inbound, conn_id, label, ip, token)| { + utils::create_conn_link( + inbound.tag, + conn_id, + inbound.clone(), + label, + *ip, + token, + ) + .ok() }) .collect::>(); diff --git a/src/bin/api/core/http/routes.rs b/src/bin/api/core/http/routes.rs index b1f84c23..5e2f9851 100644 --- a/src/bin/api/core/http/routes.rs +++ b/src/bin/api/core/http/routes.rs @@ -83,9 +83,7 @@ where .and(warp::path::end()) .and(auth.clone()) .and(warp::body::json::()) - .and(warp::query::()) .and(with_state(self.sync.clone())) - .and(publisher(self.publisher.clone())) .and_then(post_node_handler); let get_node_route = warp::get() @@ -160,6 +158,15 @@ where .and(with_state(self.sync.clone())) .and_then(get_connection_handler); + let get_connections_route = warp::get() + .and(warp::path("connections")) + .and(warp::path::end()) + .and(auth.clone()) + .and(warp::query::()) + .and(publisher(self.publisher.clone())) + .and(with_state(self.sync.clone())) + .and_then(get_connections_handler); + let post_connection_route = warp::post() .and(warp::path("connection")) .and(warp::path::end()) @@ -203,6 +210,7 @@ where .or(post_node_register_route) // Connection .or(get_connection_route) + .or(get_connections_route) .or(post_connection_route) .or(delete_connection_route) .or(put_connection_route) diff --git a/src/bin/api/core/postgres/connection.rs b/src/bin/api/core/postgres/connection.rs index 50b7444d..ce16cab9 100644 --- a/src/bin/api/core/postgres/connection.rs +++ b/src/bin/api/core/postgres/connection.rs @@ -32,6 +32,7 @@ pub struct ConnRow { pub wg: Option, pub node_id: Option, pub proto: Tag, + pub token: Option, is_deleted: bool, } @@ -55,6 +56,7 @@ impl From<(uuid::Uuid, Conn)> for ConnRow { wg: conn.get_wireguard().cloned(), node_id: conn.get_wireguard_node_id(), proto: conn.get_proto().proto(), + token: conn.get_token(), is_deleted: conn.is_deleted, } } @@ -85,6 +87,13 @@ impl TryFrom for Conn { Proto::new_ss(&password) } + Tag::Hysteria2 => { + let token = row + .token + .ok_or_else(|| PonyError::Custom("Missing Hysteria2 token".into()))?; + Proto::new_hysteria2(&token) + } + tag => Proto::new_xray(&tag), }; @@ -97,7 +106,6 @@ impl TryFrom for Conn { modified_at: row.modified_at, expired_at: row.expired_at, is_deleted: row.is_deleted, - node_id: row.node_id, }) } } @@ -119,6 +127,7 @@ impl PgConn { SELECT id, password, + token, env, created_at, modified_at, @@ -151,6 +160,7 @@ impl PgConn { let modified_at: NaiveDateTime = row.get("modified_at"); let expired_at: Option> = row.get("expired_at"); let subscription_id: Option = row.get("subscription_id"); + let token: Option = row.get("token"); let online: i64 = row.get("online"); let uplink: i64 = row.get("uplink"); let downlink: i64 = row.get("downlink"); @@ -174,6 +184,7 @@ impl PgConn { ConnRow { conn_id, password, + token, env, created_at, modified_at, @@ -244,11 +255,12 @@ impl PgConn { wg_privkey, wg_pubkey, wg_address, - node_id + node_id, + token ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, - $11, $12, $13, $14, $15, $16 + $11, $12, $13, $14, $15, $16, $17 ) "; @@ -272,6 +284,7 @@ impl PgConn { &conn.wg.as_ref().map(|w| &w.keys.pubkey), &conn.wg.as_ref().map(|w| w.address.to_string()), &conn.node_id, + &conn.token, ], ) .await; diff --git a/src/bin/api/core/postgres/node.rs b/src/bin/api/core/postgres/node.rs index 6d7fc3f4..ba6a9bce 100644 --- a/src/bin/api/core/postgres/node.rs +++ b/src/bin/api/core/postgres/node.rs @@ -9,6 +9,7 @@ use chrono::DateTime; use chrono::Utc; use tokio::sync::Mutex; +use pony::config::h2::H2Settings; use pony::config::wireguard::WireguardSettings; use pony::config::xray::Inbound; use pony::memory::node::Node; @@ -78,12 +79,12 @@ impl PgNode { INSERT INTO inbounds ( id, node_id, tag, port, stream_settings, uplink, downlink, conn_count, - wg_pubkey, wg_privkey, wg_interface, wg_network, wg_address, dns + wg_pubkey, wg_privkey, wg_interface, wg_network, wg_address, dns, h2 ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, - $9, $10, $11, $12, $13, $14 + $9, $10, $11, $12, $13, $14, $15 ) ON CONFLICT (node_id, tag) DO UPDATE SET port = EXCLUDED.port, @@ -96,12 +97,14 @@ impl PgNode { wg_interface = EXCLUDED.wg_interface, wg_network = EXCLUDED.wg_network, wg_address = EXCLUDED.wg_address, - dns = EXCLUDED.dns + dns = EXCLUDED.dns, + h2 = EXCLUDED.h2 "; for inbound in node.inbounds.values() { let inbound_id = uuid::Uuid::new_v4(); let stream_settings = serde_json::to_value(&inbound.stream_settings)?; + let h2_settings = serde_json::to_value(&inbound.h2)?; let (wg_pubkey, wg_privkey, wg_interface, wg_network, wg_address, dns) = inbound .wg @@ -141,6 +144,7 @@ impl PgNode { &wg_network, &wg_address, &dns, + &h2_settings, ], ) .await?; @@ -160,7 +164,7 @@ impl PgNode { n.id AS node_id, n.uuid, n.env, n.hostname, n.address, n.status, n.created_at, n.modified_at, n.label, n.interface, n.cores, n.max_bandwidth_bps, i.id AS inbound_id, i.tag, i.port, i.stream_settings, i.uplink, i.downlink, - i.conn_count, i.wg_pubkey, i.wg_privkey, i.wg_interface, i.wg_network, i.wg_address, i.dns + i.conn_count, i.wg_pubkey, i.wg_privkey, i.wg_interface, i.wg_network, i.wg_address, i.dns, i.h2 FROM nodes n LEFT JOIN inbounds i ON n.id = i.node_id", &[], @@ -201,6 +205,11 @@ impl PgNode { }); let inbound_id: Option = row.get("inbound_id"); + let h2: Option = row + .get::<_, Option>("h2") + .map(|v| serde_json::from_value(v).ok()) + .flatten(); + if let Some(ipv4_addr) = to_ipv4(address) { let node_entry = nodes_map.entry(node_id).or_insert_with(|| Node { uuid, @@ -257,6 +266,7 @@ impl PgNode { downlink: row.get("downlink"), conn_count: row.get("conn_count"), wg, + h2, }; node_entry.inbounds.insert(inbound.tag, inbound); diff --git a/src/bin/api/core/postgres/subscription.rs b/src/bin/api/core/postgres/subscription.rs index 39eb1db8..cdd55451 100644 --- a/src/bin/api/core/postgres/subscription.rs +++ b/src/bin/api/core/postgres/subscription.rs @@ -2,7 +2,6 @@ use std::sync::Arc; use tokio::sync::Mutex; use pony::memory::subscription::Subscription; -use pony::utils::get_uuid_last_octet_simple; use pony::Result; use super::PgClientManager; diff --git a/src/bin/api/core/sync/tasks.rs b/src/bin/api/core/sync/tasks.rs index 3573e898..deaf37be 100644 --- a/src/bin/api/core/sync/tasks.rs +++ b/src/bin/api/core/sync/tasks.rs @@ -463,7 +463,7 @@ where } if let Some(bonus_days) = sub_req.bonus_days { - sub.set_bonus_days(bonus_days); // 👈 см. ниже + sub.set_bonus_days(bonus_days); } if let Some(ref_by) = sub_req.referred_by.clone() { diff --git a/src/bin/api/core/tasks.rs b/src/bin/api/core/tasks.rs index 565b0310..20db3086 100644 --- a/src/bin/api/core/tasks.rs +++ b/src/bin/api/core/tasks.rs @@ -91,10 +91,11 @@ impl Tasks for Api>, Connection, Subscription> { } }; - let key = conn - .node_id - .map(|id| id.to_string()) - .unwrap_or_else(|| conn.get_env()); + let key = if let Some(node_id) = conn.get_wireguard_node_id() { + node_id.to_string() + } else { + conn.get_env() + }; let _ = publisher.send_binary(&key, bytes.as_ref()).await; } @@ -143,10 +144,11 @@ impl Tasks for Api>, Connection, Subscription> { for (conn_id, conn) in conns_to_delete { let msg = conn.as_delete_message(&conn_id); if let Ok(bytes) = rkyv::to_bytes::<_, 1024>(&msg) { - let key = conn - .node_id - .map(|id| id.to_string()) - .unwrap_or_else(|| conn.get_env()); + let key = if let Some(node_id) = conn.get_wireguard_node_id() { + node_id.to_string() + } else { + conn.get_env() + }; let _ = publisher.send_binary(&key, bytes.as_ref()).await; } @@ -203,10 +205,11 @@ impl Tasks for Api>, Connection, Subscription> { for (conn_id, conn) in conns_to_restore { let msg = conn.as_update_message(&conn_id); if let Ok(bytes) = rkyv::to_bytes::<_, 1024>(&msg) { - let key = conn - .node_id - .map(|id| id.to_string()) - .unwrap_or_else(|| conn.get_env()); + let key = if let Some(node_id) = conn.get_wireguard_node_id() { + node_id.to_string() + } else { + conn.get_env() + }; let _ = publisher.send_binary(&key, bytes.as_ref()).await; } diff --git a/src/bin/auth/core/http.rs b/src/bin/auth/core/http.rs new file mode 100644 index 00000000..3ab95cc0 --- /dev/null +++ b/src/bin/auth/core/http.rs @@ -0,0 +1,155 @@ +use pony::http::requests::ConnTypeParam; +use pony::memory::cache::Cache; +use pony::ConnectionBaseOp; +use pony::ConnectionStorageBaseOp; +use pony::NodeStorageOp; +use pony::SubscriptionOp; +use pony::Tag; +use serde::Deserialize; +use serde::Serialize; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use tokio::sync::RwLock; +use warp::Filter; + +#[derive(Deserialize)] +pub struct AuthRequest { + addr: String, + pub auth: uuid::Uuid, + tx: u64, +} + +#[derive(Serialize)] +pub struct AuthResponse { + pub ok: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, +} + +pub async fn start_auth_server( + memory: Arc>>, + ipaddr: Ipv4Addr, + port: u16, +) where + N: NodeStorageOp + Sync + Send + Clone + 'static, + S: SubscriptionOp + Sync + Send + Clone + 'static + std::cmp::PartialEq + serde::Serialize, + C: ConnectionBaseOp + Sync + Send + Clone + 'static + std::fmt::Display, +{ + let health_check = warp::path("health-check").map(|| "Server OK"); + + let auth_route = warp::post() + .and(warp::path("auth")) + .and(warp::body::json()) + .and(warp::any().map(move || memory.clone())) + .and_then(auth_handler); + + let routes = health_check + .or(auth_route) + .with(warp::cors().allow_any_origin()); + + warp::serve(routes) + .run(SocketAddr::new(IpAddr::V4(ipaddr), port)) + .await; +} + +pub async fn auth_handler( + req: AuthRequest, + memory: Arc>>, +) -> Result +where + N: NodeStorageOp + Sync + Send + Clone + 'static, + S: SubscriptionOp + Sync + Send + Clone + 'static + std::cmp::PartialEq + serde::Serialize, + C: ConnectionBaseOp + Sync + Send + Clone + 'static + std::fmt::Display, +{ + log::debug!("Auth req {} {} {}", req.auth, req.addr, req.tx); + let mem = memory.read().await; + if let Some(id) = mem.connections.validate_token(&req.auth) { + return Ok(warp::reply::json(&AuthResponse { + ok: true, + id: Some(id.to_string()), + })); + } else { + return Ok(warp::reply::json(&AuthResponse { + ok: false, + id: None, + })); + } +} + +use async_trait::async_trait; +use reqwest::Client as HttpClient; +use reqwest::Url; + +use pony::{PonyError, Result as PonyResult}; + +use super::AuthService; + +#[async_trait] +pub trait ApiRequests { + async fn get_connections( + &self, + endpoint: String, + token: String, + proto: Tag, + last_update: Option, + ) -> PonyResult<()>; +} + +#[async_trait] +impl ApiRequests for AuthService +where + T: NodeStorageOp + Send + Sync + Clone, + C: ConnectionBaseOp + Send + Sync + Clone + 'static, + S: SubscriptionOp + Send + Sync + Clone + 'static, +{ + async fn get_connections( + &self, + endpoint: String, + token: String, + proto: Tag, + last_update: Option, + ) -> PonyResult<()> { + let node = { + let mem = self.memory.read().await; + mem.nodes + .get_self() + .expect("No node available to register") + .clone() + }; + + let env = node.env; + + let conn_type_param = ConnTypeParam { + proto: proto, + last_update: last_update, + env: env, + }; + + let mut endpoint_url = Url::parse(&endpoint)?; + endpoint_url + .path_segments_mut() + .map_err(|_| PonyError::Custom("Invalid API endpoint".to_string()))? + .push("connections"); + let endpoint_str = endpoint_url.to_string(); + + let res = HttpClient::new() + .get(&endpoint_str) + .query(&conn_type_param) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .send() + .await?; + + let status = res.status(); + let body = res.text().await?; + if status.is_success() { + log::debug!("Connections Request Accepted: {:?}", status); + Ok(()) + } else { + log::error!("Connections Request failed: {} - {}", status, body); + Err(PonyError::Custom( + format!("Connections Request failed: {} - {}", status, body).into(), + )) + } + } +} diff --git a/src/bin/auth/core/mod.rs b/src/bin/auth/core/mod.rs new file mode 100644 index 00000000..4f72e0cc --- /dev/null +++ b/src/bin/auth/core/mod.rs @@ -0,0 +1,38 @@ +use pony::Subscription; +use std::sync::Arc; +use tokio::sync::RwLock; + +use pony::memory::node::Node; +use pony::zmq::subscriber::Subscriber as ZmqSubscriber; +use pony::BaseConnection as Connection; +use pony::ConnectionBaseOp; +use pony::MemoryCache; +use pony::NodeStorageOp; +use pony::SubscriptionOp; + +pub mod http; +pub mod service; +pub mod tasks; + +pub type AuthServiceState = MemoryCache; + +pub struct AuthService +where + N: NodeStorageOp + Send + Sync + Clone + 'static, + C: ConnectionBaseOp + Send + Sync + Clone + 'static, + S: SubscriptionOp + Send + Sync + Clone + 'static, +{ + pub memory: Arc>>, + pub subscriber: ZmqSubscriber, +} + +impl AuthService +where + N: NodeStorageOp + Send + Sync + Clone + 'static, + C: ConnectionBaseOp + Send + Sync + Clone + 'static, + S: SubscriptionOp + Send + Sync + Clone + 'static, +{ + pub fn new(memory: Arc>>, subscriber: ZmqSubscriber) -> Self { + Self { memory, subscriber } + } +} diff --git a/src/bin/auth/core/service.rs b/src/bin/auth/core/service.rs new file mode 100644 index 00000000..e3ca555f --- /dev/null +++ b/src/bin/auth/core/service.rs @@ -0,0 +1,186 @@ +use std::net::Ipv4Addr; +use std::path::Path; +use std::sync::Arc; +use tokio::signal; +use tokio::sync::broadcast; +use tokio::sync::RwLock; +use tokio::task::JoinHandle; +use tokio::time::Duration; + +use pony::config::settings::AuthServiceSettings; +use pony::config::settings::NodeConfig; +use pony::http::debug; +use pony::memory::node::Node; +use pony::MemoryCache; +use pony::Result; +use pony::SnapshotManager; +use pony::Subscriber as ZmqSubscriber; + +use super::AuthService; +use crate::core::http::start_auth_server; +use crate::core::http::ApiRequests; +use crate::core::tasks::Tasks; +use crate::core::AuthServiceState; + +pub async fn run(settings: AuthServiceSettings) -> Result<()> { + let mut tasks: Vec> = vec![]; + let (shutdown_tx, _) = broadcast::channel::<()>(1); + + let node_config = NodeConfig::from_raw(settings.node.clone()); + let node = Node::new(node_config?, None, None, None); + + let memory: Arc> = + Arc::new(RwLock::new(MemoryCache::with_node(node.clone()))); + + let subscriber = ZmqSubscriber::new( + &settings.zmq.endpoint, + &settings.node.uuid, + &settings.node.env, + ); + + let auth = Arc::new(AuthService::new(memory.clone(), subscriber)); + + let snapshot_manager = + SnapshotManager::new(settings.clone().auth.snapshot_path, memory.clone()); + + let snapshot_timestamp = if Path::new(&snapshot_manager.snapshot_path).exists() { + match snapshot_manager.load_snapshot().await { + Ok(Some(timestamp)) => { + log::info!("Loaded connections snapshot from {}", timestamp); + Some(timestamp) + } + Ok(None) => { + log::warn!("Snapshot file exists but couldn't be loaded"); + None + } + Err(e) => { + log::error!("Failed to load snapshot: {}", e); + log::info!("Starting fresh due to snapshot load error"); + None + } + } + } else { + log::info!("No snapshot found, starting fresh"); + None + }; + + let snapshot_manager = snapshot_manager.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs( + settings.auth.snapshot_interval, + )); + loop { + interval.tick().await; + if let Err(e) = snapshot_manager.create_snapshot().await { + log::error!("Failed to create snapshot: {}", e); + } else { + log::info!("Connections snapshot saved successfully"); + } + } + }); + + let _ = { + log::info!("ZMQ listener starting..."); + + let zmq_task = tokio::spawn({ + let auth = auth.clone(); + let mut shutdown = shutdown_tx.subscribe(); + async move { + tokio::select! { + _ = auth.run_subscriber() => {}, + _ = shutdown.recv() => {}, + } + } + }); + tasks.push(zmq_task); + + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + }; + + let _ = { + let settings = settings.clone(); + if let Err(e) = auth + .get_connections( + settings.api.endpoint.clone(), + settings.api.token.clone(), + pony::Tag::Hysteria2, + snapshot_timestamp, + ) + .await + { + log::error!("-->>Cannot register auth service, \n {:?}", e); + } + }; + + let _ = { + let mut shutdown = shutdown_tx.subscribe(); + let memory = memory.clone(); + let addr = settings + .auth + .web_server + .unwrap_or(Ipv4Addr::new(127, 0, 0, 1)); + let port = settings.auth.web_port; + + let auth_handle = tokio::spawn(async move { + tokio::select! { + _ = start_auth_server(memory, addr, port) => {}, + _ = shutdown.recv() => {}, + } + }); + tasks.push(auth_handle); + }; + + let token = Arc::new(settings.api.token.clone()); + if settings.debug.enabled { + log::debug!( + "Running debug server: localhost:{}", + settings.debug.web_port + ); + let mut shutdown = shutdown_tx.subscribe(); + let memory = memory.clone(); + let addr = settings + .debug + .web_server + .unwrap_or(Ipv4Addr::new(127, 0, 0, 1)); + let port = settings.debug.web_port; + let token = token.clone(); + + let debug_handle = tokio::spawn(async move { + tokio::select! { + _ = debug::start_ws_server(memory, addr, port, token) => {}, + _ = shutdown.recv() => {}, + } + }); + tasks.push(debug_handle); + } + + wait_all_tasks_or_ctrlc(tasks, shutdown_tx).await; + Ok(()) +} + +async fn wait_all_tasks_or_ctrlc(tasks: Vec>, shutdown_tx: broadcast::Sender<()>) { + tokio::select! { + _ = async { + for (i, task) in tasks.into_iter().enumerate() { + match task.await { + Ok(_) => { + log::info!("Task {i} completed successfully"); + } + + Err(e) => { + log::error!("Task {i} panicked: {:?}", e); + let _ = shutdown_tx.send(()); + tokio::time::sleep(Duration::from_secs(5)).await; + std::process::exit(1); + } + } + } + } => {} + _ = signal::ctrl_c() => { + log::info!("🛑 Ctrl+C received. Shutting down..."); + let _ = shutdown_tx.send(()); + tokio::time::sleep(Duration::from_secs(5)).await; + std::process::exit(0); + } + } +} diff --git a/src/bin/auth/core/tasks.rs b/src/bin/auth/core/tasks.rs new file mode 100644 index 00000000..303fa3f9 --- /dev/null +++ b/src/bin/auth/core/tasks.rs @@ -0,0 +1,147 @@ +use async_trait::async_trait; +use pony::ConnectionStorageBaseOp; +use pony::Proto; +use rkyv::AlignedVec; +use rkyv::Infallible; +use tokio::time::Duration; + +use pony::Action; +use pony::BaseConnection as Connection; +use pony::ConnectionBaseOp; +use pony::Message; +use pony::NodeStorageOp; +use pony::SubscriptionOp; +use pony::Topic; +use pony::{PonyError, Result}; + +use rkyv::Deserialize; + +use super::AuthService; + +#[async_trait] +pub trait Tasks { + async fn run_subscriber(&self) -> Result<()>; + async fn handle_messages_batch(&self, msg: Vec) -> Result<()>; +} + +#[async_trait] +impl Tasks for AuthService +where + T: NodeStorageOp + Send + Sync + Clone, + C: ConnectionBaseOp + Send + Sync + Clone + 'static + From, + S: SubscriptionOp + Send + Sync + Clone + 'static + std::cmp::PartialEq, +{ + async fn run_subscriber(&self) -> Result<()> { + let sub = self.subscriber.clone(); + assert!(self.subscriber.topics.contains(&"all".to_string())); + + loop { + let Some((topic_bytes, payload_bytes)) = sub.recv().await else { + log::warn!("SUB: No multipart message received"); + continue; + }; + + let topic_str = std::str::from_utf8(&topic_bytes).unwrap_or(""); + log::debug!("SUB: Topic string: {:?}", topic_str); + log::debug!("SUB: Payload {} bytes", payload_bytes.len()); + + match Topic::from_raw(topic_str) { + Topic::Init(uuid) if uuid != self.subscriber.topics[0] => { + log::warn!("SUB: Skipping init for another node: {}", uuid); + continue; + } + Topic::Updates(env) if env != self.subscriber.topics[1] => { + log::warn!("SUB: Skipping update for another env: {}", env); + continue; + } + Topic::Unknown(raw) => { + log::warn!("SUB: Unknown topic: {}", raw); + continue; + } + Topic::All => { + log::debug!("SUB: Message for 'All' topic received"); + } + topic => { + log::debug!("SUB: Accepted topic: {:?}", topic); + } + } + + if payload_bytes.is_empty() { + log::warn!("SUB: Empty payload, skipping"); + continue; + } + + let mut aligned = AlignedVec::new(); + aligned.extend_from_slice(&payload_bytes); + + let archived = match { rkyv::check_archived_root::>(&aligned) } { + Ok(a) => a, + Err(e) => { + log::error!("SUB: Invalid rkyv root: {:?}", e); + log::error!("SUB: Payload bytes (hex) = {}", hex::encode(payload_bytes)); + continue; + } + }; + + match archived.deserialize(&mut Infallible) { + Ok(messages) => { + if let Err(err) = self.handle_messages_batch(messages).await { + log::error!("SUB: Failed to handle messages: {}", err); + } + } + Err(err) => { + log::error!("SUB: Failed to deserialize messages: {}", err); + log::error!("SUB: Payload bytes (hex) = {}", hex::encode(payload_bytes)); + } + } + + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + async fn handle_messages_batch(&self, messages: Vec) -> Result<()> { + let mut mem = self.memory.write().await; + + log::debug!("Got {} messages", messages.len()); + + for msg in messages { + let conn_id: uuid::Uuid = msg.conn_id.into(); + + let res: Result<()> = match msg.action { + Action::Create | Action::Update => { + if let Some(token) = msg.token { + let proto = Proto::new_hysteria2(&token); + let conn = Connection::new( + proto, + msg.expires_at.map(Into::into), + msg.subscription_id, + ); + mem.connections + .add(&conn_id, conn.into()) + .map(|_| ()) + .map_err(|err| { + PonyError::Custom(format!( + "Failed to add conn {}: {}", + conn_id, err + )) + }) + } else { + log::debug!("Skipped message {:?}", msg); + Ok(()) + } + } + + Action::Delete => { + let _ = mem.connections.remove(&conn_id); + Ok(()) + } + + _ => Ok(()), + }; + + res?; + } + + Ok(()) + } +} diff --git a/src/bin/auth/main.rs b/src/bin/auth/main.rs new file mode 100644 index 00000000..6c9e51fa --- /dev/null +++ b/src/bin/auth/main.rs @@ -0,0 +1,85 @@ +use fern::Dispatch; + +use pony::config::settings::AuthServiceSettings; +use pony::config::settings::Settings; +use pony::utils::*; + +mod core; + +fn main() -> Result<(), Box> { + println!(" "); + println!("░▒▓████████▓▒░▒▓███████▓▒░░▒▓█▓▒░░▒▓█▓▒░▒▓███████▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓██████▓▒░ ░▒▓███████▓▒░░▒▓███████▓▒░░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ "); + println!(" "); + println!(" "); + println!("░▒▓█▓▒░▒▓███████▓▒░▒▓████████▓▒░▒▓████████▓▒░▒▓███████▓▒░░▒▓███████▓▒░░▒▓████████▓▒░▒▓████████▓▒░ "); + println!("░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░ ░▒▓█▓▒░ "); + println!("░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░ ░▒▓█▓▒░ "); + println!("░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓██████▓▒░ ░▒▓███████▓▒░░▒▓█▓▒░░▒▓█▓▒░▒▓██████▓▒░ ░▒▓█▓▒░ "); + println!("░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░ ░▒▓█▓▒░ "); + println!("░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░ ░▒▓█▓▒░ "); + println!("░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓████████▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓████████▓▒░ ░▒▓█▓▒░ "); + println!(" "); + println!(" "); + println!("░▒▓████████▓▒░▒▓███████▓▒░░▒▓████████▓▒░▒▓████████▓▒░▒▓███████▓▒░ ░▒▓██████▓▒░░▒▓██████████████▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓██████▓▒░ ░▒▓███████▓▒░░▒▓██████▓▒░ ░▒▓██████▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░ ░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░▒▓█▓▒░░▒▓█▓▒░░▒▓█▓▒░ "); + println!("░▒▓█▓▒░ ░▒▓█▓▒░░▒▓█▓▒░▒▓████████▓▒░▒▓████████▓▒░▒▓███████▓▒░ ░▒▓██████▓▒░░▒▓█▓▒░░▒▓█▓▒░░▒▓█▓▒░ "); + println!(" "); + println!(" "); + + #[cfg(feature = "debug")] + console_subscriber::init(); + + let config_path = &std::env::args() + .nth(1) + .expect("required config path as an argument"); + println!("Config file {}", config_path); + + let settings = AuthServiceSettings::new(config_path); + + settings.validate().expect("Wrong settings file"); + println!(">>> Settings: {:?}", settings.clone()); + + Dispatch::new() + .format(|out, message, record| { + out.finish(format_args!( + "[{}][{}][{}] {}", + record.level(), + human_readable_date(current_timestamp() as u64), + record.target(), + message + )) + }) + .level(level_from_settings(&settings.logging.level)) + .chain(std::io::stdout()) + .apply() + .unwrap(); + + let num_cpus = std::thread::available_parallelism()?.get(); + + let worker_threads = if num_cpus <= 1 { 1 } else { num_cpus * 2 }; + log::info!( + "🧠 CPU cores: {}, configured worker threads: {}", + num_cpus, + worker_threads + ); + + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(worker_threads) + .enable_all() + .build() + .unwrap(); + + runtime.block_on(core::service::run(settings))?; + + Ok(()) +} diff --git a/src/config/h2.rs b/src/config/h2.rs new file mode 100644 index 00000000..5ac56463 --- /dev/null +++ b/src/config/h2.rs @@ -0,0 +1,148 @@ +use crate::PonyError; +use crate::Result; +use serde::Deserialize; +use serde::Serialize; +use std::fs::File; +use std::io::Read; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HysteriaServerConfig { + pub listen: Option, + pub acme: Option, + pub auth: Option, + pub obfs: Option, + pub masquerade: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct H2AuthInfo { + pub auth_type: String, + pub has_password: bool, + pub has_url: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcmeConfig { + pub domains: Option>, + pub email: Option, + + #[serde(rename = "type")] + pub r#type: Option, + pub dir: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthConfig { + pub r#type: Option, + pub password: Option, + pub url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HysteriaObfs { + pub r#type: Option, + pub password: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Masquerade { + pub r#type: String, +} + +impl HysteriaServerConfig { + pub fn from_file(path: &str) -> anyhow::Result { + let mut file = File::open(path)?; + let mut contents = String::new(); + file.read_to_string(&mut contents)?; + + let config: HysteriaServerConfig = serde_yaml::from_str(&contents)?; + Ok(config) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct H2Settings { + pub host: String, + pub port: u16, + pub sni: Option, + pub insecure: bool, + pub obfs: Option, + pub alpn: Option>, + pub up_mbps: Option, + pub down_mbps: Option, + pub auth_info: Option, +} +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct H2Obfs { + pub r#type: String, + pub password: String, +} + +impl HysteriaServerConfig { + pub fn validate(&self) -> Result<()> { + if self.listen.is_none() { + return Err(PonyError::Custom("Hysteria2: listen is required".into())); + } + + let auth = self + .auth + .as_ref() + .ok_or_else(|| PonyError::Custom("Hysteria2: auth section is required".into()))?; + + if auth.password.clone().unwrap_or("".to_string()).is_empty() { + return Err(PonyError::Custom( + "Hysteria2: auth.password is required".into(), + )); + } + + Ok(()) + } +} + +impl TryFrom for H2Settings { + type Error = PonyError; + + fn try_from(server: HysteriaServerConfig) -> std::result::Result { + let listen = server + .listen + .ok_or_else(|| PonyError::Custom("Hysteria2: listen missing".into()))?; + + let port = listen + .split(':') + .last() + .unwrap_or("443") + .parse::() + .map_err(|_| PonyError::Custom("Hysteria2: invalid port".into()))?; + + let host = server + .acme + .as_ref() + .and_then(|a| a.domains.as_ref()) + .and_then(|d| d.first()) + .cloned() + .ok_or_else(|| PonyError::Custom("Hysteria2: acme.domains missing".into()))?; + + let auth_info = server.auth.map(|a| H2AuthInfo { + auth_type: a.r#type.unwrap_or_else(|| "unknown".into()), + has_password: a.password.is_some(), + has_url: a.url.is_some(), + }); + + let obfs = server.obfs.map(|o| H2Obfs { + r#type: o.r#type.unwrap_or_default(), + password: o.password.unwrap_or_default(), + }); + + Ok(H2Settings { + host: host.clone(), + port, + sni: Some(host), + insecure: false, + alpn: Some(vec!["h2".into(), "http/1.1".into()]), + obfs, + up_mbps: None, + down_mbps: None, + auth_info, + }) + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 11ff1442..7e7d3468 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,3 +1,4 @@ +pub mod h2; pub mod settings; pub mod wireguard; pub mod xray; diff --git a/src/config/settings.rs b/src/config/settings.rs index c8359bf9..08755319 100644 --- a/src/config/settings.rs +++ b/src/config/settings.rs @@ -68,6 +68,12 @@ fn default_debug_web_server() -> Option { fn default_debug_web_port() -> u16 { 3001 } +fn default_auth_web_server() -> Option { + Some(Ipv4Addr::new(127, 0, 0, 1)) +} +fn default_auth_web_port() -> u16 { + 3005 +} fn default_api_web_listen() -> Option { Some(Ipv4Addr::new(127, 0, 0, 1)) } @@ -75,7 +81,7 @@ fn default_api_web_port() -> u16 { 3005 } fn default_api_token() -> String { - "supetsecrettoken".to_string() + "token".to_string() } fn default_label() -> String { "🏴‍☠️🏴‍☠️🏴‍☠️ dev".to_string() @@ -86,11 +92,11 @@ fn default_stat_job_interval() -> u64 { } fn default_snapshot_interval() -> u64 { - 300 + 30 } fn default_snapshot_path() -> String { - "snapshots/agent_snapshot.bin".to_string() + "snapshots/snapshot.bin".to_string() } fn default_metrics_interval() -> u64 { @@ -135,6 +141,10 @@ fn default_web_host() -> String { "https://frkn.org".to_string() } +fn default_h2_config_path() -> String { + "dev/h2.yaml".to_string() +} + #[derive(Clone, Debug, Deserialize, Default)] pub struct ApiServiceConfig { #[serde(default = "default_api_web_listen")] @@ -175,6 +185,18 @@ pub struct ApiAccessConfig { pub token: String, } +#[derive(Clone, Debug, Deserialize, Default)] +pub struct AuthServiceConfig { + #[serde(default = "default_snapshot_interval")] + pub snapshot_interval: u64, + #[serde(default = "default_snapshot_path")] + pub snapshot_path: String, + #[serde(default = "default_auth_web_server")] + pub web_server: Option, + #[serde(default = "default_auth_web_port")] + pub web_port: u16, +} + #[derive(Clone, Debug, Deserialize, Default)] pub struct AgentConfig { #[serde(default = "default_disabled")] @@ -264,13 +286,8 @@ impl NodeConfig { raw.hostname.unwrap() }; - // Если указан адрес - используем его let (address, interface) = if let Some(user_address) = raw.address { - // Пользователь указал адрес (скорее всего внешний/публичный) - - // Для интерфейса используем либо указанный, либо дефолтный let interface = if let Some(ref interface_name) = raw.default_interface { - // Проверяем существование интерфейса let interfaces = get_interfaces(); if let Some(_interface) = interfaces.iter().find(|i| &i.name == interface_name) { interface_name.clone() @@ -280,11 +297,9 @@ impl NodeConfig { )); } } else { - // Используем дефолтный интерфейс match get_default_interface() { Ok(interface) => interface.name, Err(e) => { - // Если не можем получить дефолтный интерфейс, используем placeholder eprintln!( "Warning: Cannot get default interface: {}. Using 'default'.", e @@ -296,7 +311,6 @@ impl NodeConfig { (user_address, interface) } else if let Some(ref interface_name) = raw.default_interface { - // Пользователь указал только интерфейс, адрес берем с интерфейса let interfaces = get_interfaces(); if let Some(interface) = interfaces.iter().find(|i| &i.name == interface_name) { match interface.ipv4.first() { @@ -370,7 +384,7 @@ pub struct XrayConfig { #[derive(Clone, Default, Debug, Deserialize)] pub struct WgConfig { - #[serde(default = "default_enabled")] + #[serde(default = "default_disabled")] pub enabled: bool, #[serde(default = "default_wg_port")] pub port: u16, @@ -383,6 +397,14 @@ pub struct WgConfig { pub dns: Option>, } +#[derive(Clone, Default, Debug, Deserialize)] +pub struct H2Config { + #[serde(default = "default_disabled")] + pub enabled: bool, + #[serde(default = "default_h2_config_path")] + pub path: String, +} + #[derive(Clone, Debug, Deserialize, Default)] pub struct ZmqSubscriberConfig { #[serde(default = "default_zmq_sub_endpoint")] @@ -457,6 +479,22 @@ pub struct ApiSettings { pub carbon: CarbonConfig, } +#[derive(Clone, Debug, Deserialize)] +pub struct AuthServiceSettings { + #[serde(default)] + pub debug: DebugConfig, + #[serde(default)] + pub logging: LoggingConfig, + #[serde(default)] + pub auth: AuthServiceConfig, + #[serde(default)] + pub zmq: ZmqSubscriberConfig, + #[serde(default)] + pub node: NodeConfigRaw, + #[serde(default)] + pub api: ApiAccessConfig, +} + #[derive(Clone, Debug, Deserialize)] pub struct AgentSettings { #[serde(default)] @@ -472,6 +510,8 @@ pub struct AgentSettings { #[serde(default)] pub wg: WgConfig, #[serde(default)] + pub h2: H2Config, + #[serde(default)] pub zmq: ZmqSubscriberConfig, #[serde(default)] pub node: NodeConfigRaw, @@ -492,3 +532,10 @@ impl Settings for ApiSettings { Ok(()) } } + +impl Settings for AuthServiceSettings { + fn validate(&self) -> Result<()> { + self.zmq.clone().validate()?; + Ok(()) + } +} diff --git a/src/config/xray.rs b/src/config/xray.rs index a61e8ebb..b1cd8d52 100644 --- a/src/config/xray.rs +++ b/src/config/xray.rs @@ -1,11 +1,11 @@ use serde::{Deserialize, Serialize}; use std::{fs::File, io::Read}; +use crate::config::h2::H2Settings; use crate::config::wireguard::WireguardSettings; use crate::http::requests::InboundResponse; use crate::memory::node::Stat as InboundStat; use crate::memory::tag::ProtoTag as Tag; - use crate::Result; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -82,6 +82,7 @@ pub struct Inbound { pub downlink: Option, pub conn_count: Option, pub wg: Option, + pub h2: Option, } impl Inbound { @@ -91,6 +92,7 @@ impl Inbound { stream_settings: self.stream_settings.clone(), tag: self.tag, wg: self.wg.clone(), + h2: self.h2.clone(), } } diff --git a/src/error.rs b/src/error.rs index 992a3ae0..70d3bfbe 100644 --- a/src/error.rs +++ b/src/error.rs @@ -97,6 +97,12 @@ impl From for PonyError { } } +impl From for PonyError { + fn from(err: anyhow::Error) -> Self { + PonyError::Custom(err.to_string()) + } +} + impl From> for PonyError { fn from(err: tokio::sync::mpsc::error::SendError) -> Self { PonyError::Custom(format!("SendError: {:?}", err)) diff --git a/src/h2_op/mod.rs b/src/h2_op/mod.rs new file mode 100644 index 00000000..5a2d6f10 --- /dev/null +++ b/src/h2_op/mod.rs @@ -0,0 +1,52 @@ +use crate::http::requests::InboundResponse; +use crate::{PonyError, Result}; +use url::Url; + +pub fn hysteria2_conn( + inbound: &InboundResponse, + label: &str, + token: &Option, +) -> Result { + let port = inbound.port; + let h2 = inbound.h2.as_ref().ok_or(PonyError::Custom( + "Hysteria2: H2Settings missing".to_string(), + ))?; + + if let Some(inb) = &inbound.h2 { + let hostname = inb.host.clone(); + + let obfs_type = h2 + .obfs + .as_ref() + .map(|o| o.r#type.clone()) + .unwrap_or_default(); + let obfs_pass = h2 + .obfs + .as_ref() + .map(|o| o.password.clone()) + .unwrap_or_default(); + + let alpn = h2.alpn.as_ref().map(|v| v.join(",")).unwrap_or_default(); + + if let Some(token) = token { + let mut url = Url::parse(&format!("hysteria://{token}@{hostname}:{port}"))?; + url.query_pairs_mut() + .append_pair("host", &h2.host) + .append_pair("sni", h2.sni.as_deref().unwrap_or("")) + .append_pair("insecure", &h2.insecure.to_string()) + .append_pair("obfs", &obfs_type) + .append_pair("obfs-pass", &obfs_pass) + .append_pair("alpn", &alpn) + .append_pair("up-mbps", &h2.up_mbps.unwrap_or(0).to_string()) + .append_pair("down-mbps", &h2.down_mbps.unwrap_or(0).to_string()); + + url.set_fragment(Some(label)); + + return Ok(url.to_string()); + } else { + Err(PonyError::Custom("Token is not valid".to_string()).into()) + } + } else { + Err(PonyError::Custom("H2 Inbound is not valid".to_string()).into()) + } +} diff --git a/src/http/debug.rs b/src/http/debug.rs index 0cfe8b3e..f68e9c33 100644 --- a/src/http/debug.rs +++ b/src/http/debug.rs @@ -105,7 +105,7 @@ pub async fn handle_debug_connection( memory: Arc>>, ) where N: NodeStorageOp + Sync + Send + Clone + 'static, - C: ConnectionBaseOp + Sync + Send + Clone + 'static + std::fmt::Display, + C: ConnectionBaseOp + Sync + Send + Clone + 'static + fmt::Display, S: SubscriptionOp + Sync + Send + Clone + 'static + std::cmp::PartialEq + serde::Serialize, { let (mut sender, mut receiver) = socket.split(); @@ -127,7 +127,11 @@ pub async fn handle_debug_connection( // COMMENT(@qezz): A `match` would probably work better here. if req.kind == "get_connections" { let memory = memory.read().await; - let conns: Vec<_> = memory.connections.keys().collect(); + let conns: Vec<_> = memory + .connections + .iter() + .map(|(k, v)| (k, v.get_proto().proto())) + .collect(); let data = serde_json::to_string(&conns).unwrap(); let response = Response { kind: Kind::Conns.to_string(), diff --git a/src/http/requests.rs b/src/http/requests.rs index 02e87d89..bbb281b8 100644 --- a/src/http/requests.rs +++ b/src/http/requests.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::Ipv4Addr; +use crate::config::h2::H2Settings; use crate::config::wireguard::WireguardSettings; use crate::config::xray::Inbound; use crate::config::xray::StreamSettings; @@ -19,6 +20,16 @@ fn default_env() -> String { "dev".to_string() } +fn default_proto() -> TagReq { + TagReq::Xray +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub enum TagReq { + Xray, + Wireguard, + Hysteria2, +} #[derive(Clone, Debug, Deserialize, Serialize)] pub struct SubIdQueryParam { pub id: uuid::Uuid, @@ -32,6 +43,8 @@ pub struct SubQueryParam { pub format: String, #[serde(default = "default_env")] pub env: String, + #[serde(default = "default_proto")] + pub proto: TagReq, } #[derive(Debug, Deserialize)] @@ -64,16 +77,10 @@ pub struct ConnQueryParam { } #[derive(Clone, Debug, Deserialize, Serialize)] -pub enum NodeType { - Xray, - Wireguard, - All, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct NodeTypeParam { - pub node_type: Option, +pub struct ConnTypeParam { + pub proto: Tag, pub last_update: Option, + pub env: String, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -128,11 +135,13 @@ pub struct InboundResponse { pub port: u16, pub stream_settings: Option, pub wg: Option, + pub h2: Option, } #[derive(Serialize, Deserialize, Clone, Debug)] pub struct ConnCreateRequest { pub env: String, + pub token: Option, pub password: Option, pub subscription_id: Option, pub proto: Tag, @@ -141,6 +150,41 @@ pub struct ConnCreateRequest { pub days: Option, } +impl ConnCreateRequest { + pub fn validate(&self) -> Result<(), String> { + if self.password.is_some() && self.wg.is_some() { + return Err("Cannot specify both password and wg".into()); + } + if self.token.is_some() && self.wg.is_some() { + return Err("Cannot specify both token and wg".into()); + } + if self.token.is_some() && self.password.is_some() { + return Err("Cannot specify both token and password".into()); + } + if !self.proto.is_wireguard() && self.wg.is_some() { + return Err("Wg params only allowed for Wireguard".into()); + } + + if !self.proto.is_wireguard() && self.node_id.is_some() { + return Err("node_id only allowed for Wireguard".into()); + } + if self.proto.is_shadowsocks() && self.password.is_none() { + return Err("Password required for Shadowsocks".into()); + } + if !self.proto.is_shadowsocks() && self.password.is_some() { + return Err("Password only allowed for Shadowsocks".into()); + } + if !self.proto.is_hysteria2() && self.token.is_some() { + return Err("Token only allowed for Hysteria2".into()); + } + if self.proto.is_hysteria2() && self.token.is_none() { + return Err("Token required for Hysteria2".into()); + } + + Ok(()) + } +} + #[derive(Serialize, Deserialize, Clone, Debug)] pub struct ConnUpdateRequest { pub env: Option, diff --git a/src/lib.rs b/src/lib.rs index 29601ec2..a5413017 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; pub mod error; +pub mod h2_op; pub mod http; pub mod memory; pub mod metrics; diff --git a/src/memory/connection/base.rs b/src/memory/connection/base.rs index 02e43dfc..ead518e9 100644 --- a/src/memory/connection/base.rs +++ b/src/memory/connection/base.rs @@ -25,7 +25,11 @@ pub struct Base { } impl Base { - pub fn new(proto: Proto, expired_at: Option>) -> Self { + pub fn new( + proto: Proto, + expired_at: Option>, + sub_id: Option, + ) -> Self { let now = Utc::now().naive_utc(); Self { @@ -33,7 +37,7 @@ impl Base { created_at: now, modified_at: now, expired_at: expired_at, - subscription_id: None, + subscription_id: sub_id, proto, is_deleted: false, } diff --git a/src/memory/connection/conn.rs b/src/memory/connection/conn.rs index b30c5107..dac5bbad 100644 --- a/src/memory/connection/conn.rs +++ b/src/memory/connection/conn.rs @@ -1,10 +1,10 @@ use chrono::DateTime; use chrono::NaiveDateTime; use chrono::Utc; +use std::fmt; use serde::Deserialize; use serde::Serialize; -use std::fmt; use super::op::api::Operations as ApiOps; use super::op::base::Operations as BasOps; @@ -21,27 +21,10 @@ pub struct Conn { pub modified_at: NaiveDateTime, pub expired_at: Option>, pub is_deleted: bool, - pub node_id: Option, -} - -impl PartialEq for Conn { - fn eq(&self, other: &Self) -> bool { - self.get_subscription_id() == other.get_subscription_id() - && self.get_proto() == other.get_proto() - && self.get_deleted() == other.get_deleted() - && self.get_env() == other.get_env() - && self.get_deleted() == other.get_deleted() - } } impl fmt::Display for Conn { - // COMMENT(qezz): This feels like a Debug implementation. That can be either derived, - // or the helper functions can be used, so one doesn't need to manually match the - // indentation - // - // More details for manual implementation: https://doc.rust-lang.org/std/fmt/trait.Debug.html fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // COMMENT(qezz): there's `writeln!()` write!(f, "Connection {{\n")?; if let Some(subscription_id) = self.subscription_id { @@ -49,29 +32,33 @@ impl fmt::Display for Conn { } else { write!(f, " subscription_id: None,\n")?; } - write!( - f, - " node_id: {},\n", - self.node_id.unwrap_or(uuid::Uuid::default()) - )?; write!(f, " env: {},\n", self.env)?; - write!(f, " conn stat: {}\n", self.stat)?; + write!(f, " conn stat: {}\n", self.stat)?; write!(f, " created_at: {},\n", self.created_at)?; write!(f, " modified_at: {},\n", self.modified_at)?; write!(f, " expired_at: {:?},\n", self.expired_at)?; write!(f, " proto: {:?},\n", self.proto)?; - write!(f, "deleted: {}\n", self.is_deleted)?; + write!(f, " deleted: {}\n", self.is_deleted)?; write!(f, "}}") } } +impl PartialEq for Conn { + fn eq(&self, other: &Self) -> bool { + self.get_subscription_id() == other.get_subscription_id() + && self.get_proto() == other.get_proto() + && self.get_deleted() == other.get_deleted() + && self.get_env() == other.get_env() + && self.get_deleted() == other.get_deleted() + } +} + impl Conn { pub fn new( env: &str, subscription_id: Option, stat: Stat, proto: Proto, - node_id: Option, expired_at: Option>, ) -> Self { let now = Utc::now().naive_utc(); @@ -85,7 +72,6 @@ impl Conn { proto: proto, subscription_id: subscription_id, is_deleted: false, - node_id: node_id, } } } diff --git a/src/memory/connection/op/api.rs b/src/memory/connection/op/api.rs index 51b19bc1..d094c8b1 100644 --- a/src/memory/connection/op/api.rs +++ b/src/memory/connection/op/api.rs @@ -37,18 +37,29 @@ impl Operations for Conn { }; let tag = self.proto.proto(); + let expires_at = self.expired_at; + + let token = match &self.proto { + Proto::Hysteria2 { token } => Some(*token), + _ => None, + }; let wg = match &self.proto { Proto::Wireguard { param, .. } => Some(param.clone()), _ => None, }; + let sub_id = self.subscription_id; + Message { conn_id: (*conn_id).into(), + subscription_id: sub_id, action: Action::Create, password, + token, tag: tag, wg, + expires_at: expires_at.map(Into::into), } } @@ -59,18 +70,29 @@ impl Operations for Conn { }; let tag = self.proto.proto(); + let expires_at = self.expired_at; let wg = match &self.proto { Proto::Wireguard { param, .. } => Some(param.clone()), _ => None, }; + let token = match &self.proto { + Proto::Hysteria2 { token } => Some(*token), + _ => None, + }; + + let sub_id = self.subscription_id; + Message { conn_id: (*conn_id).into(), + subscription_id: sub_id, action: Action::Update, password, + token, tag: tag, wg, + expires_at: expires_at.map(Into::into), } } @@ -81,18 +103,29 @@ impl Operations for Conn { }; let tag = self.proto.proto(); + let expires_at = self.expired_at; let wg = match &self.proto { Proto::Wireguard { param, .. } => Some(param.clone()), _ => None, }; + let token = match &self.proto { + Proto::Hysteria2 { token } => Some(*token), + _ => None, + }; + + let sub_id = self.subscription_id; + Message { conn_id: (*conn_id).into(), + subscription_id: sub_id, action: Action::Delete, password, + token, tag: tag, wg, + expires_at: expires_at.map(Into::into), } } } diff --git a/src/memory/connection/op/base.rs b/src/memory/connection/op/base.rs index 3f1bde4d..5e64cdeb 100644 --- a/src/memory/connection/op/base.rs +++ b/src/memory/connection/op/base.rs @@ -38,6 +38,7 @@ pub trait Operations { fn get_wireguard(&self) -> Option<&WgParam>; fn get_wireguard_node_id(&self) -> Option; fn get_password(&self) -> Option; + fn get_token(&self) -> Option; fn set_password(&mut self, password: Option) -> Result<()>; } @@ -120,6 +121,13 @@ impl Operations for Base { } } + fn get_token(&self) -> Option { + match &self.proto { + Proto::Hysteria2 { token } => Some(token.clone()), + _ => None, + } + } + fn get_wireguard(&self) -> Option<&WgParam> { match &self.proto { Proto::Wireguard { param, .. } => Some(param), @@ -237,4 +245,10 @@ impl Operations for Conn { )), } } + fn get_token(&self) -> Option { + match &self.proto { + Proto::Hysteria2 { token } => Some(token.clone()), + _ => None, + } + } } diff --git a/src/memory/connection/proto.rs b/src/memory/connection/proto.rs index 7b0cf873..85d36f14 100644 --- a/src/memory/connection/proto.rs +++ b/src/memory/connection/proto.rs @@ -13,6 +13,7 @@ pub enum Proto { Wireguard { param: WgParam, node_id: uuid::Uuid }, Shadowsocks { password: String }, Xray(Tag), + Hysteria2 { token: uuid::Uuid }, } impl Proto { @@ -20,10 +21,18 @@ impl Proto { match self { Proto::Wireguard { .. } => Tag::Wireguard, Proto::Shadowsocks { .. } => Tag::Shadowsocks, + Proto::Hysteria2 { .. } => Tag::Hysteria2, Proto::Xray(tag) => *tag, } } + pub fn token(&self) -> Option { + match self { + Proto::Hysteria2 { token } => Some(*token), + _ => None, + } + } + pub fn new_wg(param: &WgParam, node_id: &uuid::Uuid) -> Self { Proto::Wireguard { param: param.clone(), @@ -36,6 +45,9 @@ impl Proto { password: password.to_string(), } } + pub fn new_hysteria2(token: &uuid::Uuid) -> Self { + Proto::Hysteria2 { token: *token } + } pub fn new_xray(tag: &Tag) -> Self { Proto::Xray(*tag) @@ -52,4 +64,8 @@ impl Proto { pub fn is_shadowsocks(&self) -> bool { matches!(self, Proto::Shadowsocks { .. }) } + + pub fn is_hysteria2(&self) -> bool { + matches!(self, Proto::Hysteria2 { .. }) + } } diff --git a/src/memory/connection/wireguard.rs b/src/memory/connection/wireguard.rs index 4d77e849..0705058e 100644 --- a/src/memory/connection/wireguard.rs +++ b/src/memory/connection/wireguard.rs @@ -11,6 +11,7 @@ use x25519_dalek::{PublicKey, StaticSecret}; #[derive( Archive, Clone, Debug, Serialize, Deserialize, PartialEq, RkyvDeserialize, RkyvSerialize, )] +#[archive(check_bytes)] pub struct Keys { pub privkey: String, pub pubkey: String, @@ -34,6 +35,7 @@ impl Default for Keys { #[derive( Archive, Serialize, Deserialize, RkyvSerialize, RkyvDeserialize, Clone, Debug, PartialEq, )] +#[archive(check_bytes)] pub struct IpAddrMaskSerializable { pub addr: String, pub cidr: u8, @@ -65,6 +67,7 @@ impl fmt::Display for IpAddrMaskSerializable { #[derive( Archive, Clone, Debug, Serialize, Deserialize, RkyvDeserialize, RkyvSerialize, PartialEq, )] +#[archive(check_bytes)] pub struct Param { pub keys: Keys, pub address: IpAddrMaskSerializable, diff --git a/src/memory/node.rs b/src/memory/node.rs index 41ec211e..6a903adb 100644 --- a/src/memory/node.rs +++ b/src/memory/node.rs @@ -8,6 +8,7 @@ use postgres_types::{FromSql, ToSql}; use serde::{Deserialize, Serialize}; use super::tag::ProtoTag as Tag; +use crate::config::h2::H2Settings; use crate::config::settings::NodeConfig; use crate::config::wireguard::WireguardSettings; use crate::config::xray::{Config as XrayConfig, Inbound}; @@ -72,6 +73,7 @@ impl Node { settings: NodeConfig, xray_config: Option, wg_config: Option, + h2_config: Option, ) -> Self { let now = Utc::now(); let mut inbounds: HashMap = HashMap::new(); @@ -98,6 +100,23 @@ impl Node { downlink: None, conn_count: None, wg: wg_config, + h2: None, + }, + ); + } + + if let Some(ref config) = h2_config { + inbounds.insert( + Tag::Hysteria2, + Inbound { + port: config.port, + tag: Tag::Hysteria2, + stream_settings: None, + uplink: None, + downlink: None, + conn_count: None, + wg: None, + h2: h2_config, }, ); } diff --git a/src/memory/snapshot.rs b/src/memory/snapshot.rs index 7fafef55..afc42ef5 100644 --- a/src/memory/snapshot.rs +++ b/src/memory/snapshot.rs @@ -154,6 +154,11 @@ where Ok(Some(archived.timestamp)) } + + pub async fn count(&self) -> usize { + let mem = self.memory.read().await; + mem.connections.0.len() + } } impl> Connections { diff --git a/src/memory/storage/connection.rs b/src/memory/storage/connection.rs index 37bc999f..ff1fb70d 100644 --- a/src/memory/storage/connection.rs +++ b/src/memory/storage/connection.rs @@ -11,6 +11,7 @@ use super::super::tag::ProtoTag as Tag; use crate::error::{PonyError, Result}; use crate::http::requests::ConnUpdateRequest; use crate::Connection; +use crate::Proto; pub trait ApiOp where @@ -37,6 +38,7 @@ where fn update_downlink(&mut self, conn_id: &uuid::Uuid, new_downlink: i64) -> Result<()>; fn update_online(&mut self, conn_id: &uuid::Uuid, new_online: i64) -> Result<()>; fn update_stats(&mut self, conn_id: &uuid::Uuid, stats: ConnectionStat) -> Result<()>; + fn validate_token(&self, token: &uuid::Uuid) -> Option; } impl BaseOp for Connections @@ -47,6 +49,17 @@ where self.0.len() } + fn validate_token(&self, token: &uuid::Uuid) -> Option { + self.iter() + .find(|(_, conn)| { + matches!( + conn.get_proto(), + Proto::Hysteria2 { token: t } if t == *token + ) + }) + .map(|(id, _)| *id) + } + fn add(&mut self, conn_id: &uuid::Uuid, new_conn: C) -> Result { match self.entry(*conn_id) { Entry::Occupied(_) => return Ok(OperationStatus::AlreadyExist(*conn_id)), diff --git a/src/memory/tag.rs b/src/memory/tag.rs index bd35fd50..222f6163 100644 --- a/src/memory/tag.rs +++ b/src/memory/tag.rs @@ -36,6 +36,8 @@ pub enum ProtoTag { Shadowsocks, #[serde(rename = "Wireguard")] Wireguard, + #[serde(rename = "Hysteria2")] + Hysteria2, } impl fmt::Display for ProtoTag { @@ -47,6 +49,7 @@ impl fmt::Display for ProtoTag { ProtoTag::Vmess => write!(f, "Vmess"), ProtoTag::Shadowsocks => write!(f, "Shadowsocks"), ProtoTag::Wireguard => write!(f, "Wireguard"), + ProtoTag::Hysteria2 => write!(f, "Hysteria2"), } } } @@ -59,6 +62,9 @@ impl ProtoTag { pub fn is_shadowsocks(&self) -> bool { *self == ProtoTag::Shadowsocks } + pub fn is_hysteria2(&self) -> bool { + *self == ProtoTag::Hysteria2 + } } impl std::str::FromStr for ProtoTag { @@ -72,6 +78,7 @@ impl std::str::FromStr for ProtoTag { "Vmess" => Ok(ProtoTag::Vmess), "Shadowsocks" => Ok(ProtoTag::Shadowsocks), "Wireguard" => Ok(ProtoTag::Wireguard), + "Hysteria2" => Ok(ProtoTag::Hysteria2), _ => Err(()), } } diff --git a/src/utils.rs b/src/utils.rs index 585af8c0..57ce1193 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -11,6 +11,7 @@ use std::time::Instant; use tokio::time::{sleep, Duration as TokioDuration}; use url::Url; +use crate::h2_op::hysteria2_conn; use crate::http::requests::InboundResponse; use crate::memory::tag::ProtoTag as Tag; use crate::xray_op::vless::vless_grpc_conn; @@ -146,11 +147,13 @@ pub fn create_conn_link( inbound: InboundResponse, label: &str, address: Ipv4Addr, + token: &Option, ) -> Result { let raw_link = match tag { Tag::VlessTcpReality => vless_xtls_conn(conn_id, address, inbound.clone(), label), Tag::VlessGrpcReality => vless_grpc_conn(conn_id, address, inbound.clone(), label), Tag::VlessXhttpReality => vless_xhttp_conn(conn_id, address, inbound.clone(), label), + Tag::Hysteria2 => hysteria2_conn(&inbound.clone(), label, token), Tag::Vmess => vmess_tcp_conn(conn_id, address, inbound.clone(), label), _ => return Err(PonyError::Custom("Cannot complete conn line".into())), }?; diff --git a/src/xray_op/client.rs b/src/xray_op/client.rs index e83778f6..998d1a3d 100644 --- a/src/xray_op/client.rs +++ b/src/xray_op/client.rs @@ -109,7 +109,7 @@ impl HandlerActions for Arc> { "Create SS user error, password not provided".to_string(), )) } - _ => Err(PonyError::Custom("Not supported by Xray".into())), + _ => Err(PonyError::Custom("Not supported Proto".into())), } } @@ -145,9 +145,7 @@ impl HandlerActions for Arc> { "Remove SS user error, password not provided".to_string(), )) } - Tag::Wireguard => Err(crate::PonyError::Custom( - "Removing Wireguard is not implemented".to_string(), - )), + _ => Err(PonyError::Custom("Not supported Proto".into())), } } } diff --git a/src/zmq/message.rs b/src/zmq/message.rs index 3e6f896a..9bb090fb 100644 --- a/src/zmq/message.rs +++ b/src/zmq/message.rs @@ -1,10 +1,35 @@ use crate::memory::tag::ProtoTag; +use chrono::{DateTime, Utc}; + use rkyv::{Archive, Deserialize, Serialize}; use serde::{Deserialize as SerdeDes, Serialize as SerdeSer}; use std::fmt; use crate::memory::connection::wireguard::Param as WgParam; +#[derive(Archive, Deserialize, Serialize, Debug, Clone)] +#[archive(check_bytes)] +#[archive_attr(derive(Debug))] +pub struct RkyvDateTime { + timestamp: i64, + nanos: u32, +} + +impl From> for RkyvDateTime { + fn from(dt: DateTime) -> Self { + Self { + timestamp: dt.timestamp(), + nanos: dt.timestamp_subsec_nanos(), + } + } +} + +impl From for DateTime { + fn from(rkyv_dt: RkyvDateTime) -> Self { + DateTime::from_timestamp(rkyv_dt.timestamp, rkyv_dt.nanos).expect("Invalid timestamp") + } +} + #[derive(Archive, Serialize, Deserialize, SerdeSer, SerdeDes, Debug, Clone)] #[archive(check_bytes)] pub enum Action { @@ -25,7 +50,7 @@ impl fmt::Display for Action { } } -#[derive(Archive, Serialize, Deserialize, SerdeSer, SerdeDes, Clone, Debug)] +#[derive(Archive, Serialize, Deserialize, Clone, Debug)] #[archive(check_bytes)] pub struct Message { pub conn_id: uuid::Uuid, @@ -33,13 +58,16 @@ pub struct Message { pub tag: ProtoTag, pub wg: Option, pub password: Option, + pub token: Option, + pub expires_at: Option, + pub subscription_id: Option, } impl fmt::Display for Message { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "{} | {} | {} | {} | {}", + "{} | {} | {} | {} | {} | {} | {}", self.conn_id.clone(), self.action, self.tag, @@ -50,6 +78,14 @@ impl fmt::Display for Message { match &self.password { Some(pw) => pw.as_ref(), None => "-", + }, + match &self.token { + Some(t) => t.to_string(), + None => "-".to_string(), + }, + match &self.expires_at { + Some(exp_at) => format!("{:?}", exp_at), + None => "-".to_string(), } ) }