fix: harden event handling and source safety

Improve runtime resilience by bounding the reload event queue and processing telemetry updates without per-point spawned tasks. Also reduce security risk by sanitizing source responses, avoiding internal error detail leaks, and standardizing write-key configuration with backward compatibility.

Made-with: Cursor
This commit is contained in:
caoqianming 2026-03-13 14:22:16 +08:00
parent 6f215162a3
commit 5406568969
6 changed files with 159 additions and 233 deletions

View File

@ -17,7 +17,10 @@ impl AppConfig {
.unwrap_or_else(|_| "60309".to_string()) .unwrap_or_else(|_| "60309".to_string())
.parse::<u16>() .parse::<u16>()
.map_err(|_| "PORT must be a number")?; .map_err(|_| "PORT must be a number")?;
let write_api_key = env::var("WRITE_KEY").ok(); // Prefer WRITE_API_KEY, keep WRITE_KEY as backward-compatible fallback.
let write_api_key = env::var("WRITE_API_KEY")
.ok()
.or_else(|| env::var("WRITE_KEY").ok());
Ok(Self { Ok(Self {
database_url, database_url,

View File

@ -1,6 +1,8 @@
use tokio::sync::mpsc; use tokio::sync::mpsc;
use uuid::Uuid; use uuid::Uuid;
const EVENT_CHANNEL_CAPACITY: usize = 4096;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ReloadEvent { pub enum ReloadEvent {
SourceCreate { SourceCreate {
@ -24,7 +26,7 @@ pub enum ReloadEvent {
} }
pub struct EventManager { pub struct EventManager {
sender: mpsc::UnboundedSender<ReloadEvent>, sender: mpsc::Sender<ReloadEvent>,
} }
impl EventManager { impl EventManager {
@ -33,7 +35,7 @@ impl EventManager {
connection_manager: std::sync::Arc<crate::connection::ConnectionManager>, connection_manager: std::sync::Arc<crate::connection::ConnectionManager>,
ws_manager: Option<std::sync::Arc<crate::websocket::WebSocketManager>>, ws_manager: Option<std::sync::Arc<crate::websocket::WebSocketManager>>,
) -> Self { ) -> Self {
let (sender, mut receiver) = mpsc::unbounded_channel::<ReloadEvent>(); let (sender, mut receiver) = mpsc::channel::<ReloadEvent>(EVENT_CHANNEL_CAPACITY);
let ws_manager_clone = ws_manager.clone(); let ws_manager_clone = ws_manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
@ -138,34 +140,27 @@ impl EventManager {
value_changed, value_changed,
}; };
// 克隆 monitor用于并行执行 // Process in event worker directly to avoid per-point spawn overhead.
let monitor_for_ws = monitor.clone(); if let Err(e) = connection_manager_clone
let monitor_for_db = monitor.clone(); .update_point_monitor_data(monitor.clone())
.await
{
tracing::error!(
"Failed to update point monitor data for point {}: {}",
point_id,
e
);
}
// 并行执行 update_point_monitor_data 和 send_to_public不等待完成 if let Some(ws_manager) = &ws_manager_clone {
let cm_clone = connection_manager_clone.clone(); let ws_message = crate::websocket::WsMessage::PointNewValue(monitor);
tokio::spawn(async move { if let Err(e) = ws_manager.send_to_public(ws_message).await {
// 更新监控数据 tracing::error!(
if let Err(e) = cm_clone.update_point_monitor_data(monitor_for_db).await { "Failed to send WebSocket message to public room: {}",
tracing::error!("Failed to update point monitor data for point {}: {}", point_id, e); e
);
} }
}); }
let ws_clone = ws_manager_clone.clone();
tokio::spawn(async move {
// 发送WebSocket消息
if let Some(ws_manager) = ws_clone {
let ws_message = crate::websocket::WsMessage::PointNewValue(monitor_for_ws);
if let Err(e) = ws_manager.send_to_public(ws_message).await {
tracing::error!("Failed to send WebSocket message to public room: {}", e);
}
// 暂时注释掉 send_to_client因为现在信息只需发送到 public
// if let Err(e) = ws_manager.send_to_client(point_id, ws_message).await {
// tracing::error!("Failed to send WebSocket message to client room {}: {}", point_id, e);
// }
}
});
} else { } else {
tracing::warn!("Point not found for source {} client_handle {}", source_id, client_handle); tracing::warn!("Point not found for source {} client_handle {}", source_id, client_handle);
} }
@ -178,8 +173,23 @@ impl EventManager {
} }
pub fn send(&self, event: ReloadEvent) -> Result<(), String> { pub fn send(&self, event: ReloadEvent) -> Result<(), String> {
self.sender match self.sender.try_send(event) {
.send(event) Ok(()) => Ok(()),
.map_err(|e| format!("Failed to send event: {}", e)) Err(tokio::sync::mpsc::error::TrySendError::Closed(e)) => {
Err(format!("Failed to send event: channel closed ({e:?})"))
}
Err(tokio::sync::mpsc::error::TrySendError::Full(ReloadEvent::PointNewValue(payload))) => {
// High-frequency telemetry is lossy by design under sustained pressure.
tracing::warn!(
"Dropping PointNewValue due to full event queue: source={}, client_handle={}",
payload.source_id,
payload.client_handle
);
Ok(())
}
Err(tokio::sync::mpsc::error::TrySendError::Full(e)) => {
Err(format!("Failed to send event: queue full ({e:?})"))
}
}
} }
} }

View File

@ -2,7 +2,7 @@ use axum::{Json, extract::{Path, Query, State}, http::HeaderMap, response::IntoR
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use validator::Validate; use validator::Validate;
use sqlx::Row; use sqlx::{Row, QueryBuilder};
use crate::util::{response::ApiErr, pagination::{PaginatedResponse, PaginationParams}}; use crate::util::{response::ApiErr, pagination::{PaginatedResponse, PaginationParams}};
@ -69,12 +69,7 @@ pub async fn get_point(
Path(point_id): Path<Uuid>, Path(point_id): Path<Uuid>,
) -> Result<impl IntoResponse, ApiErr> { ) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool; let pool = &state.pool;
let point = sqlx::query_as::<_, Point>( let point = crate::service::get_point_by_id(pool, point_id).await?;
r#"SELECT * FROM point WHERE id = $1"#,
)
.bind(point_id)
.fetch_optional(pool)
.await?;
Ok(Json(point)) Ok(Json(point))
} }
@ -136,51 +131,26 @@ pub async fn update_point(
return Err(ApiErr::NotFound("Point not found".to_string(), None)); return Err(ApiErr::NotFound("Point not found".to_string(), None));
} }
// Build dynamic UPDATE SQL for provided fields. let mut qb = QueryBuilder::new("UPDATE point SET ");
let mut updates = Vec::new(); let mut sep = qb.separated(", ");
let mut values: Vec<String> = Vec::new();
let mut param_count = 1;
if let Some(name) = &payload.name { if let Some(name) = &payload.name {
updates.push(format!("name = ${}", param_count)); sep.push("name = ").push_bind(name);
values.push(name.clone());
param_count += 1;
} }
if let Some(description) = &payload.description { if let Some(description) = &payload.description {
updates.push(format!("description = ${}", param_count)); sep.push("description = ").push_bind(description);
values.push(description.clone());
param_count += 1;
} }
if let Some(unit) = &payload.unit { if let Some(unit) = &payload.unit {
updates.push(format!("unit = ${}", param_count)); sep.push("unit = ").push_bind(unit);
values.push(unit.clone());
param_count += 1;
} }
if let Some(tag_id) = &payload.tag_id { if let Some(tag_id) = &payload.tag_id {
updates.push(format!("tag_id = ${}", param_count)); sep.push("tag_id = ").push_bind(tag_id);
values.push(tag_id.to_string());
param_count += 1;
} }
// Always update timestamp. sep.push("updated_at = NOW()");
updates.push("updated_at = NOW()".to_string());
let sql = format!( qb.push(" WHERE id = ").push_bind(point_id);
"UPDATE point SET {} WHERE id = ${}", qb.build().execute(pool).await?;
updates.join(", "),
param_count
);
values.push(point_id.to_string());
let mut query = sqlx::query(&sql);
for value in &values {
query = query.bind(value);
}
query.execute(pool).await?;
Ok(Json(serde_json::json!({"ok_msg": "Point updated successfully"}))) Ok(Json(serde_json::json!({"ok_msg": "Point updated successfully"})))
} }
@ -489,7 +459,7 @@ pub async fn batch_set_point_value(
return Err(ApiErr::Forbidden( return Err(ApiErr::Forbidden(
"write permission denied".to_string(), "write permission denied".to_string(),
Some(serde_json::json!({ Some(serde_json::json!({
"hint": "set WRITE_API_KEY and pass header X-Write-Key" "hint": "set WRITE_API_KEY (or legacy WRITE_KEY) and pass header X-Write-Key"
})), })),
)); ));
} }

View File

@ -58,18 +58,47 @@ impl TreeNode {
#[derive(Debug, Serialize, Clone)] #[derive(Debug, Serialize, Clone)]
pub struct SourceWithStatus { pub struct SourceWithStatus {
#[serde(flatten)] #[serde(flatten)]
pub source: Source, pub source: SourcePublic,
pub is_connected: bool, pub is_connected: bool,
pub last_error: Option<String>, pub last_error: Option<String>,
#[serde(serialize_with = "crate::util::datetime::option_utc_to_local_str")] #[serde(serialize_with = "crate::util::datetime::option_utc_to_local_str")]
pub last_time: Option<DateTime<Utc>>, pub last_time: Option<DateTime<Utc>>,
} }
#[derive(Debug, Serialize, Clone)]
pub struct SourcePublic {
pub id: Uuid,
pub name: String,
pub protocol: String,
pub endpoint: String,
pub security_policy: Option<String>,
pub security_mode: Option<String>,
pub enabled: bool,
#[serde(serialize_with = "crate::util::datetime::utc_to_local_str")]
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "crate::util::datetime::utc_to_local_str")]
pub updated_at: DateTime<Utc>,
}
impl From<Source> for SourcePublic {
fn from(source: Source) -> Self {
Self {
id: source.id,
name: source.name,
protocol: source.protocol,
endpoint: source.endpoint,
security_policy: source.security_policy,
security_mode: source.security_mode,
enabled: source.enabled,
created_at: source.created_at,
updated_at: source.updated_at,
}
}
}
pub async fn get_source_list(State(state): State<AppState>) -> Result<impl IntoResponse, ApiErr> { pub async fn get_source_list(State(state): State<AppState>) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool; let pool = &state.pool;
let sources: Vec<Source> = sqlx::query_as( let sources: Vec<Source> = crate::service::get_all_enabled_sources(pool).await?;
r#"SELECT * FROM source where enabled is true"#,
).fetch_all(pool).await?;
// 获取所有连接状态 // 获取所有连接状态
let status_map: std::collections::HashMap<Uuid, (bool, Option<String>, Option<DateTime<Utc>>)> = let status_map: std::collections::HashMap<Uuid, (bool, Option<String>, Option<DateTime<Utc>>)> =
@ -87,7 +116,7 @@ pub async fn get_source_list(State(state): State<AppState>) -> Result<impl IntoR
.map(|(connected, error, time)| (*connected, error.clone(), *time)) .map(|(connected, error, time)| (*connected, error.clone(), *time))
.unwrap_or((false, None, None)); .unwrap_or((false, None, None));
SourceWithStatus { SourceWithStatus {
source, source: source.into(),
is_connected, is_connected,
last_error, last_error,
last_time, last_time,
@ -444,159 +473,67 @@ async fn process_reference(
let display_name = ref_desc.display_name.text.to_string(); let display_name = ref_desc.display_name.text.to_string();
let node_class = format!("{:?}", ref_desc.node_class); let node_class = format!("{:?}", ref_desc.node_class);
let now = Utc::now(); let effective_parent_id = if let Some(pid) = parent_id {
let node_uuid = Uuid::new_v4(); let parent_exists = sqlx::query(r#"SELECT 1 FROM node WHERE id = $1"#)
.bind(pid)
// ?? 关键优化:直接 UPSERT避免 SELECT .fetch_optional(tx.as_mut())
// 注意:如果 parent_id 存在,则必须确保该父节点已存在于数据库中 .await?;
// 否则会触发外键约束失败 if parent_exists.is_some() {
if parent_id.is_some() { Some(pid)
// 检查父节点是否已存在于数据库中
let parent_exists = sqlx::query(
r#"SELECT 1 FROM node WHERE id = $1"#,
)
.bind(parent_id.unwrap())
.fetch_optional(tx.as_mut())
.await?;
if parent_exists.is_none() {
// 如果父节点不存在,则暂时不设置 parent_id
// 这样可以避免外键约束失败
sqlx::query(
r#"
INSERT INTO node (
id,
source_id,
external_id,
namespace_uri,
namespace_index,
identifier_type,
identifier,
browse_name,
display_name,
node_class,
parent_id,
)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,NULL)
ON CONFLICT(source_id, external_id) DO UPDATE SET
namespace_uri = excluded.namespace_uri,
namespace_index = excluded.namespace_index,
identifier_type = excluded.identifier_type,
identifier = excluded.identifier,
browse_name = excluded.browse_name,
display_name = excluded.display_name,
node_class = excluded.node_class,
updated_at = NOW()
"#
)
.bind(node_uuid)
.bind(source_id)
.bind(&node_id_str)
.bind(&namespace_uri)
.bind(namespace_index.map(|v| v as i32))
.bind(&identifier_type)
.bind(&identifier)
.bind(&browse_name)
.bind(&display_name)
.bind(&node_class)
.execute(tx.as_mut())
.await
.context("Failed to execute UPSERT query")?;
} else { } else {
// 如果父节点存在,则正常设置 parent_id None
sqlx::query(
r#"
INSERT INTO node (
id,
source_id,
external_id,
namespace_uri,
namespace_index,
identifier_type,
identifier,
browse_name,
display_name,
node_class,
parent_id,
created_at,
updated_at
)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,NOW(),NOW())
ON CONFLICT(source_id, external_id) DO UPDATE SET
namespace_uri = excluded.namespace_uri,
namespace_index = excluded.namespace_index,
identifier_type = excluded.identifier_type,
identifier = excluded.identifier,
browse_name = excluded.browse_name,
display_name = excluded.display_name,
node_class = excluded.node_class,
parent_id = excluded.parent_id,
updated_at = NOW()
"#
)
.bind(node_uuid)
.bind(source_id)
.bind(&node_id_str)
.bind(&namespace_uri)
.bind(namespace_index.map(|v| v as i32))
.bind(&identifier_type)
.bind(&identifier)
.bind(&browse_name)
.bind(&display_name)
.bind(&node_class)
.bind(parent_id)
.bind(now)
.bind(now)
.execute(tx.as_mut())
.await
.context("Failed to execute UPSERT query")?;
} }
} else { } else {
// 如果没有 parent_id则正常插入 None
sqlx::query( };
r#"
INSERT INTO node ( // Use RETURNING id so queue always carries the actual DB node id.
id, let persisted_node_id = sqlx::query_scalar::<_, Uuid>(
source_id, r#"
external_id, INSERT INTO node (
namespace_uri, id,
namespace_index, source_id,
identifier_type, external_id,
identifier, namespace_uri,
browse_name, namespace_index,
display_name, identifier_type,
node_class, identifier,
parent_id browse_name,
) display_name,
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,NULL) node_class,
ON CONFLICT(source_id, external_id) DO UPDATE SET parent_id
namespace_uri = excluded.namespace_uri,
namespace_index = excluded.namespace_index,
identifier_type = excluded.identifier_type,
identifier = excluded.identifier,
browse_name = excluded.browse_name,
display_name = excluded.display_name,
node_class = excluded.node_class,
updated_at = NOW()
"#
) )
.bind(node_uuid) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
.bind(source_id) ON CONFLICT(source_id, external_id) DO UPDATE SET
.bind(&node_id_str) namespace_uri = excluded.namespace_uri,
.bind(&namespace_uri) namespace_index = excluded.namespace_index,
.bind(namespace_index.map(|v| v as i32)) identifier_type = excluded.identifier_type,
.bind(&identifier_type) identifier = excluded.identifier,
.bind(&identifier) browse_name = excluded.browse_name,
.bind(&browse_name) display_name = excluded.display_name,
.bind(&display_name) node_class = excluded.node_class,
.bind(&node_class) parent_id = COALESCE(excluded.parent_id, node.parent_id),
.execute(tx.as_mut()) updated_at = NOW()
.await RETURNING id
.context("Failed to execute UPSERT query")?; "#,
} )
.bind(Uuid::new_v4())
.bind(source_id)
.bind(&node_id_str)
.bind(&namespace_uri)
.bind(namespace_index.map(|v| v as i32))
.bind(&identifier_type)
.bind(&identifier)
.bind(&browse_name)
.bind(&display_name)
.bind(&node_class)
.bind(effective_parent_id)
.fetch_one(tx.as_mut())
.await
.context("Failed to execute UPSERT query")?;
processed_nodes.insert(node_id_str.clone(), ()); processed_nodes.insert(node_id_str.clone(), ());
queue.push_back((node_id_obj.clone(), Some(node_uuid))); queue.push_back((node_id_obj.clone(), Some(persisted_node_id)));
Ok(()) Ok(())
} }

View File

@ -17,6 +17,18 @@ pub async fn get_all_enabled_sources(pool: &PgPool) -> Result<Vec<Source>, sqlx:
.await .await
} }
pub async fn get_point_by_id(
pool: &PgPool,
point_id: uuid::Uuid,
) -> Result<Option<crate::model::Point>, sqlx::Error> {
query_as::<_, crate::model::Point>(
r#"SELECT * FROM point WHERE id = $1"#,
)
.bind(point_id)
.fetch_optional(pool)
.await
}
pub async fn get_points_grouped_by_source( pub async fn get_points_grouped_by_source(
pool: &PgPool, pool: &PgPool,
point_ids: &[uuid::Uuid], point_ids: &[uuid::Uuid],

View File

@ -68,13 +68,7 @@ impl IntoResponse for ApiErr {
impl From<Error> for ApiErr { impl From<Error> for ApiErr {
fn from(err: Error) -> Self { fn from(err: Error) -> Self {
tracing::error!("Error: {:?}; root_cause: {}", err, err.root_cause()); tracing::error!("Error: {:?}; root_cause: {}", err, err.root_cause());
ApiErr::Internal( ApiErr::Internal("internal server error".to_string(), None)
err.to_string(),
Some(serde_json::json!({
"root_cause": err.root_cause().to_string(),
"chain": err.chain().map(|e| e.to_string()).collect::<Vec<_>>()
})),
)
} }
} }