1use std::{collections::HashMap, net::SocketAddr, sync::Arc};
4
5use tokio::{
6 net::TcpListener,
7 sync::{RwLock, broadcast, mpsc},
8};
9
10use crate::{
11 address::NodeId,
12 buffer::NetBufferProducer,
13 clock::ClockSync,
14 error::{NetError, Result},
15 message::NetMessage,
16 websocket::{
17 connection::ConnectionState,
18 protocol::{ClientMessage, ParamState, ServerMessage},
19 room::RoomManager,
20 },
21};
22
23pub struct WsServerConfig {
25 pub bind_addr: SocketAddr,
27 pub max_connections_per_room: usize,
29 pub ping_interval_ms: u64,
31 pub connection_timeout_ms: u64,
33 pub broadcast_capacity: usize,
35}
36
37impl Default for WsServerConfig {
38 fn default() -> Self {
39 Self {
40 bind_addr: "0.0.0.0:8080".parse().unwrap(),
41 max_connections_per_room: 100,
42 ping_interval_ms: 5000,
43 connection_timeout_ms: 30000,
44 broadcast_capacity: 1024,
45 }
46 }
47}
48
49pub enum ServerCommand {
51 BroadcastParam { name: String, value: f32 },
53 BroadcastVisualization { data: Vec<f32> },
55 CreateRoom {
57 response: tokio::sync::oneshot::Sender<String>,
58 },
59 CloseRoom { code: String },
61 Shutdown,
63}
64
65pub struct WsServer {
67 config: WsServerConfig,
68 room_manager: Arc<RwLock<RoomManager>>,
69 clock_sync: Arc<ClockSync>,
70 producer: NetBufferProducer,
71 state_broadcast: broadcast::Sender<ServerMessage>,
72 command_rx: mpsc::Receiver<ServerCommand>,
73 connections: Arc<RwLock<HashMap<NodeId, ConnectionState>>>,
74 message_tx: mpsc::Sender<NetMessage>,
75 message_rx: mpsc::Receiver<NetMessage>,
76}
77
78impl WsServer {
79 pub fn new(config: WsServerConfig, producer: NetBufferProducer, command_rx: mpsc::Receiver<ServerCommand>) -> Self {
87 let (state_broadcast, _) = broadcast::channel(config.broadcast_capacity);
88 let (message_tx, message_rx) = mpsc::channel(1024);
89
90 Self {
91 config,
92 room_manager: Arc::new(RwLock::new(RoomManager::new())),
93 clock_sync: Arc::new(ClockSync::new()),
94 producer,
95 state_broadcast,
96 command_rx,
97 connections: Arc::new(RwLock::new(HashMap::new())),
98 message_tx,
99 message_rx,
100 }
101 }
102
103 pub fn command_channel() -> (mpsc::Sender<ServerCommand>, mpsc::Receiver<ServerCommand>) {
105 mpsc::channel(256)
106 }
107
108 pub fn clock(&self) -> Arc<ClockSync> {
110 Arc::clone(&self.clock_sync)
111 }
112
113 pub fn subscribe(&self) -> broadcast::Receiver<ServerMessage> {
115 self.state_broadcast.subscribe()
116 }
117
118 pub async fn run(mut self) -> Result<()> {
122 let listener = TcpListener::bind(&self.config.bind_addr)
123 .await
124 .map_err(|_| NetError::ConnectionFailed)?;
125
126 loop {
127 tokio::select! {
128 accept_result = listener.accept() => {
129 match accept_result {
130 Ok((stream, addr)) => {
131 self.handle_connection(stream, addr).await;
132 }
133 Err(_) => {
134 continue;
135 }
136 }
137 }
138
139 Some(msg) = self.message_rx.recv() => {
140 let _ = self.producer.try_send(msg);
141 }
142
143 Some(cmd) = self.command_rx.recv() => {
144 match cmd {
145 ServerCommand::Shutdown => {
146 return Ok(());
147 }
148 ServerCommand::BroadcastParam { name, value } => {
149 let _ = self.state_broadcast.send(ServerMessage::Update {
150 param: name,
151 value,
152 });
153 }
154 ServerCommand::BroadcastVisualization { data: _ } => {
155 }
157 ServerCommand::CreateRoom { response } => {
158 let code = self.room_manager.write().await.create_room();
159 let _ = response.send(code);
160 }
161 ServerCommand::CloseRoom { code } => {
162 let clients = self.room_manager.write().await.close_room(&code);
163 for node_id in clients {
164 self.connections.write().await.remove(&node_id);
165 }
166 let _ = self.state_broadcast.send(ServerMessage::RoomClosed);
167 }
168 }
169 }
170 }
171 }
172 }
173
174 async fn handle_connection(&self, stream: tokio::net::TcpStream, _addr: SocketAddr) {
175 use futures_util::{SinkExt, StreamExt};
176 use tokio_tungstenite::accept_async;
177
178 let ws_stream = match accept_async(stream).await {
179 Ok(ws) => ws,
180 Err(_) => return,
181 };
182
183 let (mut write, mut read) = ws_stream.split();
184
185 let room_manager = Arc::clone(&self.room_manager);
186 let clock_sync = Arc::clone(&self.clock_sync);
187 let connections = Arc::clone(&self.connections);
188 let mut broadcast_rx = self.state_broadcast.subscribe();
189 let message_tx = self.message_tx.clone();
190
191 tokio::spawn(async move {
192 let mut node_id: Option<NodeId> = None;
193 let mut room_code: Option<String> = None;
194 let mut should_cleanup = true;
195
196 loop {
197 tokio::select! {
198 msg = read.next() => {
199 match msg {
200 Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => {
201 if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(&text) {
202 match client_msg {
203 ClientMessage::Join { room_code: code, client_name } => {
204 let mut rm = room_manager.write().await;
205 let connections_read = connections.read().await;
206 let new_node_id = NodeId::generate_unique(
207 clock_sync.now().as_micros(),
208 |id| connections_read.contains_key(id),
209 );
210 drop(connections_read);
211
212 match rm.join_room(&code, new_node_id, client_name.clone()) {
213 Ok(()) => {
214 node_id = Some(new_node_id);
215 room_code = Some(code);
216
217 let state = ConnectionState::new(
218 new_node_id,
219 room_code.clone().unwrap(),
220 client_name,
221 );
222 connections.write().await.insert(new_node_id, state);
223
224 let response = ServerMessage::Welcome {
225 node_id: new_node_id.to_uuid_string(),
226 server_time: clock_sync.now().as_micros(),
227 };
228
229 let json = serde_json::to_string(&response).unwrap();
230 let _ = write.send(
231 tokio_tungstenite::tungstenite::Message::Text(json)
232 ).await;
233 }
234 Err(NetError::InvalidRoomCode) => {
235 let response = ServerMessage::invalid_room_code();
236 let json = serde_json::to_string(&response).unwrap();
237 let _ = write.send(
238 tokio_tungstenite::tungstenite::Message::Text(json)
239 ).await;
240 }
241 Err(NetError::RoomFull) => {
242 let response = ServerMessage::room_full();
243 let json = serde_json::to_string(&response).unwrap();
244 let _ = write.send(
245 tokio_tungstenite::tungstenite::Message::Text(json)
246 ).await;
247 }
248 Err(_) => {}
249 }
250 }
251 ClientMessage::Parameter { param, value, at: _ } => {
252 if let Some(nid) = node_id {
253 let msg = NetMessage::param_change(¶m, value, nid)
254 .with_timestamp(clock_sync.now());
255 let _ = message_tx.send(msg).await;
256 }
257 }
258 ClientMessage::Trigger { name, at: _ } => {
259 if let Some(nid) = node_id {
260 let msg = if let Some((x, y)) = parse_trigger_coordinates(&name) {
261 NetMessage::trigger_with_coordinates(&name, x, y, nid)
262 } else {
263 NetMessage::trigger(&name, nid)
264 };
265 let msg = msg.with_timestamp(clock_sync.now());
266 let _ = message_tx.send(msg).await;
267 }
268 }
269 ClientMessage::Ping { client_time } => {
270 let server_time = clock_sync.now().as_micros();
271 let response = ServerMessage::Pong {
272 client_time,
273 server_time,
274 };
275 let json = serde_json::to_string(&response).unwrap();
276 let _ = write.send(
277 tokio_tungstenite::tungstenite::Message::Text(json)
278 ).await;
279 }
280 ClientMessage::Leave => {
281 if let (Some(nid), Some(code)) = (node_id, &room_code) {
282 room_manager.write().await.leave_room(code, nid);
283 connections.write().await.remove(&nid);
284 }
285 should_cleanup = false;
286 break;
287 }
288 _ => {}
289 }
290 }
291 }
292 Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
293 break;
294 }
295 _ => {}
296 }
297 }
298
299 Ok(broadcast_msg) = broadcast_rx.recv() => {
300 if node_id.is_some() {
301 let json = serde_json::to_string(&broadcast_msg).unwrap();
302 let _ = write.send(
303 tokio_tungstenite::tungstenite::Message::Text(json)
304 ).await;
305 }
306 }
307 }
308 }
309
310 if should_cleanup && let (Some(nid), Some(code)) = (node_id, room_code) {
311 room_manager.write().await.leave_room(&code, nid);
312 connections.write().await.remove(&nid);
313 }
314 });
315 }
316}
317
318fn parse_trigger_coordinates(name: &str) -> Option<(f32, f32)> {
320 let coords = name.split(':').nth(1)?;
321 let mut parts = coords.split(',');
322 let x: f32 = parts.next()?.parse().ok()?;
323 let y: f32 = parts.next()?.parse().ok()?;
324 Some((x, y))
325}
326
327#[allow(dead_code)]
329pub fn create_param_state_list(params: &[(String, f32)]) -> Vec<ParamState> {
330 params
331 .iter()
332 .map(|(name, value)| ParamState::new(name.clone(), *value))
333 .collect()
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_server_config_default() {
342 let config = WsServerConfig::default();
343 assert_eq!(config.bind_addr.port(), 8080);
344 assert_eq!(config.max_connections_per_room, 100);
345 }
346
347 #[test]
348 fn test_create_param_state_list() {
349 let params = vec![("gain".to_string(), 0.5), ("freq".to_string(), 440.0)];
350
351 let states = create_param_state_list(¶ms);
352
353 assert_eq!(states.len(), 2);
354 assert_eq!(states[0].name, "gain");
355 assert!((states[0].value - 0.5).abs() < f32::EPSILON);
356 }
357
358 #[test]
359 fn test_command_channel() {
360 let (tx, _rx) = WsServer::command_channel();
361 assert!(!tx.is_closed());
362 }
363}