fix(engine): fix supervisor restart, deduplicate helpers, fix notify race

- engine.rs: replace HashSet<Uuid> with HashMap<Uuid, JoinHandle> in
  supervise(); use is_finished() to detect exited tasks so units that
  are disabled then re-enabled get a new task on next 10s scan
- control/mod.rs: extract shared monitor_value_as_bool (using the more
  complete validator version that includes "yes"); remove duplicate
  copies from engine.rs and validator.rs
- runtime.rs: fix get_or_create_notify TOCTOU by using entry API
  instead of read-drop-write pattern

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
caoqianming 2026-03-26 09:08:25 +08:00
parent 4ce91adf60
commit b3f92867bc
4 changed files with 38 additions and 45 deletions

View File

@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet}; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Notify; use tokio::sync::Notify;
@ -12,7 +12,7 @@ use crate::{
}, },
event::AppEvent, event::AppEvent,
service::EquipmentRolePoint, service::EquipmentRolePoint,
telemetry::{DataValue, PointMonitorInfo, PointQuality}, telemetry::{PointMonitorInfo, PointQuality},
websocket::WsMessage, websocket::WsMessage,
AppState, AppState,
}; };
@ -25,8 +25,9 @@ pub fn start(state: AppState, runtime_store: Arc<ControlRuntimeStore>) {
} }
/// Supervisor: scans for enabled units every 10 s and ensures each has a running task. /// Supervisor: scans for enabled units every 10 s and ensures each has a running task.
/// Uses JoinHandle to detect exited tasks so disabled-then-re-enabled units are restarted.
async fn supervise(state: AppState, store: Arc<ControlRuntimeStore>) { async fn supervise(state: AppState, store: Arc<ControlRuntimeStore>) {
let mut spawned: HashSet<Uuid> = HashSet::new(); let mut tasks: HashMap<Uuid, tokio::task::JoinHandle<()>> = HashMap::new();
let mut interval = tokio::time::interval(Duration::from_secs(10)); let mut interval = tokio::time::interval(Duration::from_secs(10));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
@ -35,10 +36,14 @@ async fn supervise(state: AppState, store: Arc<ControlRuntimeStore>) {
match crate::service::get_all_enabled_units(&state.pool).await { match crate::service::get_all_enabled_units(&state.pool).await {
Ok(units) => { Ok(units) => {
for unit in units { for unit in units {
if spawned.insert(unit.id) { let needs_spawn = tasks
.get(&unit.id)
.map_or(true, |h| h.is_finished());
if needs_spawn {
let s = state.clone(); let s = state.clone();
let st = store.clone(); let st = store.clone();
tokio::spawn(async move { unit_task(s, st, unit.id).await; }); let handle = tokio::spawn(async move { unit_task(s, st, unit.id).await; });
tasks.insert(unit.id, handle);
} }
} }
} }
@ -296,7 +301,7 @@ async fn check_fault_comm(
roles roles
.get("flt") .get("flt")
.and_then(|rp| monitor.get(&rp.point_id)) .and_then(|rp| monitor.get(&rp.point_id))
.map(|m| monitor_value_as_bool(m)) .map(|m| super::monitor_value_as_bool(m))
.unwrap_or(false) .unwrap_or(false)
}); });
@ -307,7 +312,7 @@ async fn check_fault_comm(
roles roles
.get("flt") .get("flt")
.and_then(|rp| monitor.get(&rp.point_id)) .and_then(|rp| monitor.get(&rp.point_id))
.map(|m| monitor_value_as_bool(m)) .map(|m| super::monitor_value_as_bool(m))
.unwrap_or(false) .unwrap_or(false)
}) })
.map(|(eq_id, _)| *eq_id) .map(|(eq_id, _)| *eq_id)
@ -404,13 +409,13 @@ fn find_cmd(
let rem_ok = roles let rem_ok = roles
.get("rem") .get("rem")
.and_then(|rp| monitor.get(&rp.point_id)) .and_then(|rp| monitor.get(&rp.point_id))
.map(|m| monitor_value_as_bool(m) && m.quality == PointQuality::Good) .map(|m| super::monitor_value_as_bool(m) && m.quality == PointQuality::Good)
.unwrap_or(true); .unwrap_or(true);
let flt_ok = roles let flt_ok = roles
.get("flt") .get("flt")
.and_then(|rp| monitor.get(&rp.point_id)) .and_then(|rp| monitor.get(&rp.point_id))
.map(|m| !monitor_value_as_bool(m) && m.quality == PointQuality::Good) .map(|m| !super::monitor_value_as_bool(m) && m.quality == PointQuality::Good)
.unwrap_or(true); .unwrap_or(true);
if rem_ok && flt_ok { if rem_ok && flt_ok {
@ -423,15 +428,3 @@ fn find_cmd(
} }
} }
fn monitor_value_as_bool(monitor: &PointMonitorInfo) -> bool {
match monitor.value.as_ref() {
Some(DataValue::Bool(v)) => *v,
Some(DataValue::Int(v)) => *v != 0,
Some(DataValue::UInt(v)) => *v != 0,
Some(DataValue::Float(v)) => *v != 0.0,
Some(DataValue::Text(v)) => {
matches!(v.trim().to_ascii_lowercase().as_str(), "1" | "true" | "on")
}
_ => false,
}
}

View File

@ -2,3 +2,18 @@ pub mod command;
pub mod engine; pub mod engine;
pub mod runtime; pub mod runtime;
pub mod validator; pub mod validator;
use crate::telemetry::{DataValue, PointMonitorInfo};
pub(crate) fn monitor_value_as_bool(monitor: &PointMonitorInfo) -> bool {
match monitor.value.as_ref() {
Some(DataValue::Bool(v)) => *v,
Some(DataValue::Int(v)) => *v != 0,
Some(DataValue::UInt(v)) => *v != 0,
Some(DataValue::Float(v)) => *v != 0.0,
Some(DataValue::Text(v)) => {
matches!(v.trim().to_ascii_lowercase().as_str(), "1" | "true" | "on" | "yes")
}
_ => false,
}
}

View File

@ -73,14 +73,12 @@ impl ControlRuntimeStore {
} }
pub async fn get_or_create_notify(&self, unit_id: Uuid) -> Arc<Notify> { pub async fn get_or_create_notify(&self, unit_id: Uuid) -> Arc<Notify> {
let read = self.notifiers.read().await; self.notifiers
if let Some(n) = read.get(&unit_id) { .write()
return n.clone(); .await
} .entry(unit_id)
drop(read); .or_insert_with(|| Arc::new(Notify::new()))
let n = Arc::new(Notify::new()); .clone()
self.notifiers.write().await.insert(unit_id, n.clone());
n
} }
/// Wake the engine task for a unit (e.g., when auto_enabled or fault_locked changes). /// Wake the engine task for a unit (e.g., when auto_enabled or fault_locked changes).

View File

@ -5,7 +5,7 @@ use uuid::Uuid;
use crate::{ use crate::{
service::EquipmentRolePoint, service::EquipmentRolePoint,
telemetry::{DataValue, PointMonitorInfo, PointQuality, ValueType}, telemetry::{PointMonitorInfo, PointQuality, ValueType},
util::response::ApiErr, util::response::ApiErr,
AppState, AppState,
}; };
@ -95,7 +95,7 @@ pub async fn validate_manual_control(
let rem_monitor = monitor_guard let rem_monitor = monitor_guard
.get(&rem_point.point_id) .get(&rem_point.point_id)
.ok_or_else(|| missing_monitor_err("REM", equipment_id))?; .ok_or_else(|| missing_monitor_err("REM", equipment_id))?;
if !monitor_value_as_bool(rem_monitor) { if !super::monitor_value_as_bool(rem_monitor) {
return Err(ApiErr::Forbidden( return Err(ApiErr::Forbidden(
"Remote control not allowed, REM is not enabled".to_string(), "Remote control not allowed, REM is not enabled".to_string(),
Some(json!({ "equipment_id": equipment_id })), Some(json!({ "equipment_id": equipment_id })),
@ -107,7 +107,7 @@ pub async fn validate_manual_control(
let flt_monitor = monitor_guard let flt_monitor = monitor_guard
.get(&flt_point.point_id) .get(&flt_point.point_id)
.ok_or_else(|| missing_monitor_err("FLT", equipment_id))?; .ok_or_else(|| missing_monitor_err("FLT", equipment_id))?;
if monitor_value_as_bool(flt_monitor) { if super::monitor_value_as_bool(flt_monitor) {
return Err(ApiErr::Forbidden( return Err(ApiErr::Forbidden(
"Equipment fault is active, command denied".to_string(), "Equipment fault is active, command denied".to_string(),
Some(json!({ "equipment_id": equipment_id })), Some(json!({ "equipment_id": equipment_id })),
@ -199,16 +199,3 @@ fn missing_monitor_err(role: &str, equipment_id: Uuid) -> ApiErr {
) )
} }
fn monitor_value_as_bool(monitor: &PointMonitorInfo) -> bool {
match monitor.value.as_ref() {
Some(DataValue::Bool(value)) => *value,
Some(DataValue::Int(value)) => *value != 0,
Some(DataValue::UInt(value)) => *value != 0,
Some(DataValue::Float(value)) => *value != 0.0,
Some(DataValue::Text(value)) => matches!(
value.trim().to_ascii_lowercase().as_str(),
"1" | "true" | "on" | "yes"
),
_ => false,
}
}