use axum::{ extract::{ Query, State, WebSocketUpgrade, ws::{Message, WebSocket}, }, response::Response, }; use futures_util::{ SinkExt as _, StreamExt, stream::{SplitSink, SplitStream}, }; use serde::Deserialize; use tokio::sync::broadcast::Receiver; use crate::{auth::verify_token, message::ChatMessage, state::AppState}; #[derive(Debug, Deserialize)] pub struct WebsocketHandlerParams { token: String, } pub async fn websocket_handler( upgrade: WebSocketUpgrade, Query(params): Query, State(state): State, ) -> Response { upgrade.on_upgrade(move |websocket| handler(websocket, state, params)) } async fn handler(mut socket: WebSocket, state: AppState, params: WebsocketHandlerParams) { let Some(claims) = verify_token(¶ms.token) else { let _ = socket.close(); return; }; let _ = socket.send((claims.user_id).to_string().into()).await; let (sender, receiver) = socket.split(); let broadcast_receiver = state.broadcast_sender.subscribe(); tokio::spawn(receive(receiver)); tokio::spawn(send(sender, broadcast_receiver)); } async fn send( mut sender: SplitSink, mut broadcast_receiver: Receiver, ) { while let Ok(message) = broadcast_receiver.recv().await { let Ok(json_message) = serde_json::to_string(&message) else { continue; }; if sender.send(json_message.into()).await.is_err() { break; } } } async fn receive(mut stream: SplitStream) { while let Some(Ok(message)) = stream.next().await { let Ok(text) = message.to_text() else { continue; }; println!("Message: {text}"); } }