From 19fd34b92cf65cd2ce99e48c5ae743993147206a Mon Sep 17 00:00:00 2001 From: 409 <409dev@protonmail.com> Date: Sat, 7 Jun 2025 13:35:49 +0200 Subject: [PATCH] websockets + send messages as json --- Cargo.lock | 1 + Cargo.toml | 1 + src/main.rs | 12 +++++------ src/message.rs | 14 ++++++++++++ src/send_message_handler.rs | 15 +++++++------ src/state.rs | 22 +++++++++++++++++++ src/websockets.rs | 43 +++++++++++++++++++++++++++---------- 7 files changed, 84 insertions(+), 24 deletions(-) create mode 100644 src/message.rs create mode 100644 src/state.rs diff --git a/Cargo.lock b/Cargo.lock index c056bca..b56e798 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -338,6 +338,7 @@ dependencies = [ "axum_typed_multipart", "futures-util", "serde", + "serde_json", "tokio", "validify", ] diff --git a/Cargo.toml b/Cargo.toml index dbc5722..05f2535 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,4 @@ tokio = { version = "1.45.1", features = ["full"] } axum-valid = { path = "../axum-valid", features = ["basic", "typed_multipart", "validify"], default-features = false } validify = "2.0.0" axum_typed_multipart = "0.16.2" +serde_json = "1.0.140" diff --git a/src/main.rs b/src/main.rs index 2792dcc..94ff645 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,19 @@ +mod message; mod send_message_handler; +mod state; mod websockets; -use std::sync::Arc; - use axum::{ Router, routing::{get, post}, }; -use tokio::{net::TcpListener, sync::Mutex}; - -pub type MyState = Arc>>; +use state::AppState; +use tokio::{net::TcpListener, sync::broadcast::channel}; #[tokio::main] async fn main() -> anyhow::Result<()> { - let state = MyState::default(); + let (sender, _) = channel(32); + let state = AppState::new(sender); let router: Router = Router::new() .route("/message", post(send_message_handler::handler)) diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..dc71ce7 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,14 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ChatMessage { + pub sender_id: u32, + pub content: String, +} + +impl ChatMessage { + pub fn new(sender_id: u32, content: String) -> Self { + Self { sender_id, content } + } +} diff --git a/src/send_message_handler.rs b/src/send_message_handler.rs index 4548238..9fc6435 100644 --- a/src/send_message_handler.rs +++ b/src/send_message_handler.rs @@ -3,25 +3,26 @@ use axum_typed_multipart::{TryFromMultipart, TypedMultipart}; use axum_valid::ValidifiedByRef; use validify::Validify; -use crate::MyState; +use crate::{message::ChatMessage, state::AppState}; #[derive(Validify, TryFromMultipart)] +#[try_from_multipart(rename_all = "kebab-case")] pub struct SendMessageData { + client_id: u32, #[modify(trim)] content: String, } pub async fn handler( - State(state): State, + State(state): State, ValidifiedByRef(TypedMultipart(data)): ValidifiedByRef>, ) -> StatusCode { - let mut messages = state.lock().await; + let mut messages = state.messages.lock().await; - // println!("{}", &data.content); + let message = ChatMessage::new(data.client_id, data.content); + messages.push(message.clone()); - messages.push(data.content); - - dbg!(&messages); + let _ = state.broadcast_sender.send(message); StatusCode::OK } diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..290aeb4 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,22 @@ +use std::sync::Arc; + +use tokio::sync::{Mutex, broadcast::Sender}; + +use crate::message::ChatMessage; + +#[derive(Debug, Clone)] +pub struct AppState { + pub messages: Arc>>, + pub next_client_id: Arc>, + pub broadcast_sender: Sender, +} + +impl AppState { + pub fn new(websocket_sender: Sender) -> Self { + Self { + messages: Arc::default(), + next_client_id: Arc::default(), + broadcast_sender: websocket_sender, + } + } +} diff --git a/src/websockets.rs b/src/websockets.rs index 5af1f32..2867bca 100644 --- a/src/websockets.rs +++ b/src/websockets.rs @@ -1,8 +1,6 @@ -use std::time::Duration; - use axum::{ extract::{ - WebSocketUpgrade, + State, WebSocketUpgrade, ws::{Message, WebSocket}, }, response::Response, @@ -12,22 +10,45 @@ use futures_util::{ SinkExt as _, StreamExt, stream::{SplitSink, SplitStream}, }; +use tokio::sync::broadcast::Receiver; -pub async fn websocket_handler(upgrade: WebSocketUpgrade) -> Response { - upgrade.on_upgrade(handler) +use crate::{message::ChatMessage, state::AppState}; + +pub async fn websocket_handler( + upgrade: WebSocketUpgrade, + State(state): State, +) -> Response { + upgrade.on_upgrade(move |websocket| handler(websocket, state)) } -async fn handler(socket: WebSocket) { +async fn handler(mut socket: WebSocket, state: AppState) { + let mut next_client_id = state.next_client_id.lock().await; + + *next_client_id += 1; + let _ = socket.send((*next_client_id).to_string().into()).await; + + drop(next_client_id); + let (sender, receiver) = socket.split(); + let broadcast_receiver = state.broadcast_sender.subscribe(); + tokio::spawn(receive(receiver)); - tokio::spawn(send(sender)); + tokio::spawn(send(sender, broadcast_receiver)); } -async fn send(mut sender: SplitSink) { - loop { - let _ = sender.send("Hello client!".into()); - tokio::time::sleep(Duration::from_secs(5)).await; +async fn send( + mut sender: SplitSink, + mut broadcast_receiver: Receiver, +) { + while let Ok(message) = broadcast_receiver.recv().await { + let Ok(json_message) = serde_json::to_string(&message) else { + continue; + }; + + if sender.send(json_message.into()).await.is_err() { + break; + } } }