bbx_net/websocket/
server.rs

1//! WebSocket server for phone PWA connections.
2
3use std::{collections::HashMap, net::SocketAddr, sync::Arc};
4
5use tokio::{
6    net::TcpListener,
7    sync::{RwLock, broadcast, mpsc},
8};
9
10use crate::{
11    address::NodeId,
12    buffer::NetBufferProducer,
13    clock::ClockSync,
14    error::{NetError, Result},
15    message::NetMessage,
16    websocket::{
17        connection::ConnectionState,
18        protocol::{ClientMessage, ParamState, ServerMessage},
19        room::RoomManager,
20    },
21};
22
23/// Configuration for the WebSocket server.
24pub struct WsServerConfig {
25    /// Address to bind to (e.g., "0.0.0.0:8080").
26    pub bind_addr: SocketAddr,
27    /// Maximum connections per room.
28    pub max_connections_per_room: usize,
29    /// Ping interval in milliseconds.
30    pub ping_interval_ms: u64,
31    /// Connection timeout in milliseconds.
32    pub connection_timeout_ms: u64,
33    /// Capacity for the state broadcast channel.
34    pub broadcast_capacity: usize,
35}
36
37impl Default for WsServerConfig {
38    fn default() -> Self {
39        Self {
40            bind_addr: "0.0.0.0:8080".parse().unwrap(),
41            max_connections_per_room: 100,
42            ping_interval_ms: 5000,
43            connection_timeout_ms: 30000,
44            broadcast_capacity: 1024,
45        }
46    }
47}
48
49/// Commands from main thread to server.
50pub enum ServerCommand {
51    /// Broadcast a parameter update to all connected clients.
52    BroadcastParam { name: String, value: f32 },
53    /// Broadcast visualization data to all clients.
54    BroadcastVisualization { data: Vec<f32> },
55    /// Create a new room and return its code via the response channel.
56    CreateRoom {
57        response: tokio::sync::oneshot::Sender<String>,
58    },
59    /// Close a room.
60    CloseRoom { code: String },
61    /// Shutdown the server.
62    Shutdown,
63}
64
65/// WebSocket server for phone PWA connections.
66pub struct WsServer {
67    config: WsServerConfig,
68    room_manager: Arc<RwLock<RoomManager>>,
69    clock_sync: Arc<ClockSync>,
70    producer: NetBufferProducer,
71    state_broadcast: broadcast::Sender<ServerMessage>,
72    command_rx: mpsc::Receiver<ServerCommand>,
73    connections: Arc<RwLock<HashMap<NodeId, ConnectionState>>>,
74    message_tx: mpsc::Sender<NetMessage>,
75    message_rx: mpsc::Receiver<NetMessage>,
76}
77
78impl WsServer {
79    /// Create a new WebSocket server.
80    ///
81    /// # Arguments
82    ///
83    /// * `config` - Server configuration
84    /// * `producer` - Buffer producer for sending messages to audio thread
85    /// * `command_rx` - Channel for receiving commands from main thread
86    pub fn new(config: WsServerConfig, producer: NetBufferProducer, command_rx: mpsc::Receiver<ServerCommand>) -> Self {
87        let (state_broadcast, _) = broadcast::channel(config.broadcast_capacity);
88        let (message_tx, message_rx) = mpsc::channel(1024);
89
90        Self {
91            config,
92            room_manager: Arc::new(RwLock::new(RoomManager::new())),
93            clock_sync: Arc::new(ClockSync::new()),
94            producer,
95            state_broadcast,
96            command_rx,
97            connections: Arc::new(RwLock::new(HashMap::new())),
98            message_tx,
99            message_rx,
100        }
101    }
102
103    /// Create a command channel pair for communicating with the server.
104    pub fn command_channel() -> (mpsc::Sender<ServerCommand>, mpsc::Receiver<ServerCommand>) {
105        mpsc::channel(256)
106    }
107
108    /// Get the clock synchronization instance.
109    pub fn clock(&self) -> Arc<ClockSync> {
110        Arc::clone(&self.clock_sync)
111    }
112
113    /// Get a handle to subscribe to state broadcasts.
114    pub fn subscribe(&self) -> broadcast::Receiver<ServerMessage> {
115        self.state_broadcast.subscribe()
116    }
117
118    /// Run the WebSocket server.
119    ///
120    /// This method runs until a shutdown command is received or an error occurs.
121    pub async fn run(mut self) -> Result<()> {
122        let listener = TcpListener::bind(&self.config.bind_addr)
123            .await
124            .map_err(|_| NetError::ConnectionFailed)?;
125
126        loop {
127            tokio::select! {
128                accept_result = listener.accept() => {
129                    match accept_result {
130                        Ok((stream, addr)) => {
131                            self.handle_connection(stream, addr).await;
132                        }
133                        Err(_) => {
134                            continue;
135                        }
136                    }
137                }
138
139                Some(msg) = self.message_rx.recv() => {
140                    let _ = self.producer.try_send(msg);
141                }
142
143                Some(cmd) = self.command_rx.recv() => {
144                    match cmd {
145                        ServerCommand::Shutdown => {
146                            return Ok(());
147                        }
148                        ServerCommand::BroadcastParam { name, value } => {
149                            let _ = self.state_broadcast.send(ServerMessage::Update {
150                                param: name,
151                                value,
152                            });
153                        }
154                        ServerCommand::BroadcastVisualization { data: _ } => {
155                            // Visualization broadcasts would be handled here
156                        }
157                        ServerCommand::CreateRoom { response } => {
158                            let code = self.room_manager.write().await.create_room();
159                            let _ = response.send(code);
160                        }
161                        ServerCommand::CloseRoom { code } => {
162                            let clients = self.room_manager.write().await.close_room(&code);
163                            for node_id in clients {
164                                self.connections.write().await.remove(&node_id);
165                            }
166                            let _ = self.state_broadcast.send(ServerMessage::RoomClosed);
167                        }
168                    }
169                }
170            }
171        }
172    }
173
174    async fn handle_connection(&self, stream: tokio::net::TcpStream, _addr: SocketAddr) {
175        use futures_util::{SinkExt, StreamExt};
176        use tokio_tungstenite::accept_async;
177
178        let ws_stream = match accept_async(stream).await {
179            Ok(ws) => ws,
180            Err(_) => return,
181        };
182
183        let (mut write, mut read) = ws_stream.split();
184
185        let room_manager = Arc::clone(&self.room_manager);
186        let clock_sync = Arc::clone(&self.clock_sync);
187        let connections = Arc::clone(&self.connections);
188        let mut broadcast_rx = self.state_broadcast.subscribe();
189        let message_tx = self.message_tx.clone();
190
191        tokio::spawn(async move {
192            let mut node_id: Option<NodeId> = None;
193            let mut room_code: Option<String> = None;
194            let mut should_cleanup = true;
195
196            loop {
197                tokio::select! {
198                    msg = read.next() => {
199                        match msg {
200                            Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => {
201                                if let Ok(client_msg) = serde_json::from_str::<ClientMessage>(&text) {
202                                    match client_msg {
203                                        ClientMessage::Join { room_code: code, client_name } => {
204                                            let mut rm = room_manager.write().await;
205                                            let connections_read = connections.read().await;
206                                            let new_node_id = NodeId::generate_unique(
207                                                clock_sync.now().as_micros(),
208                                                |id| connections_read.contains_key(id),
209                                            );
210                                            drop(connections_read);
211
212                                            match rm.join_room(&code, new_node_id, client_name.clone()) {
213                                                Ok(()) => {
214                                                    node_id = Some(new_node_id);
215                                                    room_code = Some(code);
216
217                                                    let state = ConnectionState::new(
218                                                        new_node_id,
219                                                        room_code.clone().unwrap(),
220                                                        client_name,
221                                                    );
222                                                    connections.write().await.insert(new_node_id, state);
223
224                                                    let response = ServerMessage::Welcome {
225                                                        node_id: new_node_id.to_uuid_string(),
226                                                        server_time: clock_sync.now().as_micros(),
227                                                    };
228
229                                                    let json = serde_json::to_string(&response).unwrap();
230                                                    let _ = write.send(
231                                                        tokio_tungstenite::tungstenite::Message::Text(json)
232                                                    ).await;
233                                                }
234                                                Err(NetError::InvalidRoomCode) => {
235                                                    let response = ServerMessage::invalid_room_code();
236                                                    let json = serde_json::to_string(&response).unwrap();
237                                                    let _ = write.send(
238                                                        tokio_tungstenite::tungstenite::Message::Text(json)
239                                                    ).await;
240                                                }
241                                                Err(NetError::RoomFull) => {
242                                                    let response = ServerMessage::room_full();
243                                                    let json = serde_json::to_string(&response).unwrap();
244                                                    let _ = write.send(
245                                                        tokio_tungstenite::tungstenite::Message::Text(json)
246                                                    ).await;
247                                                }
248                                                Err(_) => {}
249                                            }
250                                        }
251                                        ClientMessage::Parameter { param, value, at: _ } => {
252                                            if let Some(nid) = node_id {
253                                                let msg = NetMessage::param_change(&param, value, nid)
254                                                    .with_timestamp(clock_sync.now());
255                                                let _ = message_tx.send(msg).await;
256                                            }
257                                        }
258                                        ClientMessage::Trigger { name, at: _ } => {
259                                            if let Some(nid) = node_id {
260                                                let msg = if let Some((x, y)) = parse_trigger_coordinates(&name) {
261                                                    NetMessage::trigger_with_coordinates(&name, x, y, nid)
262                                                } else {
263                                                    NetMessage::trigger(&name, nid)
264                                                };
265                                                let msg = msg.with_timestamp(clock_sync.now());
266                                                let _ = message_tx.send(msg).await;
267                                            }
268                                        }
269                                        ClientMessage::Ping { client_time } => {
270                                            let server_time = clock_sync.now().as_micros();
271                                            let response = ServerMessage::Pong {
272                                                client_time,
273                                                server_time,
274                                            };
275                                            let json = serde_json::to_string(&response).unwrap();
276                                            let _ = write.send(
277                                                tokio_tungstenite::tungstenite::Message::Text(json)
278                                            ).await;
279                                        }
280                                        ClientMessage::Leave => {
281                                            if let (Some(nid), Some(code)) = (node_id, &room_code) {
282                                                room_manager.write().await.leave_room(code, nid);
283                                                connections.write().await.remove(&nid);
284                                            }
285                                            should_cleanup = false;
286                                            break;
287                                        }
288                                        _ => {}
289                                    }
290                                }
291                            }
292                            Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) | None => {
293                                break;
294                            }
295                            _ => {}
296                        }
297                    }
298
299                    Ok(broadcast_msg) = broadcast_rx.recv() => {
300                        if node_id.is_some() {
301                            let json = serde_json::to_string(&broadcast_msg).unwrap();
302                            let _ = write.send(
303                                tokio_tungstenite::tungstenite::Message::Text(json)
304                            ).await;
305                        }
306                    }
307                }
308            }
309
310            if should_cleanup && let (Some(nid), Some(code)) = (node_id, room_code) {
311                room_manager.write().await.leave_room(&code, nid);
312                connections.write().await.remove(&nid);
313            }
314        });
315    }
316}
317
318/// Parse coordinates from trigger names in "prefix:x,y" format.
319fn parse_trigger_coordinates(name: &str) -> Option<(f32, f32)> {
320    let coords = name.split(':').nth(1)?;
321    let mut parts = coords.split(',');
322    let x: f32 = parts.next()?.parse().ok()?;
323    let y: f32 = parts.next()?.parse().ok()?;
324    Some((x, y))
325}
326
327/// Create a parameter state list for state synchronization.
328#[allow(dead_code)]
329pub fn create_param_state_list(params: &[(String, f32)]) -> Vec<ParamState> {
330    params
331        .iter()
332        .map(|(name, value)| ParamState::new(name.clone(), *value))
333        .collect()
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_server_config_default() {
342        let config = WsServerConfig::default();
343        assert_eq!(config.bind_addr.port(), 8080);
344        assert_eq!(config.max_connections_per_room, 100);
345    }
346
347    #[test]
348    fn test_create_param_state_list() {
349        let params = vec![("gain".to_string(), 0.5), ("freq".to_string(), 440.0)];
350
351        let states = create_param_state_list(&params);
352
353        assert_eq!(states.len(), 2);
354        assert_eq!(states[0].name, "gain");
355        assert!((states[0].value - 0.5).abs() < f32::EPSILON);
356    }
357
358    #[test]
359    fn test_command_channel() {
360        let (tx, _rx) = WsServer::command_channel();
361        assert!(!tx.is_closed());
362    }
363}