use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, Path, State, }, response::IntoResponse, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{broadcast, RwLock}; use uuid::Uuid; /// WebSocket message payload types. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", content = "data")] pub enum WsMessage { PointValueChange(crate::telemetry::WsPointMonitorInfo), PointSetValueBatchResult(crate::connection::BatchSetPointValueRes), } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", content = "data", rename_all = "snake_case")] pub enum WsClientMessage { AuthWrite(WsAuthWriteReq), PointSetValueBatch(crate::connection::BatchSetPointValueReq), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WsAuthWriteReq { pub key: String, } /// Room manager: room_id -> broadcast sender. #[derive(Clone)] pub struct RoomManager { rooms: Arc>>>, } impl RoomManager { pub fn new() -> Self { Self { rooms: Arc::new(RwLock::new(HashMap::new())), } } /// Get or create room sender. pub async fn get_or_create_room(&self, room_id: &str) -> broadcast::Sender { let mut rooms = self.rooms.write().await; if let Some(sender) = rooms.get(room_id) { return sender.clone(); } let (sender, _) = broadcast::channel(100); rooms.insert(room_id.to_string(), sender.clone()); tracing::info!("Created new room: {}", room_id); sender } /// Get room sender if room exists. pub async fn get_room(&self, room_id: &str) -> Option> { let rooms = self.rooms.read().await; rooms.get(room_id).cloned() } /// Remove room if there are no receivers left. pub async fn remove_room_if_empty(&self, room_id: &str) { let mut rooms = self.rooms.write().await; let should_remove = rooms .get(room_id) .map(|sender| sender.receiver_count() == 0) .unwrap_or(false); if should_remove { rooms.remove(room_id); tracing::info!("Removed empty room: {}", room_id); } } /// Send message to room. /// /// Returns: /// - Ok(n): n subscribers received it /// - Ok(0): room missing or no active subscribers pub async fn send_to_room(&self, room_id: &str, message: WsMessage) -> Result { if let Some(sender) = self.get_room(room_id).await { match sender.send(message) { Ok(count) => Ok(count), // No receiver is not exceptional in push scenarios. Err(broadcast::error::SendError(_)) => Ok(0), } } else { Ok(0) } } } impl Default for RoomManager { fn default() -> Self { Self::new() } } /// WebSocket manager. #[derive(Clone)] pub struct WebSocketManager { public_room: Arc, } impl WebSocketManager { pub fn new() -> Self { Self { public_room: Arc::new(RoomManager::new()), } } /// Send message to public room. pub async fn send_to_public(&self, message: WsMessage) -> Result { self.public_room.get_or_create_room("public").await; self.public_room.send_to_room("public", message).await } /// Send message to a dedicated client room. pub async fn send_to_client(&self, client_id: Uuid, message: WsMessage) -> Result { self.public_room .send_to_room(&client_id.to_string(), message) .await } } impl Default for WebSocketManager { fn default() -> Self { Self::new() } } /// Public websocket handler. pub async fn public_websocket_handler( ws: WebSocketUpgrade, State(state): State, ) -> impl IntoResponse { let ws_manager = state.ws_manager.clone(); let app_state = state.clone(); ws.on_upgrade(move |socket| handle_socket(socket, ws_manager, "public".to_string(), app_state)) } /// Client websocket handler. pub async fn client_websocket_handler( ws: WebSocketUpgrade, Path(client_id): Path, State(state): State, ) -> impl IntoResponse { let ws_manager = state.ws_manager.clone(); let room_id = client_id.to_string(); let app_state = state.clone(); ws.on_upgrade(move |socket| handle_socket(socket, ws_manager, room_id, app_state)) } /// Handle websocket connection for one room. async fn handle_socket( mut socket: WebSocket, ws_manager: Arc, room_id: String, state: crate::AppState, ) { let room_sender = ws_manager.public_room.get_or_create_room(&room_id).await; let mut rx = room_sender.subscribe(); let mut can_write = false; loop { tokio::select! { maybe_msg = socket.recv() => { match maybe_msg { Some(Ok(msg)) => { if matches!(msg, Message::Close(_)) { break; } match msg { Message::Text(text) => { match serde_json::from_str::(&text) { Ok(WsClientMessage::AuthWrite(payload)) => { can_write = state.config.verify_write_key(&payload.key); if !can_write { tracing::warn!("WebSocket write auth failed in room {}", room_id); } } Ok(WsClientMessage::PointSetValueBatch(payload)) => { let response = if !can_write { crate::connection::BatchSetPointValueRes { success: false, err_msg: Some("write permission denied".to_string()), success_count: 0, failed_count: 0, results: vec![], } } else { match state.connection_manager.write_point_values_batch(payload).await { Ok(v) => v, Err(e) => crate::connection::BatchSetPointValueRes { success: false, err_msg: Some(e), success_count: 0, failed_count: 1, results: vec![crate::connection::SetPointValueResItem { point_id: Uuid::nil(), success: false, err_msg: Some("Internal write error".to_string()), }], }, } }; if let Err(e) = ws_manager .public_room .send_to_room(&room_id, WsMessage::PointSetValueBatchResult(response)) .await { tracing::error!( "Failed to send PointSetValueBatchResult to room {}: {}", room_id, e ); } } Err(e) => { tracing::warn!( "Invalid websocket message in room {}: {}", room_id, e ); } } } _ => { tracing::debug!("Received WebSocket message from room {}: {:?}", room_id, msg); } } } Some(Err(e)) => { tracing::error!("WebSocket error in room {}: {}", room_id, e); break; } None => break, } } room_message = rx.recv() => { match room_message { Ok(message) => match serde_json::to_string(&message) { Ok(json_str) => { if socket.send(Message::Text(json_str.into())).await.is_err() { break; } } Err(e) => { tracing::error!("Failed to serialize websocket message: {}", e); } }, Err(broadcast::error::RecvError::Lagged(skipped)) => { tracing::warn!("WebSocket room {} lagged, skipped {} messages", room_id, skipped); } Err(broadcast::error::RecvError::Closed) => break, } } } } ws_manager.public_room.remove_room_if_empty(&room_id).await; }