feat: 软件第一个版本

This commit is contained in:
caoqianming 2026-03-03 13:30:49 +08:00
commit 44f4a794d3
24 changed files with 7180 additions and 0 deletions

28
.gitignore vendored Normal file
View File

@ -0,0 +1,28 @@
# Rust build output
/target
# Environment and local secrets
.env
.env.*
!.env.example
# Local runtime data/logs
/logs/
/data/
# PKI private keys and generated cert artifacts
/pki/
*.pem
*.key
*.pfx
*.p12
# IDE/editor
.vscode/
.idea/
*.swp
*.swo
# OS files
.DS_Store
Thumbs.db

3216
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

43
Cargo.toml Normal file
View File

@ -0,0 +1,43 @@
[package]
name = "gateway_rs"
version = "0.1.0"
edition = "2021"
[dependencies]
# Async runtime
tokio = { version = "1.49", features = ["full"] }
# Web framework
axum = { version = "0.8", features = ["ws"] }
tower-http = { version = "0.6", features = ["cors"] }
# Database
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "chrono", "uuid"] }
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
# Time handling
chrono = "0.4"
time = "0.3"
# UUID
uuid = { version = "1.21", features = ["serde", "v4"] }
# OPC UA
async-opcua = { version = "0.18", features = ["client"] }
# Logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "time"] }
tracing-appender = "0.2"
# Environment variables
dotenv = "0.15"
# Validation
validator = { version = "0.20", features = ["derive"] }
# Error handling
anyhow = "1.0"

View File

@ -0,0 +1,58 @@
-- Add migration script here
-- 1⃣ Source 表
CREATE TABLE source (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name TEXT NOT NULL,
protocol TEXT NOT NULL,
endpoint TEXT NOT NULL,
security_policy TEXT,
security_mode TEXT,
username TEXT,
password TEXT,
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
unique (endpoint) -- 唯一约束endpoint 不重复
);
-- 2⃣ Node 表
CREATE TABLE node (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
source_id UUID NOT NULL,
external_id TEXT NOT NULL, -- 主查字段
namespace_uri TEXT, -- 防止index变化
namespace_index INTEGER, -- 仅作记录
identifier_type TEXT, -- 仅作记录
identifier TEXT, -- 仅作记录
browse_name TEXT NOT NULL,
display_name TEXT,
node_class TEXT NOT NULL,
parent_id UUID,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
FOREIGN KEY (source_id) REFERENCES source(id),
FOREIGN KEY (parent_id) REFERENCES node(id),
UNIQUE(source_id, external_id)
);
-- Node 常用索引
CREATE INDEX idx_node_source_id ON node(source_id);
CREATE INDEX idx_node_parent_id ON node(parent_id);
-- 3⃣ Point 表
CREATE TABLE point (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
node_id UUID NOT NULL,
name TEXT NOT NULL,
description TEXT,
unit TEXT,
scan_interval_s INTEGER NOT NULL DEFAULT 1 CHECK (scan_interval_s > 0),
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
FOREIGN KEY (node_id) REFERENCES node(id),
UNIQUE (node_id)
);
-- Point 常用索引
CREATE INDEX idx_point_node_id ON point(node_id);

View File

@ -0,0 +1,15 @@
-- Tag 表
CREATE TABLE tag (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE (name)
);
-- 在 point 表中添加 tag_id 字段
ALTER TABLE point ADD COLUMN tag_id UUID REFERENCES tag(id) ON DELETE SET NULL;
-- 常用索引
CREATE INDEX idx_point_tag_id ON point(tag_id);

40
src/config.rs Normal file
View File

@ -0,0 +1,40 @@
use std::env;
#[derive(Clone)]
pub struct AppConfig {
pub database_url: String,
pub server_host: String,
pub server_port: u16,
pub write_api_key: Option<String>,
}
impl AppConfig {
pub fn from_env() -> Result<Self, String> {
let database_url = get_env("DATABASE_URL")?;
let server_host = env::var("HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
let server_port = env::var("PORT")
.unwrap_or_else(|_| "60309".to_string())
.parse::<u16>()
.map_err(|_| "PORT must be a number")?;
let write_api_key = env::var("WRITE_KEY").ok();
Ok(Self {
database_url,
server_host,
server_port,
write_api_key,
})
}
pub fn verify_write_key(&self, key: &str) -> bool {
self.write_api_key
.as_ref()
.map(|expected| expected == key)
.unwrap_or(false)
}
}
fn get_env(key: &str) -> Result<String, String> {
env::var(key).map_err(|_| format!("Missing environment variable: {}", key))
}

1170
src/connection.rs Normal file

File diff suppressed because it is too large Load Diff

15
src/db.rs Normal file
View File

@ -0,0 +1,15 @@
use sqlx::PgPool;
use sqlx::postgres::PgPoolOptions;
use tracing::info;
pub async fn init_database(database_url: &str) -> Result<PgPool, sqlx::Error> {
let pool = PgPoolOptions::new()
.max_connections(10)
.connect(database_url)
.await?;
// MIGRATOR.run(&pool).await?;
info!("数据库已连接,如有迁移请手动执行");
Ok(pool)
}

180
src/event.rs Normal file
View File

@ -0,0 +1,180 @@
use tokio::sync::mpsc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub enum ReloadEvent {
SourceCreate {
source_id: Uuid,
},
SourceUpdate {
source_id: Uuid,
},
SourceDelete {
source_id: Uuid,
},
PointCreate {
source_id: Uuid,
point_id: Uuid,
},
PointCreateBatch {
source_id: Uuid,
point_ids: Vec<Uuid>,
},
PointDeleteBatch {
source_id: Uuid,
point_ids: Vec<Uuid>,
},
PointValueChange(crate::telemetry::PointValueChangeEvent),
}
pub struct EventManager {
sender: mpsc::UnboundedSender<ReloadEvent>,
}
impl EventManager {
pub fn new(
pool: sqlx::PgPool,
connection_manager: std::sync::Arc<crate::connection::ConnectionManager>,
ws_manager: Option<std::sync::Arc<crate::websocket::WebSocketManager>>,
) -> Self {
let (sender, mut receiver) = mpsc::unbounded_channel::<ReloadEvent>();
let ws_manager_clone = ws_manager.clone();
tokio::spawn(async move {
while let Some(event) = receiver.recv().await {
match event {
ReloadEvent::SourceCreate { source_id } => {
tracing::info!("Processing SourceCreate event for {}", source_id);
if let Err(e) = connection_manager.connect_from_source(&pool, source_id).await {
tracing::error!("Failed to connect to source {}: {}", source_id, e);
}
}
ReloadEvent::SourceUpdate { source_id } => {
tracing::info!("SourceUpdate event for {}: not implemented yet", source_id);
}
ReloadEvent::SourceDelete { source_id } => {
tracing::info!("Processing SourceDelete event for {}", source_id);
if let Err(e) = connection_manager.disconnect(source_id).await {
tracing::error!("Failed to disconnect from source {}: {}", source_id, e);
}
}
ReloadEvent::PointCreate { source_id, point_id } => {
match connection_manager
.subscribe_points_from_source(source_id, Some(vec![point_id]), &pool)
.await
{
Ok(stats) => {
let subscribed = *stats.get("subscribed").unwrap_or(&0);
let polled = *stats.get("polled").unwrap_or(&0);
let total = *stats.get("total").unwrap_or(&0);
tracing::info!(
"PointCreate subscribe finished for source {} point {}: subscribed={}, polled={}, total={}",
source_id,
point_id,
subscribed,
polled,
total
);
}
Err(e) => {
tracing::error!("Failed to subscribe to point {}: {}", point_id, e);
}
}
}
ReloadEvent::PointCreateBatch { source_id, point_ids } => {
let requested_count = point_ids.len();
match connection_manager
.subscribe_points_from_source(source_id, Some(point_ids), &pool)
.await
{
Ok(stats) => {
let subscribed = *stats.get("subscribed").unwrap_or(&0);
let polled = *stats.get("polled").unwrap_or(&0);
let total = *stats.get("total").unwrap_or(&0);
tracing::info!(
"PointCreateBatch subscribe finished for source {}: requested={}, subscribed={}, polled={}, total={}",
source_id,
requested_count,
subscribed,
polled,
total
);
}
Err(e) => {
tracing::error!("Failed to subscribe to points: {}", e);
}
}
}
ReloadEvent::PointDeleteBatch { source_id, point_ids } => {
tracing::info!(
"Processing PointDeleteBatch event for source {} with {} points",
source_id,
point_ids.len()
);
if let Err(e) = connection_manager
.unsubscribe_points_from_source(source_id, point_ids)
.await
{
tracing::error!("Failed to unsubscribe points: {}", e);
}
}
ReloadEvent::PointValueChange(payload) => {
let source_id = payload.source_id;
let client_handle = payload.client_handle;
let point_id = if let Some(point_id) = payload.point_id {
Some(point_id)
} else {
let status = connection_manager.get_status_read_guard().await;
status
.get(&source_id)
.and_then(|s| s.client_handle_map.get(&client_handle).copied())
};
if let Some(point_id) = point_id {
let monitor = crate::telemetry::PointMonitorInfo {
protocol: payload.protocol.clone(),
source_id,
point_id,
client_handle,
scan_mode: payload.scan_mode.clone(),
timestamp: payload.timestamp,
quality: payload.quality.clone(),
value: payload.value.clone(),
value_type: payload.value_type.clone(),
value_text: payload.value_text.clone(),
};
if let Err(e) = connection_manager.update_point_monitor_data(monitor.clone()).await {
tracing::error!("Failed to update point monitor data for point {}: {}", point_id, e);
}
if let Some(ws_manager) = &ws_manager_clone {
let ws_message = crate::websocket::WsMessage::PointValueChange(
crate::telemetry::WsPointMonitorInfo::from(&monitor),
);
if let Err(e) = ws_manager.send_to_public(ws_message.clone()).await {
tracing::error!("Failed to send WebSocket message to public room: {}", e);
}
if let Err(e) = ws_manager.send_to_client(point_id, ws_message).await {
tracing::error!("Failed to send WebSocket message to client room {}: {}", point_id, e);
}
}
} else {
tracing::warn!("Point not found for source {} client_handle {}", source_id, client_handle);
}
}
}
}
});
Self { sender }
}
pub fn send(&self, event: ReloadEvent) -> Result<(), String> {
self.sender
.send(event)
.map_err(|e| format!("Failed to send event: {}", e))
}
}

3
src/handler.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod source;
pub mod point;
pub mod tag;

508
src/handler/point.rs Normal file
View File

@ -0,0 +1,508 @@
se axum::{Json, extract::{Path, Query, State}, http::HeaderMap, response::IntoResponse};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use validator::Validate;
use crate::util::response::ApiErr;
use crate::{
AppState,
model::{Node, Point},
};
/// List all points.
#[derive(Deserialize)]
pub struct GetPointListQuery {
pub source_id: Option<Uuid>,
}
#[derive(Serialize)]
pub struct PointWithMonitor {
#[serde(flatten)]
pub point: Point,
pub point_monitor: Option<crate::telemetry::WsPointMonitorInfo>,
}
pub async fn get_point_list(
State(state): State<AppState>,
Query(query): Query<GetPointListQuery>,
) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool;
let points: Vec<Point> = match query.source_id {
Some(source_id) => {
sqlx::query_as::<_, Point>(
r#"
SELECT p.*
FROM point p
INNER JOIN node n ON p.node_id = n.id
WHERE n.source_id = $1
ORDER BY p.created_at
"#,
)
.bind(source_id)
.fetch_all(pool)
.await?
}
None => {
sqlx::query_as::<_, Point>(
r#"SELECT * FROM point ORDER BY created_at"#,
)
.fetch_all(pool)
.await?
}
};
let monitor_guard = state
.connection_manager
.get_point_monitor_data_read_guard()
.await;
let resp: Vec<PointWithMonitor> = points
.into_iter()
.map(|point| {
let point_monitor = monitor_guard
.get(&point.id)
.cloned()
.map(|m| crate::telemetry::WsPointMonitorInfo::from(&m));
PointWithMonitor { point, point_monitor }
})
.collect();
Ok(Json(resp))
}
/// Get a point by id.
pub async fn get_point(
State(state): State<AppState>,
Path(point_id): Path<Uuid>,
) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool;
let point = sqlx::query_as::<_, Point>(
r#"SELECT * FROM point WHERE id = $1"#,
)
.bind(point_id)
.fetch_optional(pool)
.await?;
Ok(Json(point))
}
/// Request payload for updating editable point fields.
#[derive(Deserialize, Validate)]
pub struct UpdatePointReq {
pub name: Option<String>,
pub description: Option<String>,
pub unit: Option<String>,
pub tag_id: Option<Uuid>,
}
/// Request payload for batch setting point tags.
#[derive(Deserialize, Validate)]
pub struct BatchSetPointTagsReq {
pub point_ids: Vec<Uuid>,
pub tag_id: Option<Uuid>,
}
/// Update point metadata (name/description/unit only).
pub async fn update_point(
State(state): State<AppState>,
Path(point_id): Path<Uuid>,
Json(payload): Json<UpdatePointReq>,
) -> Result<impl IntoResponse, ApiErr> {
payload.validate()?;
let pool = &state.pool;
if payload.name.is_none() && payload.description.is_none() && payload.unit.is_none() && payload.tag_id.is_none() {
return Ok(Json(serde_json::json!({"ok_msg": "No fields to update"})));
}
// If tag_id is provided, ensure tag exists.
if let Some(tag_id) = payload.tag_id {
let tag_exists = sqlx::query(
r#"SELECT 1 FROM tag WHERE id = $1"#,
)
.bind(tag_id)
.fetch_optional(pool)
.await?
.is_some();
if !tag_exists {
return Err(ApiErr::NotFound("Tag not found".to_string(), None));
}
}
// Ensure target point exists.
let existing_point = sqlx::query_as::<_, Point>(
r#"SELECT * FROM point WHERE id = $1"#,
)
.bind(point_id)
.fetch_optional(pool)
.await?;
if existing_point.is_none() {
return Err(ApiErr::NotFound("Point not found".to_string(), None));
}
// Build dynamic UPDATE SQL for provided fields.
let mut updates = Vec::new();
let mut values: Vec<String> = Vec::new();
let mut param_count = 1;
if let Some(name) = &payload.name {
updates.push(format!("name = ${}", param_count));
values.push(name.clone());
param_count += 1;
}
if let Some(description) = &payload.description {
updates.push(format!("description = ${}", param_count));
values.push(description.clone());
param_count += 1;
}
if let Some(unit) = &payload.unit {
updates.push(format!("unit = ${}", param_count));
values.push(unit.clone());
param_count += 1;
}
if let Some(tag_id) = &payload.tag_id {
updates.push(format!("tag_id = ${}", param_count));
values.push(tag_id.to_string());
param_count += 1;
}
// Always update timestamp.
updates.push("updated_at = NOW()".to_string());
let sql = format!(
"UPDATE point SET {} WHERE id = ${}",
updates.join(", "),
param_count
);
values.push(point_id.to_string());
let mut query = sqlx::query(&sql);
for value in &values {
query = query.bind(value);
}
query.execute(pool).await?;
Ok(Json(serde_json::json!({"ok_msg": "Point updated successfully"})))
}
/// Batch set point tags.
pub async fn batch_set_point_tags(
State(state): State<AppState>,
Json(payload): Json<BatchSetPointTagsReq>,
) -> Result<impl IntoResponse, ApiErr> {
payload.validate()?;
if payload.point_ids.is_empty() {
return Err(ApiErr::BadRequest("point_ids cannot be empty".to_string(), None));
}
let pool = &state.pool;
// If tag_id is provided, ensure tag exists.
if let Some(tag_id) = payload.tag_id {
let tag_exists = sqlx::query(
r#"SELECT 1 FROM tag WHERE id = $1"#,
)
.bind(tag_id)
.fetch_optional(pool)
.await?
.is_some();
if !tag_exists {
return Err(ApiErr::NotFound("Tag not found".to_string(), None));
}
}
// Check which points exist
let existing_points: Vec<Uuid> = sqlx::query(
r#"SELECT id FROM point WHERE id = ANY($1)"#,
)
.bind(&payload.point_ids)
.fetch_all(pool)
.await?
.into_iter()
.map(|row: sqlx::postgres::PgRow| row.get::<Uuid, _>("id"))
.collect();
if existing_points.is_empty() {
return Err(ApiErr::NotFound("No valid points found".to_string(), None));
}
// Update tag_id for all existing points
let result = sqlx::query(
r#"UPDATE point SET tag_id = $1, updated_at = NOW() WHERE id = ANY($2)"#,
)
.bind(payload.tag_id)
.bind(&existing_points)
.execute(pool)
.await?;
Ok(Json(serde_json::json!({
"ok_msg": "Point tags updated successfully",
"updated_count": result.rows_affected()
})))
}
/// Delete one point by id.
pub async fn delete_point(
State(state): State<AppState>,
Path(point_id): Path<Uuid>,
) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool;
let source_id = {
let grouped = crate::service::get_points_grouped_by_source(pool, &[point_id]).await?;
grouped.keys().next().copied()
};
// Ensure target point exists.
let existing_point = sqlx::query_as::<_, Point>(
r#"SELECT * FROM point WHERE id = $1"#,
)
.bind(point_id)
.fetch_optional(pool)
.await?;
if existing_point.is_none() {
return Err(ApiErr::NotFound("Point not found".to_string(), None));
}
// Delete point.
sqlx::query(
r#"delete from point WHERE id = $1"#,
)
.bind(point_id)
.execute(pool)
.await?;
if let Some(source_id) = source_id {
if let Err(e) = state
.event_manager
.send(crate::event::ReloadEvent::PointDeleteBatch {
source_id,
point_ids: vec![point_id],
})
{
tracing::error!("Failed to send PointDeleteBatch event: {}", e);
}
}
Ok(Json(serde_json::json!({"ok_msg": "Point deleted successfully"})))
}
#[derive(Deserialize, Validate)]
/// Request payload for batch point creation from node ids.
pub struct BatchCreatePointsReq {
pub node_ids: Vec<Uuid>,
}
#[derive(Serialize)]
/// Response payload for batch point creation.
pub struct BatchCreatePointsRes {
pub success_count: usize,
pub failed_count: usize,
pub failed_node_ids: Vec<Uuid>,
pub created_point_ids: Vec<Uuid>,
}
/// Batch create points by node ids.
pub async fn batch_create_points(
State(state): State<AppState>,
Json(payload): Json<BatchCreatePointsReq>,
) -> Result<impl IntoResponse, ApiErr> {
payload.validate()?;
let pool = &state.pool;
if payload.node_ids.is_empty() {
return Err(ApiErr::BadRequest("node_ids cannot be empty".to_string(), None));
}
let mut success_count = 0;
let mut failed_count = 0;
let mut failed_node_ids = Vec::new();
let mut created_point_ids = Vec::new();
// Use one transaction for the full batch.
let mut tx = pool.begin().await?;
for node_id in payload.node_ids {
// Ensure node exists.
let node_exists = sqlx::query(
r#"SELECT 1 FROM node WHERE id = $1"#,
)
.bind(node_id)
.fetch_optional(&mut *tx)
.await?
.is_some();
if !node_exists {
failed_count += 1;
failed_node_ids.push(node_id);
continue;
}
// Skip nodes that already have a point.
let point_exists = sqlx::query(
r#"SELECT 1 FROM point WHERE node_id = $1"#,
)
.bind(node_id)
.fetch_optional(&mut *tx)
.await?
.is_some();
if point_exists {
continue;
}
// Use node browse_name as default point name.
let node_info = sqlx::query_as::<_, Node>(
r#"SELECT * FROM node WHERE id = $1"#,
)
.bind(node_id)
.fetch_optional(&mut *tx)
.await?;
let name = match node_info {
Some(node) => node.browse_name.clone(),
None => format!("Point_{}", node_id),
};
let new_id = Uuid::new_v4();
sqlx::query(
r#"
INSERT INTO point (id, node_id, name)
VALUES ($1, $2, $3)
"#
)
.bind(new_id)
.bind(node_id)
.bind(&name)
.execute(&mut *tx)
.await?;
success_count += 1;
created_point_ids.push(new_id);
}
// Commit the transaction.
tx.commit().await?;
// Emit grouped create events by source.
if !created_point_ids.is_empty() {
let grouped = crate::service::get_points_grouped_by_source(pool, &created_point_ids).await?;
for (source_id, points) in grouped {
let point_ids: Vec<Uuid> = points.into_iter().map(|p| p.point_id).collect();
if let Err(e) = state
.event_manager
.send(crate::event::ReloadEvent::PointCreateBatch { source_id, point_ids })
{
tracing::error!("Failed to send PointCreateBatch event: {}", e);
}
}
}
Ok(Json(BatchCreatePointsRes {
success_count,
failed_count,
failed_node_ids,
created_point_ids,
}))
}
#[derive(Deserialize, Validate)]
/// Request payload for batch point deletion.
pub struct BatchDeletePointsReq {
pub point_ids: Vec<Uuid>,
}
#[derive(Serialize)]
/// Response payload for batch point deletion.
pub struct BatchDeletePointsRes {
pub deleted_count: u64,
}
/// Batch delete points and emit grouped delete events by source.
pub async fn batch_delete_points(
State(state): State<AppState>,
Json(payload): Json<BatchDeletePointsReq>,
) -> Result<impl IntoResponse, ApiErr> {
payload.validate()?;
if payload.point_ids.is_empty() {
return Err(ApiErr::BadRequest("point_ids cannot be empty".to_string(), None));
}
let pool = &state.pool;
let point_ids = payload.point_ids;
let grouped = crate::service::get_points_grouped_by_source(pool, &point_ids).await?;
let existing_point_ids: Vec<Uuid> = grouped
.values()
.flat_map(|points| points.iter().map(|p| p.point_id))
.collect();
if existing_point_ids.is_empty() {
return Ok(Json(BatchDeletePointsRes { deleted_count: 0 }));
}
let result = sqlx::query(
r#"DELETE FROM point WHERE id = ANY($1)"#,
)
.bind(&existing_point_ids)
.execute(pool)
.await?;
for (source_id, points) in grouped {
let ids: Vec<Uuid> = points.into_iter().map(|p| p.point_id).collect();
if let Err(e) = state
.event_manager
.send(crate::event::ReloadEvent::PointDeleteBatch {
source_id,
point_ids: ids,
})
{
tracing::error!("Failed to send PointDeleteBatch event: {}", e);
}
}
Ok(Json(BatchDeletePointsRes {
deleted_count: result.rows_affected(),
}))
}
pub async fn batch_set_point_value(
State(state): State<AppState>,
headers: HeaderMap,
Json(payload): Json<crate::connection::BatchSetPointValueReq>,
) -> Result<impl IntoResponse, ApiErr> {
let write_key = headers
.get("X-Write-Key")
.and_then(|v| v.to_str().ok())
.unwrap_or_default();
if !state.config.verify_write_key(write_key) {
return Err(ApiErr::Forbidden(
"write permission denied".to_string(),
Some(serde_json::json!({
"hint": "set WRITE_API_KEY and pass header X-Write-Key"
})),
));
}
let result = state.connection_manager.write_point_values_batch(payload)
.await
.map_err(|e| ApiErr::Internal(e, None))?;
Ok(Json(result))
}

583
src/handler/source.rs Normal file
View File

@ -0,0 +1,583 @@
use axum::{Json, extract::{Path, State}, http::StatusCode, response::IntoResponse};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use validator::Validate;
use opcua::types::{
NodeId, BrowseDescription, ReferenceDescription,
BrowseDirection as OpcuaBrowseDirection, Identifier, ReadValueId, AttributeId, NumericRange, TimestampsToReturn, Variant
};
use opcua::types::ReferenceTypeId;
use opcua::client::Session;
use std::collections::{HashMap, VecDeque};
use crate::util::response::ApiErr;
use crate::{AppState, model::{Node, Source}};
use anyhow::{Context};
// 树节点结构体
#[derive(Debug, Serialize, Clone)]
pub struct TreeNode {
pub id: Uuid,
pub source_id: Uuid,
pub external_id: String,
pub namespace_uri: Option<String>,
pub namespace_index: Option<i32>,
pub identifier_type: Option<String>,
pub identifier: Option<String>,
pub browse_name: String,
pub display_name: Option<String>,
pub node_class: String,
pub parent_id: Option<Uuid>,
pub children: Vec<TreeNode>,
}
impl TreeNode {
fn from_node(node: Node) -> Self {
TreeNode {
id: node.id,
source_id: node.source_id,
external_id: node.external_id,
namespace_uri: node.namespace_uri,
namespace_index: node.namespace_index,
identifier_type: node.identifier_type,
identifier: node.identifier,
browse_name: node.browse_name,
display_name: node.display_name,
node_class: node.node_class,
parent_id: node.parent_id,
children: Vec::new(),
}
}
}
// 带连接状态的Source响应结构体
#[derive(Debug, Serialize, Clone)]
pub struct SourceWithStatus {
#[serde(flatten)]
pub source: Source,
pub is_connected: bool,
pub last_error: Option<String>,
#[serde(serialize_with = "crate::util::datetime::option_utc_to_local_str")]
pub last_time: Option<DateTime<Utc>>,
}
pub async fn get_source_list(State(state): State<AppState>) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool;
let sources: Vec<Source> = sqlx::query_as(
r#"SELECT * FROM source where enabled is true"#,
).fetch_all(pool).await?;
// 获取所有连接状态
let status_map: std::collections::HashMap<Uuid, (bool, Option<String>, Option<DateTime<Utc>>)> =
state.connection_manager.get_all_status().await
.into_iter()
.map(|(source_id, s)| (source_id, (s.is_connected, s.last_error, Some(s.last_time))))
.collect();
// 组合Source和连接状态
let sources_with_status: Vec<SourceWithStatus> = sources
.into_iter()
.map(|source| {
let (is_connected, last_error, last_time) = status_map
.get(&source.id)
.map(|(connected, error, time)| (*connected, error.clone(), *time))
.unwrap_or((false, None, None));
SourceWithStatus {
source,
is_connected,
last_error,
last_time,
}
})
.collect();
Ok(Json(sources_with_status))
}
pub async fn get_node_tree(
State(state): State<AppState>,
Path(source_id): Path<Uuid>,
) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool;
// 查询所有属于该source的节点
let nodes: Vec<Node> = sqlx::query_as::<_, Node>(
r#"SELECT * FROM node WHERE source_id = $1 ORDER BY created_at"#,
)
.bind(source_id)
.fetch_all(pool)
.await?;
// 构建节点树
let tree = build_node_tree(nodes);
Ok(Json(tree))
}
fn build_node_tree(nodes: Vec<Node>) -> Vec<TreeNode> {
let mut node_map: HashMap<Uuid, TreeNode> = HashMap::new();
let mut children_map: HashMap<Uuid, Vec<Uuid>> = HashMap::new();
let mut roots: Vec<Uuid> = Vec::new();
// ① 转换 + 记录 parent 关系
for node in nodes {
let tree_node = TreeNode::from_node(node);
let id = tree_node.id;
if let Some(pid) = tree_node.parent_id {
children_map.entry(pid).or_default().push(id);
} else {
roots.push(id);
}
node_map.insert(id, tree_node);
}
// ② 递归构建
fn attach_children(
id: Uuid,
node_map: &mut HashMap<Uuid, TreeNode>,
children_map: &HashMap<Uuid, Vec<Uuid>>,
) -> TreeNode {
let mut node = node_map.remove(&id).unwrap();
if let Some(child_ids) = children_map.get(&id) {
for &cid in child_ids {
let child = attach_children(cid, node_map, children_map);
node.children.push(child);
}
}
node
}
// ③ 生成最终树
roots
.into_iter()
.map(|rid| attach_children(rid, &mut node_map, &children_map))
.collect()
}
#[derive(Deserialize, Validate)]
pub struct CreateSourceReq {
pub name: String,
pub endpoint: String,
pub enabled: bool,
}
#[derive(Serialize)]
pub struct CreateSourceRes {
pub id: Uuid,
}
pub async fn create_source(
State(state): State<AppState>,
Json(payload): Json<CreateSourceReq>,
) -> Result<impl IntoResponse, ApiErr> {
payload.validate()?;
let pool = &state.pool;
let new_id = Uuid::new_v4();
sqlx::query(
r#"INSERT INTO source (id, name, endpoint, enabled, protocol) VALUES ($1, $2, $3, $4, $5)"#,
)
.bind(new_id)
.bind(&payload.name)
.bind(&payload.endpoint)
.bind(payload.enabled)
.bind("opcua") //默认opcua协议
.execute(pool)
.await?;
// 触发 SourceCreate 事件
let _ = state.event_manager.send(crate::event::ReloadEvent::SourceCreate { source_id: new_id });
Ok((StatusCode::CREATED, Json(CreateSourceRes { id: new_id })))
}
pub async fn delete_source(
State(state): State<AppState>,
Path(source_id): Path<Uuid>,
) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool;
// 删除source
let result = sqlx::query("DELETE FROM source WHERE id = $1")
.bind(source_id)
.execute(pool)
.await?;
// 检查是否删除了记录
if result.rows_affected() == 0 {
return Err(ApiErr::NotFound(format!("Source with id {} not found", source_id), None));
}
// 触发 SourceDelete 事件
let _ = state.event_manager.send(crate::event::ReloadEvent::SourceDelete { source_id });
Ok(StatusCode::NO_CONTENT)
}
pub async fn browse_and_save_nodes(
State(state): State<AppState>,
Path(source_id): Path<Uuid>,
) -> Result<impl IntoResponse, ApiErr> {
let pool = &state.pool;
// 确认 source 存在
sqlx::query("SELECT 1 FROM source WHERE id = $1")
.bind(source_id)
.fetch_one(pool)
.await?;
let session = state.connection_manager
.get_session(source_id)
.await
.ok_or_else(|| anyhow::anyhow!("Source not connected"))?;
// 读取 namespace 映射
let namespace_map = load_namespace_map(&session).await
.context("Failed to load namespace map")?;
// 开启事务(整次浏览一个事务)
let mut tx = pool.begin().await
.context("Failed to begin transaction")?;
let mut processed_nodes: HashMap<String, ()> = HashMap::new();
let mut queue: VecDeque<(NodeId, Option<Uuid>)> = VecDeque::new();
queue.push_back((NodeId::objects_folder_id(), None));
while let Some((node_id, parent_id)) = queue.pop_front() {
browse_single_node(
&session,
&mut tx,
source_id,
&node_id,
parent_id,
&namespace_map,
&mut processed_nodes,
&mut queue,
).await
.with_context(|| format!("Failed to browse node: {:?}", node_id))?;
}
tx.commit().await
.context("Failed to commit transaction")?;
Ok(Json(serde_json::json!({
"ok_msg": "Browse completed",
"total_nodes": processed_nodes.len()
})))
}
////////////////////////////////////////////////////////////////
// 浏览单个节点(含 continuation
////////////////////////////////////////////////////////////////
async fn browse_single_node(
session: &Session,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
source_id: Uuid,
node_id: &NodeId,
parent_id: Option<Uuid>,
namespace_map: &HashMap<i32, String>,
processed_nodes: &mut HashMap<String, ()>,
queue: &mut VecDeque<(NodeId, Option<Uuid>)>,
) -> anyhow::Result<()> {
let browse_desc = BrowseDescription {
node_id: node_id.clone(),
browse_direction: OpcuaBrowseDirection::Forward,
reference_type_id: ReferenceTypeId::HierarchicalReferences.into(),
include_subtypes: true,
node_class_mask: 0,
result_mask: 0x3F,
};
let mut results = session.browse(&[browse_desc], 0u32, None).await
.context("Failed to browse node")?;
loop {
let result = &results[0];
if let Some(refs) = &result.references {
for ref_desc in refs {
process_reference(
ref_desc,
tx,
source_id,
parent_id,
namespace_map,
processed_nodes,
queue,
).await
.with_context(|| format!("Failed to process reference: {:?}", ref_desc.node_id.node_id))?;
}
}
if !result.continuation_point.is_null() {
let cp = result.continuation_point.clone();
results = session.browse_next(false, &[cp]).await
.context("Failed to browse next")?;
} else {
break;
}
}
Ok(())
}
////////////////////////////////////////////////////////////////
// 处理单个 Reference核心优化版
////////////////////////////////////////////////////////////////
async fn process_reference(
ref_desc: &ReferenceDescription,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
source_id: Uuid,
parent_id: Option<Uuid>,
namespace_map: &HashMap<i32, String>,
processed_nodes: &mut HashMap<String, ()>,
queue: &mut VecDeque<(NodeId, Option<Uuid>)>,
) -> anyhow::Result<()> {
let node_id_obj = &ref_desc.node_id.node_id;
let node_id_str = node_id_obj.to_string();
// 内存去重
if processed_nodes.contains_key(&node_id_str) {
return Ok(());
}
let (namespace_index, identifier_type, identifier) =
parse_node_id(node_id_obj);
let namespace_uri = namespace_map
.get(&(namespace_index.unwrap_or(0) as i32))
.cloned()
.unwrap_or_default();
let browse_name = ref_desc.browse_name.name.to_string();
let display_name = ref_desc.display_name.text.to_string();
let node_class = format!("{:?}", ref_desc.node_class);
let now = Utc::now();
let node_uuid = Uuid::new_v4();
// ?? 关键优化:直接 UPSERT避免 SELECT
// 注意:如果 parent_id 存在,则必须确保该父节点已存在于数据库中
// 否则会触发外键约束失败
if parent_id.is_some() {
// 检查父节点是否已存在于数据库中
let parent_exists = sqlx::query(
r#"SELECT 1 FROM node WHERE id = $1"#,
)
.bind(parent_id.unwrap())
.fetch_optional(tx.as_mut())
.await?;
if parent_exists.is_none() {
// 如果父节点不存在,则暂时不设置 parent_id
// 这样可以避免外键约束失败
sqlx::query(
r#"
INSERT INTO node (
id,
source_id,
external_id,
namespace_uri,
namespace_index,
identifier_type,
identifier,
browse_name,
display_name,
node_class,
parent_id,
)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,NULL)
ON CONFLICT(source_id, external_id) DO UPDATE SET
namespace_uri = excluded.namespace_uri,
namespace_index = excluded.namespace_index,
identifier_type = excluded.identifier_type,
identifier = excluded.identifier,
browse_name = excluded.browse_name,
display_name = excluded.display_name,
node_class = excluded.node_class,
updated_at = NOW()
"#
)
.bind(node_uuid)
.bind(source_id)
.bind(&node_id_str)
.bind(&namespace_uri)
.bind(namespace_index.map(|v| v as i32))
.bind(&identifier_type)
.bind(&identifier)
.bind(&browse_name)
.bind(&display_name)
.bind(&node_class)
.execute(tx.as_mut())
.await
.context("Failed to execute UPSERT query")?;
} else {
// 如果父节点存在,则正常设置 parent_id
sqlx::query(
r#"
INSERT INTO node (
id,
source_id,
external_id,
namespace_uri,
namespace_index,
identifier_type,
identifier,
browse_name,
display_name,
node_class,
parent_id,
created_at,
updated_at
)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,NOW(),NOW())
ON CONFLICT(source_id, external_id) DO UPDATE SET
namespace_uri = excluded.namespace_uri,
namespace_index = excluded.namespace_index,
identifier_type = excluded.identifier_type,
identifier = excluded.identifier,
browse_name = excluded.browse_name,
display_name = excluded.display_name,
node_class = excluded.node_class,
parent_id = excluded.parent_id,
updated_at = NOW()
"#
)
.bind(node_uuid)
.bind(source_id)
.bind(&node_id_str)
.bind(&namespace_uri)
.bind(namespace_index.map(|v| v as i32))
.bind(&identifier_type)
.bind(&identifier)
.bind(&browse_name)
.bind(&display_name)
.bind(&node_class)
.bind(parent_id)
.bind(now)
.bind(now)
.execute(tx.as_mut())
.await
.context("Failed to execute UPSERT query")?;
}
} else {
// 如果没有 parent_id则正常插入
sqlx::query(
r#"
INSERT INTO node (
id,
source_id,
external_id,
namespace_uri,
namespace_index,
identifier_type,
identifier,
browse_name,
display_name,
node_class,
parent_id
)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,NULL)
ON CONFLICT(source_id, external_id) DO UPDATE SET
namespace_uri = excluded.namespace_uri,
namespace_index = excluded.namespace_index,
identifier_type = excluded.identifier_type,
identifier = excluded.identifier,
browse_name = excluded.browse_name,
display_name = excluded.display_name,
node_class = excluded.node_class,
updated_at = NOW()
"#
)
.bind(node_uuid)
.bind(source_id)
.bind(&node_id_str)
.bind(&namespace_uri)
.bind(namespace_index.map(|v| v as i32))
.bind(&identifier_type)
.bind(&identifier)
.bind(&browse_name)
.bind(&display_name)
.bind(&node_class)
.execute(tx.as_mut())
.await
.context("Failed to execute UPSERT query")?;
}
processed_nodes.insert(node_id_str.clone(), ());
queue.push_back((node_id_obj.clone(), Some(node_uuid)));
Ok(())
}
////////////////////////////////////////////////////////////////
// 解析 NodeId
////////////////////////////////////////////////////////////////
fn parse_node_id(node_id: &NodeId) -> (Option<u16>, Option<String>, String) {
let namespace_index = Some(node_id.namespace);
let (identifier_type, identifier) = match &node_id.identifier {
Identifier::Numeric(i) => ("i".to_string(), i.to_string()),
Identifier::String(s) => ("s".to_string(), s.to_string()),
Identifier::Guid(g) => ("g".to_string(), g.to_string()),
Identifier::ByteString(b) => ("b".to_string(), format!("{:?}", b)),
};
(namespace_index, Some(identifier_type), identifier)
}
////////////////////////////////////////////////////////////////
// 读取 NamespaceArray
////////////////////////////////////////////////////////////////
async fn load_namespace_map(
session: &Session,
) -> anyhow::Result<HashMap<i32, String>> {
// 读取命名空间数组节点
let ns_node = NodeId::new(0, 2255);
let read_request = ReadValueId {
node_id: ns_node,
attribute_id: AttributeId::Value as u32,
index_range: NumericRange::None,
data_encoding: Default::default(),
};
// 执行读取操作
let result = session.read(&[read_request], TimestampsToReturn::Neither, 0f64).await
.context("Failed to read namespace map")?;
// 解析并构建命名空间映射
let mut map = HashMap::new();
if let Some(value) = &result[0].value {
if let Variant::Array(array) = value {
for (i, item) in array.values.iter().enumerate() {
if let Ok(index) = i32::try_from(i) {
if let Variant::String(uri) = item {
map.insert(index, uri.to_string());
}
}
}
}
}
Ok(map)
}

102
src/handler/tag.rs Normal file
View File

@ -0,0 +1,102 @@
use axum::{Json, extract::{Path, State}, http::StatusCode, response::IntoResponse};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use validator::Validate;
use crate::util::response::ApiErr;
use crate::{AppState};
/// 获取所有标签
pub async fn get_tag_list(
State(state): State<AppState>,
) -> Result<impl IntoResponse, ApiErr> {
let tags = crate::service::get_all_tags(&state.pool).await?;
Ok(Json(tags))
}
/// 获取标签下的点位信息
pub async fn get_tag_points(
State(state): State<AppState>,
Path(tag_id): Path<Uuid>,
) -> Result<impl IntoResponse, ApiErr> {
let points = crate::service::get_tag_points(&state.pool, tag_id).await?;
Ok(Json(points))
}
#[derive(Debug, Deserialize, Validate)]
pub struct CreateTagReq {
#[validate(length(min = 1, max = 100))]
pub name: String,
pub description: Option<String>,
pub point_ids: Option<Vec<Uuid>>,
}
#[derive(Debug, Deserialize, Validate)]
pub struct UpdateTagReq {
#[validate(length(min = 1, max = 100))]
pub name: Option<String>,
pub description: Option<String>,
pub point_ids: Option<Vec<Uuid>>,
}
/// 创建标签
pub async fn create_tag(
State(state): State<AppState>,
Json(payload): Json<CreateTagReq>,
) -> Result<impl IntoResponse, ApiErr> {
payload.validate()?;
let point_ids = payload.point_ids.as_deref().unwrap_or(&[]);
let tag_id = crate::service::create_tag(
&state.pool,
&payload.name,
payload.description.as_deref(),
point_ids,
).await?;
Ok((StatusCode::CREATED, Json(serde_json::json!({
"id": tag_id,
"ok_msg": "Tag created successfully"
}))))
}
/// 更新标签
pub async fn update_tag(
State(state): State<AppState>,
Path(tag_id): Path<Uuid>,
Json(payload): Json<UpdateTagReq>,
) -> Result<impl IntoResponse, ApiErr> {
payload.validate()?;
// 检查标签是否存在
let exists = crate::service::get_tag_by_id(&state.pool, tag_id).await?;
if exists.is_none() {
return Err(ApiErr::NotFound("Tag not found".to_string(), None));
}
crate::service::update_tag(
&state.pool,
tag_id,
payload.name.as_deref(),
payload.description.as_deref(),
payload.point_ids.as_deref(),
).await?;
Ok(Json(serde_json::json!({
"ok_msg": "Tag updated successfully"
})))
}
/// 删除标签
pub async fn delete_tag(
State(state): State<AppState>,
Path(tag_id): Path<Uuid>,
) -> Result<impl IntoResponse, ApiErr> {
let deleted = crate::service::delete_tag(&state.pool, tag_id).await?;
if !deleted {
return Err(ApiErr::NotFound("Tag not found".to_string(), None));
}
Ok(StatusCode::NO_CONTENT)
}

150
src/main.rs Normal file
View File

@ -0,0 +1,150 @@
mod model;
mod config;
mod util;
mod db;
mod handler;
mod middleware;
mod connection;
mod event;
mod service;
mod websocket;
mod telemetry;
use config::AppConfig;
use tower_http::cors::{Any, CorsLayer};
use db::init_database;
use middleware::simple_logger;
use connection::ConnectionManager;
use event::EventManager;
use std::sync::Arc;
use axum::{
routing::get,
Router,
};
#[derive(Clone)]
pub struct AppState {
pub config: AppConfig,
pub pool: sqlx::PgPool,
pub connection_manager: Arc<ConnectionManager>,
pub event_manager: Arc<EventManager>,
pub ws_manager: Arc<websocket::WebSocketManager>,
}
#[tokio::main]
async fn main() {
dotenv::dotenv().ok();
util::log::init_logger();
let config = AppConfig::from_env().expect("Failed to load configuration");
let pool = init_database(&config.database_url).await.expect("Failed to initialize database");
let mut connection_manager = ConnectionManager::new_with_pool(pool.clone());
let ws_manager = Arc::new(websocket::WebSocketManager::new());
let event_manager = Arc::new(EventManager::new(
pool.clone(),
Arc::new(connection_manager.clone()),
Some(ws_manager.clone()),
));
connection_manager.set_event_manager(event_manager.clone());
let connection_manager = Arc::new(connection_manager);
// Connect to all enabled sources
let sources = service::get_all_enabled_sources(&pool)
.await
.expect("Failed to fetch sources");
for source in sources {
tracing::info!("Connecting to source: {} ({})", source.name, source.endpoint);
match connection_manager.connect_from_source(&pool, source.id).await {
Ok(_) => {
tracing::info!("Successfully connected to source: {}", source.name);
// Subscribe to points for this source
match connection_manager
.subscribe_points_from_source(source.id, None, &pool)
.await
{
Ok(stats) => {
let subscribed = *stats.get("subscribed").unwrap_or(&0);
let polled = *stats.get("polled").unwrap_or(&0);
let total = *stats.get("total").unwrap_or(&0);
tracing::info!(
"Point subscribe setup for source {}: subscribed={}, polled={}, total={}",
source.name,
subscribed,
polled,
total
);
}
Err(e) => {
tracing::error!("Failed to subscribe to points for source {}: {}", source.name, e);
}
}
}
Err(e) => {
tracing::error!("Failed to connect to source {}: {}", source.name, e);
}
}
}
let state = AppState {
config: config.clone(),
pool,
connection_manager: connection_manager.clone(),
event_manager,
ws_manager,
};
let app = build_router(state.clone());
let addr = format!("{}:{}", config.server_host, config.server_port);
tracing::info!("Starting server at http://{}", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
// comment fixed
let shutdown_signal = async move{
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
tracing::info!("Received shutdown signal, closing all connections...");
connection_manager.disconnect_all().await;
tracing::info!("All connections closed");
};
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal)
.await
.unwrap();
}
fn build_router(state: AppState) -> Router {
let all_route = Router::new()
.route("/api/source", get(handler::source::get_source_list).post(handler::source::create_source))
.route("/api/source/{source_id}", axum::routing::delete(handler::source::delete_source))
.route("/api/source/{source_id}/browse", axum::routing::post(handler::source::browse_and_save_nodes))
.route("/api/source/{source_id}/node-tree", get(handler::source::get_node_tree))
.route("/api/point", get(handler::point::get_point_list))
.route(
"/api/point/value/batch",
axum::routing::post(handler::point::batch_set_point_value),
)
.route(
"/api/point/batch",
axum::routing::post(handler::point::batch_create_points)
.delete(handler::point::batch_delete_points),
)
.route("/api/point/{point_id}", get(handler::point::get_point).put(handler::point::update_point).delete(handler::point::delete_point))
.route("/api/point/batch/set-tags", put(handler::point::batch_set_point_tags))
.route("/api/tag", get(handler::tag::get_tag_list).post(handler::tag::create_tag))
.route("/api/tag/{tag_id}", get(handler::tag::get_tag_points).put(handler::tag::update_tag).delete(handler::tag::delete_tag));
Router::new()
.merge(all_route)
.route("/ws/public", get(websocket::public_websocket_handler))
.route("/ws/client/{client_id}", get(websocket::client_websocket_handler))
.layer(axum::middleware::from_fn(simple_logger))
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
.with_state(state)
}

37
src/middleware.rs Normal file
View File

@ -0,0 +1,37 @@
use axum::{
body::Body,
http::Request,
middleware::Next,
response::Response,
};
use std::time::Instant;
pub async fn simple_logger(
req: Request<Body>,
next: Next,
) -> Response {
// 直接获取字符串引用,不用克隆
let method = req.method().to_string();
let uri = req.uri().to_string(); // Uri 的 to_string() 创建新字符串
let start = Instant::now();
let res = next.run(req).await;
let duration = start.elapsed();
let status = res.status();
match status.as_u16() {
100..=399 => {
tracing::info!("{} {} {} {:?}", method, uri, status, duration);
}
400..=499 => {
tracing::warn!("{} {} {} {:?}", method, uri, status, duration);
}
500..=599 => {
tracing::error!("{} {} {} {:?}", method, uri, status, duration);
}
_ => {
tracing::warn!("{} {} {} {:?}", method, uri, status, duration);
}
}
res
}

120
src/model.rs Normal file
View File

@ -0,0 +1,120 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
use crate::util::datetime::utc_to_local_str;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ScanMode {
Poll,
Subscribe,
}
impl ScanMode {
pub fn as_str(&self) -> &'static str {
match self {
ScanMode::Poll => "poll",
ScanMode::Subscribe => "subscribe",
}
}
}
impl From<ScanMode> for String {
fn from(mode: ScanMode) -> Self {
mode.as_str().to_string()
}
}
impl std::fmt::Display for ScanMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl std::str::FromStr for ScanMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"poll" => Ok(ScanMode::Poll),
"subscribe" => Ok(ScanMode::Subscribe),
_ => Err(format!("Invalid scan mode: {}", s)),
}
}
}
#[derive(Debug, Serialize, Deserialize, FromRow, Clone)]
pub struct Source {
pub id: Uuid,
pub name: String,
pub protocol: String, // opcua, modbus
pub endpoint: String,
pub security_policy: Option<String>,
pub security_mode: Option<String>,
pub username: Option<String>,
pub password: Option<String>,
pub enabled: bool,
#[serde(serialize_with = "utc_to_local_str")]
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "utc_to_local_str")]
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize, FromRow)]
#[allow(dead_code)]
pub struct Node {
pub id: Uuid,
pub source_id: Uuid,
pub external_id: String, // ns=2;s=Temperature
// comment fixed
pub namespace_uri: Option<String>,
pub namespace_index: Option<i32>,
pub identifier_type: Option<String>, // i/s/g/b
pub identifier: Option<String>,
pub browse_name: String,
pub display_name: Option<String>,
pub node_class: String, // Object/Variable/Method coil/input topic
pub parent_id: Option<Uuid>,
#[serde(serialize_with = "utc_to_local_str")]
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "utc_to_local_str")]
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize, FromRow)]
#[allow(dead_code)]
pub struct Point {
pub id: Uuid,
pub node_id: Uuid,
pub name: String,
pub description: Option<String>,
pub unit: Option<String>,
pub scan_interval_s: i32, // s
pub tag_id: Option<Uuid>,
#[serde(serialize_with = "utc_to_local_str")]
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "utc_to_local_str")]
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct PointSubscriptionInfo {
pub point_id: Uuid,
pub external_id: String,
pub scan_interval_s: i32,
}
#[derive(Debug, Serialize, Deserialize, FromRow, Clone)]
pub struct Tag {
pub id: Uuid,
pub name: String,
pub description: Option<String>,
#[serde(serialize_with = "utc_to_local_str")]
pub created_at: DateTime<Utc>,
#[serde(serialize_with = "utc_to_local_str")]
pub updated_at: DateTime<Utc>,
}

300
src/service.rs Normal file
View File

@ -0,0 +1,300 @@
use crate::model::{PointSubscriptionInfo, Source};
use sqlx::{PgPool, query_as};
pub async fn get_enabled_source(
pool: &PgPool,
source_id: uuid::Uuid,
) -> Result<Option<Source>, sqlx::Error> {
query_as::<_, Source>("SELECT * FROM source WHERE id = $1 AND enabled = true")
.bind(source_id)
.fetch_optional(pool)
.await
}
pub async fn get_all_enabled_sources(pool: &PgPool) -> Result<Vec<Source>, sqlx::Error> {
query_as::<_, Source>("SELECT * FROM source WHERE enabled = true")
.fetch_all(pool)
.await
}
pub async fn get_points_grouped_by_source(
pool: &PgPool,
point_ids: &[uuid::Uuid],
) -> Result<std::collections::HashMap<uuid::Uuid, Vec<PointSubscriptionInfo>>, sqlx::Error> {
if point_ids.is_empty() {
return Ok(std::collections::HashMap::new());
}
let rows = sqlx::query(
r#"
SELECT
p.id as point_id,
n.source_id,
n.external_id,
p.scan_interval_s
FROM point p
INNER JOIN node n ON p.node_id = n.id
WHERE p.id = ANY($1)
ORDER BY n.source_id, p.created_at
"#,
)
.bind(point_ids)
.fetch_all(pool)
.await?;
let mut result: std::collections::HashMap<uuid::Uuid, Vec<PointSubscriptionInfo>> =
std::collections::HashMap::new();
for row in rows {
use sqlx::Row;
let point_id: uuid::Uuid = row.get("point_id");
let source_id: uuid::Uuid = row.get("source_id");
let info = PointSubscriptionInfo {
point_id,
external_id: row.get("external_id"),
scan_interval_s: row.get("scan_interval_s"),
};
result.entry(source_id).or_default().push(info);
}
Ok(result)
}
pub async fn get_points_with_ids(
pool: &PgPool,
source_id: uuid::Uuid,
point_ids: &[uuid::Uuid],
) -> Result<Vec<PointSubscriptionInfo>, sqlx::Error> {
let rows = if point_ids.is_empty() {
sqlx::query(
r#"
SELECT
p.id as point_id,
n.external_id,
p.scan_interval_s
FROM point p
INNER JOIN node n ON p.node_id = n.id
WHERE n.source_id = $1
ORDER BY p.created_at
"#,
)
.bind(source_id)
.fetch_all(pool)
.await?
} else {
sqlx::query(
r#"
SELECT
p.id as point_id,
n.external_id,
p.scan_interval_s
FROM point p
INNER JOIN node n ON p.node_id = n.id
WHERE n.source_id = $1
AND p.id = ANY($2)
ORDER BY p.created_at
"#,
)
.bind(source_id)
.bind(point_ids)
.fetch_all(pool)
.await?
};
use sqlx::Row;
Ok(rows
.into_iter()
.map(|row| PointSubscriptionInfo {
point_id: row.get("point_id"),
external_id: row.get("external_id"),
scan_interval_s: row.get("scan_interval_s"),
})
.collect())
}
// ==================== Tag 相关服务函数 ====================
/// 获取所有标签
pub async fn get_all_tags(
pool: &PgPool,
) -> Result<Vec<crate::model::Tag>, sqlx::Error> {
query_as::<_, crate::model::Tag>(
r#"SELECT * FROM tag ORDER BY created_at"#
)
.fetch_all(pool)
.await
}
/// 根据ID获取标签
pub async fn get_tag_by_id(
pool: &PgPool,
tag_id: uuid::Uuid,
) -> Result<Option<crate::model::Tag>, sqlx::Error> {
query_as::<_, crate::model::Tag>(
r#"SELECT * FROM tag WHERE id = $1"#
)
.bind(tag_id)
.fetch_optional(pool)
.await
}
/// 获取标签下的点位
pub async fn get_tag_points(
pool: &PgPool,
tag_id: uuid::Uuid,
) -> Result<Vec<crate::model::Point>, sqlx::Error> {
query_as::<_, crate::model::Point>(
r#"
SELECT *
FROM point
WHERE tag_id = $1
ORDER BY created_at
"#
)
.bind(tag_id)
.fetch_all(pool)
.await
}
/// 创建标签
pub async fn create_tag(
pool: &PgPool,
name: &str,
description: Option<&str>,
point_ids: &[uuid::Uuid],
) -> Result<uuid::Uuid, sqlx::Error> {
let mut tx = pool.begin().await?;
let tag_id = uuid::Uuid::new_v4();
sqlx::query(
r#"
INSERT INTO tag (id, name, description)
VALUES ($1, $2, $3)
"#
)
.bind(tag_id)
.bind(name)
.bind(description)
.execute(&mut *tx)
.await?;
if !point_ids.is_empty() {
for point_id in point_ids {
sqlx::query(
r#"
UPDATE point
SET tag_id = $1
WHERE id = $2
"#
)
.bind(tag_id)
.bind(point_id)
.execute(&mut *tx)
.await?;
}
}
tx.commit().await?;
Ok(tag_id)
}
/// 更新标签
pub async fn update_tag(
pool: &PgPool,
tag_id: uuid::Uuid,
name: Option<&str>,
description: Option<&str>,
point_ids: Option<&[uuid::Uuid]>,
) -> Result<(), sqlx::Error> {
let mut tx = pool.begin().await?;
// 更新基本信息
if name.is_some() || description.is_some() {
let mut updates = Vec::new();
let mut param_count = 1;
if let Some(n) = name {
updates.push(format!("name = ${}", param_count));
param_count += 1;
}
if let Some(d) = description {
updates.push(format!("description = ${}", param_count));
param_count += 1;
}
updates.push("updated_at = NOW()".to_string());
let sql = format!(
r#"UPDATE tag SET {} WHERE id = ${}"#,
updates.join(", "),
param_count
);
let mut query = sqlx::query(&sql);
if let Some(n) = name {
query = query.bind(n);
}
if let Some(d) = description {
query = query.bind(d);
}
query = query.bind(tag_id);
query.execute(&mut *tx).await?;
}
// 更新点位列表
if let Some(new_point_ids) = point_ids {
// 先将原属于该标签的点位移出标签
sqlx::query(
r#"
UPDATE point
SET tag_id = NULL
WHERE tag_id = $1
"#
)
.bind(tag_id)
.execute(&mut *tx)
.await?;
// 将新点位添加到标签
for point_id in new_point_ids {
sqlx::query(
r#"
UPDATE point
SET tag_id = $1
WHERE id = $2
"#
)
.bind(tag_id)
.bind(point_id)
.execute(&mut *tx)
.await?;
}
}
tx.commit().await?;
Ok(())
}
/// 删除标签
pub async fn delete_tag(
pool: &PgPool,
tag_id: uuid::Uuid,
) -> Result<bool, sqlx::Error> {
let result = sqlx::query(
r#"DELETE FROM tag WHERE id = $1"#
)
.bind(tag_id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}

154
src/telemetry.rs Normal file
View File

@ -0,0 +1,154 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::model::ScanMode;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum PointQuality {
Good,
Bad,
Uncertain,
Unknown,
}
impl PointQuality {
pub fn from_status_code(status: &opcua::types::StatusCode) -> Self {
if status.is_good() {
Self::Good
} else if status.is_bad() {
Self::Bad
} else if status.is_uncertain() {
Self::Uncertain
} else {
Self::Unknown
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
pub enum DataValue {
Null,
Bool(bool),
Int(i64),
UInt(u64),
Float(f64),
Text(String),
Bytes(Vec<u8>),
Array(Vec<DataValue>),
Object(serde_json::Value),
}
impl DataValue {
pub fn to_json_value(&self) -> serde_json::Value {
match self {
DataValue::Null => serde_json::Value::Null,
DataValue::Bool(v) => serde_json::Value::Bool(*v),
DataValue::Int(v) => serde_json::json!(*v),
DataValue::UInt(v) => serde_json::json!(*v),
DataValue::Float(v) => serde_json::json!(*v),
DataValue::Text(v) => serde_json::json!(v),
DataValue::Bytes(v) => serde_json::json!(v),
DataValue::Array(v) => {
serde_json::Value::Array(v.iter().map(DataValue::to_json_value).collect())
}
DataValue::Object(v) => v.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PointMonitorInfo {
pub protocol: String,
pub source_id: Uuid,
pub point_id: Uuid,
pub client_handle: u32,
pub scan_mode: ScanMode,
pub timestamp: Option<DateTime<Utc>>,
pub quality: PointQuality,
pub value: Option<DataValue>,
pub value_type: Option<String>,
pub value_text: Option<String>,
}
impl PointMonitorInfo {
pub fn value_as_json(&self) -> Option<serde_json::Value> {
self.value.as_ref().map(DataValue::to_json_value)
}
}
#[derive(Debug, Clone)]
pub struct PointValueChangeEvent {
pub source_id: Uuid,
pub point_id: Option<Uuid>,
pub client_handle: u32,
pub value: Option<DataValue>,
pub value_type: Option<String>,
pub value_text: Option<String>,
pub quality: PointQuality,
pub protocol: String,
pub timestamp: Option<DateTime<Utc>>,
pub scan_mode: ScanMode,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsPointMonitorInfo {
pub protocol: String,
pub point_id: Uuid,
pub scan_mode: String,
pub timestamp: Option<String>,
pub quality: PointQuality,
pub value: Option<serde_json::Value>,
pub value_type: Option<String>,
}
impl From<&PointMonitorInfo> for WsPointMonitorInfo {
fn from(m: &PointMonitorInfo) -> Self {
Self {
protocol: m.protocol.clone(),
point_id: m.point_id,
scan_mode: m.scan_mode.to_string(),
timestamp: m
.timestamp
.as_ref()
.map(crate::util::datetime::utc_to_local_string),
quality: m.quality.clone(),
value: m.value_as_json(),
value_type: m.value_type.clone(),
}
}
}
pub fn opcua_variant_to_data(value: &opcua::types::Variant) -> DataValue {
use opcua::types::Variant;
match value {
Variant::Empty => DataValue::Null,
Variant::Boolean(v) => DataValue::Bool(*v),
Variant::SByte(v) => DataValue::Int(*v as i64),
Variant::Byte(v) => DataValue::UInt(*v as u64),
Variant::Int16(v) => DataValue::Int(*v as i64),
Variant::UInt16(v) => DataValue::UInt(*v as u64),
Variant::Int32(v) => DataValue::Int(*v as i64),
Variant::UInt32(v) => DataValue::UInt(*v as u64),
Variant::Int64(v) => DataValue::Int(*v),
Variant::UInt64(v) => DataValue::UInt(*v),
Variant::Float(v) => DataValue::Float(*v as f64),
Variant::Double(v) => DataValue::Float(*v),
Variant::String(v) => DataValue::Text(v.to_string()),
Variant::ByteString(v) => DataValue::Bytes(v.value.clone().unwrap_or_default()),
Variant::Array(v) => {
DataValue::Array(v.values.iter().map(opcua_variant_to_data).collect())
}
_ => DataValue::Text(value.to_string()),
}
}
pub fn opcua_variant_type(value: &opcua::types::Variant) -> String {
match value.scalar_type_id() {
Some(t) => t.to_string(),
None => "unknown".to_string(),
}
}

4
src/util.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod datetime;
pub mod log;
pub mod response;
pub mod validator;

24
src/util/datetime.rs Normal file
View File

@ -0,0 +1,24 @@
use chrono::{DateTime, Local, Utc};
use serde::Serializer;
pub fn utc_to_local_string(date: &DateTime<Utc>) -> String {
date.with_timezone(&Local).format("%Y-%m-%d %H:%M:%S%.3f").to_string()
}
pub fn utc_to_local_str<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let formatted = utc_to_local_string(date);
serializer.serialize_str(&formatted)
}
pub fn option_utc_to_local_str<S>(date: &Option<DateTime<Utc>>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match date {
Some(d) => utc_to_local_str(d, serializer),
None => serializer.serialize_none(),
}
}

36
src/util/log.rs Normal file
View File

@ -0,0 +1,36 @@
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
use tracing_appender::{rolling, non_blocking};
use std::sync::OnceLock;
use time::UtcOffset;
static LOG_GUARD: OnceLock<non_blocking::WorkerGuard> = OnceLock::new();
pub fn init_logger() {
std::fs::create_dir_all("./logs").ok();
let file_appender = rolling::daily("./logs", "app.log");
let (file_writer, guard) = non_blocking(file_appender);
LOG_GUARD.set(guard).ok();
let timer = fmt::time::OffsetTime::new(
UtcOffset::from_hms(8, 0, 0).unwrap(),
time::format_description::well_known::Rfc3339,
);
tracing_subscriber::registry()
.with(EnvFilter::from_default_env())
.with(
fmt::layer()
.compact()
.with_timer(timer.clone())
.with_writer(std::io::stdout),
)
.with(
fmt::layer()
.compact()
.with_timer(timer)
.with_writer(file_writer)
.with_ansi(false),
)
.init();
}

91
src/util/response.rs Normal file
View File

@ -0,0 +1,91 @@
use anyhow::Error;
use axum::{Json, http::StatusCode, response::IntoResponse};
use serde::Serialize;
use serde_json::Value;
use sqlx::Error as SqlxError;
#[derive(Debug, Serialize)]
pub struct ErrResp {
pub err_code: i32,
pub err_msg: String,
pub err_detail: Option<Value>,
}
impl ErrResp {
pub fn new(err_code: i32, err_msg: impl Into<String>, detail: Option<Value>) -> Self {
Self {
err_code,
err_msg: err_msg.into(),
err_detail: detail,
}
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub enum ApiErr {
Unauthorized(String, Option<Value>),
Forbidden(String, Option<Value>),
BadRequest(String, Option<Value>),
NotFound(String, Option<Value>),
Internal(String, Option<Value>),
}
impl IntoResponse for ApiErr {
fn into_response(self) -> axum::response::Response {
match self {
ApiErr::Unauthorized(msg, detail) => {
(StatusCode::UNAUTHORIZED, Json(ErrResp::new(401, msg, detail))).into_response()
}
ApiErr::Forbidden(msg, detail) => {
(StatusCode::FORBIDDEN, Json(ErrResp::new(403, msg, detail))).into_response()
}
ApiErr::BadRequest(msg, detail) => {
(StatusCode::BAD_REQUEST, Json(ErrResp::new(400, msg, detail))).into_response()
}
ApiErr::NotFound(msg, detail) => {
(StatusCode::NOT_FOUND, Json(ErrResp::new(404, msg, detail))).into_response()
}
ApiErr::Internal(msg, detail) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrResp::new(500, msg, detail)),
)
.into_response(),
}
}
}
impl From<Error> for ApiErr {
fn from(err: Error) -> Self {
tracing::error!("Error: {:?}; root_cause: {}", err, err.root_cause());
ApiErr::Internal(
err.to_string(),
Some(serde_json::json!({
"root_cause": err.root_cause().to_string(),
"chain": err.chain().map(|e| e.to_string()).collect::<Vec<_>>()
})),
)
}
}
impl From<SqlxError> for ApiErr {
fn from(err: SqlxError) -> Self {
match err {
SqlxError::RowNotFound => {
ApiErr::NotFound("Resource not found".into(), None)
}
SqlxError::Database(db_err) => {
if db_err.code().as_deref() == Some("23505") {
ApiErr::BadRequest("数据已存在".into(), None)
} else {
tracing::error!("Database error: {}", db_err);
ApiErr::Internal("Database error".into(), None)
}
}
_ => {
tracing::error!("Database error: {}", err);
ApiErr::Internal("Database error".into(), None)
}
}
}
}

34
src/util/validator.rs Normal file
View File

@ -0,0 +1,34 @@
use crate::util::response::ApiErr;
use serde_json::{json, Value};
use validator::ValidationErrors;
impl From<ValidationErrors> for ApiErr {
fn from(errors: ValidationErrors) -> Self {
// 构建详细的错误信息
let mut error_details = serde_json::Map::new();
let mut first_error_msg = String::from("请求参数验证失败");
for (field, field_errors) in errors.field_errors() {
let error_list: Vec<String> = field_errors
.iter()
.map(|e| {
e.message.as_ref()
.map(|m| m.to_string())
.unwrap_or_else(|| e.code.to_string())
})
.collect();
error_details.insert(field.to_string(), json!(error_list));
// 获取第一个字段的第一个错误信息
if first_error_msg == "请求参数验证失败" && !error_list.is_empty() {
if let Some(msg) = field_errors[0].message.as_ref() {
first_error_msg = format!("{}: {}", field, msg);
} else {
first_error_msg = format!("{}: {}", field, field_errors[0].code);
}
}
}
ApiErr::BadRequest(first_error_msg, Some(Value::Object(error_details)))
}
}

269
src/websocket.rs Normal file
View File

@ -0,0 +1,269 @@
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Path, State,
},
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use uuid::Uuid;
/// WebSocket message payload types.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum WsMessage {
PointValueChange(crate::telemetry::WsPointMonitorInfo),
PointSetValueBatchResult(crate::connection::BatchSetPointValueRes),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data", rename_all = "snake_case")]
pub enum WsClientMessage {
AuthWrite(WsAuthWriteReq),
PointSetValueBatch(crate::connection::BatchSetPointValueReq),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsAuthWriteReq {
pub key: String,
}
/// Room manager: room_id -> broadcast sender.
#[derive(Clone)]
pub struct RoomManager {
rooms: Arc<RwLock<HashMap<String, broadcast::Sender<WsMessage>>>>,
}
impl RoomManager {
pub fn new() -> Self {
Self {
rooms: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Get or create room sender.
pub async fn get_or_create_room(&self, room_id: &str) -> broadcast::Sender<WsMessage> {
let mut rooms = self.rooms.write().await;
if let Some(sender) = rooms.get(room_id) {
return sender.clone();
}
let (sender, _) = broadcast::channel(100);
rooms.insert(room_id.to_string(), sender.clone());
tracing::info!("Created new room: {}", room_id);
sender
}
/// Get room sender if room exists.
pub async fn get_room(&self, room_id: &str) -> Option<broadcast::Sender<WsMessage>> {
let rooms = self.rooms.read().await;
rooms.get(room_id).cloned()
}
/// Remove room if there are no receivers left.
pub async fn remove_room_if_empty(&self, room_id: &str) {
let mut rooms = self.rooms.write().await;
let should_remove = rooms
.get(room_id)
.map(|sender| sender.receiver_count() == 0)
.unwrap_or(false);
if should_remove {
rooms.remove(room_id);
tracing::info!("Removed empty room: {}", room_id);
}
}
/// Send message to room.
///
/// Returns:
/// - Ok(n): n subscribers received it
/// - Ok(0): room missing or no active subscribers
pub async fn send_to_room(&self, room_id: &str, message: WsMessage) -> Result<usize, String> {
if let Some(sender) = self.get_room(room_id).await {
match sender.send(message) {
Ok(count) => Ok(count),
// No receiver is not exceptional in push scenarios.
Err(broadcast::error::SendError(_)) => Ok(0),
}
} else {
Ok(0)
}
}
}
impl Default for RoomManager {
fn default() -> Self {
Self::new()
}
}
/// WebSocket manager.
#[derive(Clone)]
pub struct WebSocketManager {
public_room: Arc<RoomManager>,
}
impl WebSocketManager {
pub fn new() -> Self {
Self {
public_room: Arc::new(RoomManager::new()),
}
}
/// Send message to public room.
pub async fn send_to_public(&self, message: WsMessage) -> Result<usize, String> {
self.public_room.get_or_create_room("public").await;
self.public_room.send_to_room("public", message).await
}
/// Send message to a dedicated client room.
pub async fn send_to_client(&self, client_id: Uuid, message: WsMessage) -> Result<usize, String> {
self.public_room
.send_to_room(&client_id.to_string(), message)
.await
}
}
impl Default for WebSocketManager {
fn default() -> Self {
Self::new()
}
}
/// Public websocket handler.
pub async fn public_websocket_handler(
ws: WebSocketUpgrade,
State(state): State<crate::AppState>,
) -> impl IntoResponse {
let ws_manager = state.ws_manager.clone();
let app_state = state.clone();
ws.on_upgrade(move |socket| handle_socket(socket, ws_manager, "public".to_string(), app_state))
}
/// Client websocket handler.
pub async fn client_websocket_handler(
ws: WebSocketUpgrade,
Path(client_id): Path<Uuid>,
State(state): State<crate::AppState>,
) -> impl IntoResponse {
let ws_manager = state.ws_manager.clone();
let room_id = client_id.to_string();
let app_state = state.clone();
ws.on_upgrade(move |socket| handle_socket(socket, ws_manager, room_id, app_state))
}
/// Handle websocket connection for one room.
async fn handle_socket(
mut socket: WebSocket,
ws_manager: Arc<WebSocketManager>,
room_id: String,
state: crate::AppState,
) {
let room_sender = ws_manager.public_room.get_or_create_room(&room_id).await;
let mut rx = room_sender.subscribe();
let mut can_write = false;
loop {
tokio::select! {
maybe_msg = socket.recv() => {
match maybe_msg {
Some(Ok(msg)) => {
if matches!(msg, Message::Close(_)) {
break;
}
match msg {
Message::Text(text) => {
match serde_json::from_str::<WsClientMessage>(&text) {
Ok(WsClientMessage::AuthWrite(payload)) => {
can_write = state.config.verify_write_key(&payload.key);
if !can_write {
tracing::warn!("WebSocket write auth failed in room {}", room_id);
}
}
Ok(WsClientMessage::PointSetValueBatch(payload)) => {
let response = if !can_write {
crate::connection::BatchSetPointValueRes {
success: false,
err_msg: Some("write permission denied".to_string()),
success_count: 0,
failed_count: 0,
results: vec![],
}
} else {
match state.connection_manager.write_point_values_batch(payload).await {
Ok(v) => v,
Err(e) => crate::connection::BatchSetPointValueRes {
success: false,
err_msg: Some(e),
success_count: 0,
failed_count: 1,
results: vec![crate::connection::SetPointValueResItem {
point_id: Uuid::nil(),
success: false,
err_msg: Some("Internal write error".to_string()),
}],
},
}
};
if let Err(e) = ws_manager
.public_room
.send_to_room(&room_id, WsMessage::PointSetValueBatchResult(response))
.await
{
tracing::error!(
"Failed to send PointSetValueBatchResult to room {}: {}",
room_id,
e
);
}
}
Err(e) => {
tracing::warn!(
"Invalid websocket message in room {}: {}",
room_id,
e
);
}
}
}
_ => {
tracing::debug!("Received WebSocket message from room {}: {:?}", room_id, msg);
}
}
}
Some(Err(e)) => {
tracing::error!("WebSocket error in room {}: {}", room_id, e);
break;
}
None => break,
}
}
room_message = rx.recv() => {
match room_message {
Ok(message) => match serde_json::to_string(&message) {
Ok(json_str) => {
if socket.send(Message::Text(json_str.into())).await.is_err() {
break;
}
}
Err(e) => {
tracing::error!("Failed to serialize websocket message: {}", e);
}
},
Err(broadcast::error::RecvError::Lagged(skipped)) => {
tracing::warn!("WebSocket room {} lagged, skipped {} messages", room_id, skipped);
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
}
}
ws_manager.public_room.remove_room_if_empty(&room_id).await;
}