diff --git a/src/config.rs b/src/config.rs index 0d400c2..84da369 100644 --- a/src/config.rs +++ b/src/config.rs @@ -17,7 +17,10 @@ impl AppConfig { .unwrap_or_else(|_| "60309".to_string()) .parse::() .map_err(|_| "PORT must be a number")?; - let write_api_key = env::var("WRITE_KEY").ok(); + // Prefer WRITE_API_KEY, keep WRITE_KEY as backward-compatible fallback. + let write_api_key = env::var("WRITE_API_KEY") + .ok() + .or_else(|| env::var("WRITE_KEY").ok()); Ok(Self { database_url, diff --git a/src/event.rs b/src/event.rs index 4cd9f79..b8bc298 100644 --- a/src/event.rs +++ b/src/event.rs @@ -1,6 +1,8 @@ use tokio::sync::mpsc; use uuid::Uuid; +const EVENT_CHANNEL_CAPACITY: usize = 4096; + #[derive(Debug, Clone)] pub enum ReloadEvent { SourceCreate { @@ -24,7 +26,7 @@ pub enum ReloadEvent { } pub struct EventManager { - sender: mpsc::UnboundedSender, + sender: mpsc::Sender, } impl EventManager { @@ -33,7 +35,7 @@ impl EventManager { connection_manager: std::sync::Arc, ws_manager: Option>, ) -> Self { - let (sender, mut receiver) = mpsc::unbounded_channel::(); + let (sender, mut receiver) = mpsc::channel::(EVENT_CHANNEL_CAPACITY); let ws_manager_clone = ws_manager.clone(); tokio::spawn(async move { @@ -138,34 +140,27 @@ impl EventManager { value_changed, }; - // 克隆 monitor,用于并行执行 - let monitor_for_ws = monitor.clone(); - let monitor_for_db = monitor.clone(); + // Process in event worker directly to avoid per-point spawn overhead. + if let Err(e) = connection_manager_clone + .update_point_monitor_data(monitor.clone()) + .await + { + tracing::error!( + "Failed to update point monitor data for point {}: {}", + point_id, + e + ); + } - // 并行执行 update_point_monitor_data 和 send_to_public,不等待完成 - let cm_clone = connection_manager_clone.clone(); - tokio::spawn(async move { - // 更新监控数据 - if let Err(e) = cm_clone.update_point_monitor_data(monitor_for_db).await { - tracing::error!("Failed to update point monitor data for point {}: {}", point_id, e); + if let Some(ws_manager) = &ws_manager_clone { + let ws_message = crate::websocket::WsMessage::PointNewValue(monitor); + if let Err(e) = ws_manager.send_to_public(ws_message).await { + tracing::error!( + "Failed to send WebSocket message to public room: {}", + e + ); } - }); - - let ws_clone = ws_manager_clone.clone(); - tokio::spawn(async move { - // 发送WebSocket消息 - if let Some(ws_manager) = ws_clone { - let ws_message = crate::websocket::WsMessage::PointNewValue(monitor_for_ws); - if let Err(e) = ws_manager.send_to_public(ws_message).await { - tracing::error!("Failed to send WebSocket message to public room: {}", e); - } - - // 暂时注释掉 send_to_client,因为现在信息只需发送到 public - // if let Err(e) = ws_manager.send_to_client(point_id, ws_message).await { - // tracing::error!("Failed to send WebSocket message to client room {}: {}", point_id, e); - // } - } - }); + } } else { tracing::warn!("Point not found for source {} client_handle {}", source_id, client_handle); } @@ -178,8 +173,23 @@ impl EventManager { } pub fn send(&self, event: ReloadEvent) -> Result<(), String> { - self.sender - .send(event) - .map_err(|e| format!("Failed to send event: {}", e)) + match self.sender.try_send(event) { + Ok(()) => Ok(()), + Err(tokio::sync::mpsc::error::TrySendError::Closed(e)) => { + Err(format!("Failed to send event: channel closed ({e:?})")) + } + Err(tokio::sync::mpsc::error::TrySendError::Full(ReloadEvent::PointNewValue(payload))) => { + // High-frequency telemetry is lossy by design under sustained pressure. + tracing::warn!( + "Dropping PointNewValue due to full event queue: source={}, client_handle={}", + payload.source_id, + payload.client_handle + ); + Ok(()) + } + Err(tokio::sync::mpsc::error::TrySendError::Full(e)) => { + Err(format!("Failed to send event: queue full ({e:?})")) + } + } } } diff --git a/src/handler/point.rs b/src/handler/point.rs index 2900217..51e6899 100644 --- a/src/handler/point.rs +++ b/src/handler/point.rs @@ -2,7 +2,7 @@ use axum::{Json, extract::{Path, Query, State}, http::HeaderMap, response::IntoR use serde::{Deserialize, Serialize}; use uuid::Uuid; use validator::Validate; -use sqlx::Row; +use sqlx::{Row, QueryBuilder}; use crate::util::{response::ApiErr, pagination::{PaginatedResponse, PaginationParams}}; @@ -69,12 +69,7 @@ pub async fn get_point( Path(point_id): Path, ) -> Result { let pool = &state.pool; - let point = sqlx::query_as::<_, Point>( - r#"SELECT * FROM point WHERE id = $1"#, - ) - .bind(point_id) - .fetch_optional(pool) - .await?; + let point = crate::service::get_point_by_id(pool, point_id).await?; Ok(Json(point)) } @@ -136,51 +131,26 @@ pub async fn update_point( return Err(ApiErr::NotFound("Point not found".to_string(), None)); } - // Build dynamic UPDATE SQL for provided fields. - let mut updates = Vec::new(); - let mut values: Vec = Vec::new(); - let mut param_count = 1; + let mut qb = QueryBuilder::new("UPDATE point SET "); + let mut sep = qb.separated(", "); if let Some(name) = &payload.name { - updates.push(format!("name = ${}", param_count)); - values.push(name.clone()); - param_count += 1; + sep.push("name = ").push_bind(name); } - if let Some(description) = &payload.description { - updates.push(format!("description = ${}", param_count)); - values.push(description.clone()); - param_count += 1; + sep.push("description = ").push_bind(description); } - if let Some(unit) = &payload.unit { - updates.push(format!("unit = ${}", param_count)); - values.push(unit.clone()); - param_count += 1; + sep.push("unit = ").push_bind(unit); } - if let Some(tag_id) = &payload.tag_id { - updates.push(format!("tag_id = ${}", param_count)); - values.push(tag_id.to_string()); - param_count += 1; + sep.push("tag_id = ").push_bind(tag_id); } - // Always update timestamp. - updates.push("updated_at = NOW()".to_string()); + sep.push("updated_at = NOW()"); - let sql = format!( - "UPDATE point SET {} WHERE id = ${}", - updates.join(", "), - param_count - ); - values.push(point_id.to_string()); - - let mut query = sqlx::query(&sql); - for value in &values { - query = query.bind(value); - } - - query.execute(pool).await?; + qb.push(" WHERE id = ").push_bind(point_id); + qb.build().execute(pool).await?; Ok(Json(serde_json::json!({"ok_msg": "Point updated successfully"}))) } @@ -489,7 +459,7 @@ pub async fn batch_set_point_value( return Err(ApiErr::Forbidden( "write permission denied".to_string(), Some(serde_json::json!({ - "hint": "set WRITE_API_KEY and pass header X-Write-Key" + "hint": "set WRITE_API_KEY (or legacy WRITE_KEY) and pass header X-Write-Key" })), )); } diff --git a/src/handler/source.rs b/src/handler/source.rs index 843b015..effa3fd 100644 --- a/src/handler/source.rs +++ b/src/handler/source.rs @@ -58,18 +58,47 @@ impl TreeNode { #[derive(Debug, Serialize, Clone)] pub struct SourceWithStatus { #[serde(flatten)] - pub source: Source, + pub source: SourcePublic, pub is_connected: bool, pub last_error: Option, #[serde(serialize_with = "crate::util::datetime::option_utc_to_local_str")] pub last_time: Option>, } +#[derive(Debug, Serialize, Clone)] +pub struct SourcePublic { + pub id: Uuid, + pub name: String, + pub protocol: String, + pub endpoint: String, + pub security_policy: Option, + pub security_mode: Option, + pub enabled: bool, + #[serde(serialize_with = "crate::util::datetime::utc_to_local_str")] + pub created_at: DateTime, + #[serde(serialize_with = "crate::util::datetime::utc_to_local_str")] + pub updated_at: DateTime, +} + +impl From for SourcePublic { + fn from(source: Source) -> Self { + Self { + id: source.id, + name: source.name, + protocol: source.protocol, + endpoint: source.endpoint, + security_policy: source.security_policy, + security_mode: source.security_mode, + enabled: source.enabled, + created_at: source.created_at, + updated_at: source.updated_at, + } + } +} + pub async fn get_source_list(State(state): State) -> Result { let pool = &state.pool; - let sources: Vec = sqlx::query_as( - r#"SELECT * FROM source where enabled is true"#, - ).fetch_all(pool).await?; + let sources: Vec = crate::service::get_all_enabled_sources(pool).await?; // 获取所有连接状态 let status_map: std::collections::HashMap, Option>)> = @@ -87,7 +116,7 @@ pub async fn get_source_list(State(state): State) -> Result( + r#" + INSERT INTO node ( + id, + source_id, + external_id, + namespace_uri, + namespace_index, + identifier_type, + identifier, + browse_name, + display_name, + node_class, + parent_id ) - .bind(node_uuid) - .bind(source_id) - .bind(&node_id_str) - .bind(&namespace_uri) - .bind(namespace_index.map(|v| v as i32)) - .bind(&identifier_type) - .bind(&identifier) - .bind(&browse_name) - .bind(&display_name) - .bind(&node_class) - .execute(tx.as_mut()) - .await - .context("Failed to execute UPSERT query")?; - } + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11) + ON CONFLICT(source_id, external_id) DO UPDATE SET + namespace_uri = excluded.namespace_uri, + namespace_index = excluded.namespace_index, + identifier_type = excluded.identifier_type, + identifier = excluded.identifier, + browse_name = excluded.browse_name, + display_name = excluded.display_name, + node_class = excluded.node_class, + parent_id = COALESCE(excluded.parent_id, node.parent_id), + updated_at = NOW() + RETURNING id + "#, + ) + .bind(Uuid::new_v4()) + .bind(source_id) + .bind(&node_id_str) + .bind(&namespace_uri) + .bind(namespace_index.map(|v| v as i32)) + .bind(&identifier_type) + .bind(&identifier) + .bind(&browse_name) + .bind(&display_name) + .bind(&node_class) + .bind(effective_parent_id) + .fetch_one(tx.as_mut()) + .await + .context("Failed to execute UPSERT query")?; processed_nodes.insert(node_id_str.clone(), ()); - queue.push_back((node_id_obj.clone(), Some(node_uuid))); + queue.push_back((node_id_obj.clone(), Some(persisted_node_id))); Ok(()) } diff --git a/src/service.rs b/src/service.rs index f5bcdc9..70e3997 100644 --- a/src/service.rs +++ b/src/service.rs @@ -17,6 +17,18 @@ pub async fn get_all_enabled_sources(pool: &PgPool) -> Result, sqlx: .await } +pub async fn get_point_by_id( + pool: &PgPool, + point_id: uuid::Uuid, +) -> Result, sqlx::Error> { + query_as::<_, crate::model::Point>( + r#"SELECT * FROM point WHERE id = $1"#, + ) + .bind(point_id) + .fetch_optional(pool) + .await +} + pub async fn get_points_grouped_by_source( pool: &PgPool, point_ids: &[uuid::Uuid], diff --git a/src/util/response.rs b/src/util/response.rs index 6c41484..ab59fe8 100644 --- a/src/util/response.rs +++ b/src/util/response.rs @@ -68,13 +68,7 @@ impl IntoResponse for ApiErr { impl From for ApiErr { fn from(err: Error) -> Self { tracing::error!("Error: {:?}; root_cause: {}", err, err.root_cause()); - ApiErr::Internal( - err.to_string(), - Some(serde_json::json!({ - "root_cause": err.root_cause().to_string(), - "chain": err.chain().map(|e| e.to_string()).collect::>() - })), - ) + ApiErr::Internal("internal server error".to_string(), None) } }