auth middleware + message history

This commit is contained in:
2025-06-08 21:29:06 +02:00
parent 4f03cde9b5
commit 368ae7209c
12 changed files with 888 additions and 57 deletions

4
.env Normal file
View File

@@ -0,0 +1,4 @@
POSTGRES_HOST="172.19.0.2"
POSTGRES_PORT="5432"
POSTGRES_USER="postgres"
POSTGRES_PASSWORD="pg"

733
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,3 +15,5 @@ axum_typed_multipart = "0.16.2"
serde_json = "1.0.140" serde_json = "1.0.140"
chrono = { version = "0.4.41", features = ["serde"] } chrono = { version = "0.4.41", features = ["serde"] }
jsonwebtoken = "9.3.1" jsonwebtoken = "9.3.1"
sqlx = { version = "0.8.6", features = ["chrono", "postgres", "runtime-tokio"] }
dotenv = "0.15.0"

9
compose.yaml Normal file
View File

@@ -0,0 +1,9 @@
services:
postgres:
container_name: 'chat-app-pg'
image: 'postgres:17'
restart: 'unless-stopped'
volumes:
- './postgres:/var/lib/postgresql/data'
environment:
POSTGRES_PASSWORD: 'pg'

View File

@@ -1,15 +1,22 @@
use std::sync::LazyLock; use std::sync::LazyLock;
use axum::{extract::State, http::StatusCode}; use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use chrono::Utc; use chrono::Utc;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{FromRow, Postgres};
use crate::state::AppState; use crate::state::AppState;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Claims { pub struct Claims {
pub user_id: u32, pub user_id: i64,
pub iat: usize, pub iat: usize,
pub exp: usize, pub exp: usize,
} }
@@ -29,7 +36,7 @@ pub static AUTH_PUBLIC_KEY: LazyLock<DecodingKey> = LazyLock::new(|| {
}); });
impl Claims { impl Claims {
pub fn new(user_id: u32) -> Self { pub fn new(user_id: i64) -> Self {
let now = (Utc::now().timestamp_millis() / 1000) as usize; let now = (Utc::now().timestamp_millis() / 1000) as usize;
Self { Self {
user_id, user_id,
@@ -40,16 +47,26 @@ impl Claims {
} }
} }
#[derive(Debug, Clone, FromRow)]
pub struct User {
id: i64,
}
pub async fn get_auth_token(State(state): State<AppState>) -> Result<String, StatusCode> { pub async fn get_auth_token(State(state): State<AppState>) -> Result<String, StatusCode> {
let mut next_client_id = state.next_client_id.lock().await; let Ok(user) = sqlx::query_as::<Postgres, User>("INSERT INTO users DEFAULT VALUES RETURNING *")
let claims = Claims::new(*next_client_id); .fetch_one(&state.pg_pool)
.await
else {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
*next_client_id += 1; let claims = Claims::new(user.id);
encode(&Header::new(Algorithm::RS512), &claims, &AUTH_SECRET_KEY).map_err(|e| { let Ok(token) = encode(&Header::new(Algorithm::RS512), &claims, &AUTH_SECRET_KEY) else {
dbg!(&e); return Err(StatusCode::INTERNAL_SERVER_ERROR);
StatusCode::INTERNAL_SERVER_ERROR };
})
Ok(token)
} }
pub fn verify_token(token: &str) -> Option<Claims> { pub fn verify_token(token: &str) -> Option<Claims> {
@@ -57,9 +74,34 @@ pub fn verify_token(token: &str) -> Option<Claims> {
decode::<Claims>(token, key, &Validation::new(Algorithm::RS512)) decode::<Claims>(token, key, &Validation::new(Algorithm::RS512))
.map(|token_data| token_data.claims) .map(|token_data| token_data.claims)
.map_err(|e| {
println!("{e:?}");
e
})
.ok() .ok()
} }
const AUTH_HEADER_PREFIX: &'static str = "Bearer ";
pub async fn authentication_middleware(
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let Some(Ok(auth_header)) = request
.headers()
.get(axum::http::header::AUTHORIZATION)
.map(|header| header.to_str())
else {
return Err(StatusCode::UNAUTHORIZED);
};
if auth_header.len() <= AUTH_HEADER_PREFIX.len() {
return Err(StatusCode::UNAUTHORIZED);
}
let token = auth_header.split_at(AUTH_HEADER_PREFIX.len()).1;
let Some(claims) = verify_token(token) else {
return Err(StatusCode::UNAUTHORIZED);
};
request.extensions_mut().insert(claims);
return Ok(next.run(request).await);
}

View File

@@ -1,27 +1,43 @@
mod auth; mod auth;
mod message; mod message;
mod message_history;
mod postgres;
mod send_message_handler; mod send_message_handler;
mod state; mod state;
mod websockets; mod websockets;
use auth::authentication_middleware;
use axum::{ use axum::{
Router, Router,
middleware::from_fn,
routing::{get, post}, routing::{get, post},
}; };
use postgres::get_postgres_pool;
use state::AppState; use state::AppState;
use tokio::{net::TcpListener, sync::broadcast::channel}; use tokio::{net::TcpListener, sync::broadcast::channel};
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let (sender, _) = channel(32); dotenv::dotenv().ok();
let state = AppState::new(sender);
let router: Router = Router::new() let pool = get_postgres_pool().await?;
.route("/auth", get(auth::get_auth_token))
.route("/message", post(send_message_handler::handler)) let (sender, _) = channel(32);
let state = AppState::new(sender, pool);
let unauthenticated_routes: Router = Router::new()
.route("/ws", get(websockets::websocket_handler)) .route("/ws", get(websockets::websocket_handler))
.route("/auth", get(auth::get_auth_token))
.with_state(state.clone());
let authenticated_router: Router = Router::new()
.route("/message", post(send_message_handler::handler))
.route("/messages", get(message_history::handler))
.layer(from_fn(authentication_middleware))
.with_state(state); .with_state(state);
let router = unauthenticated_routes.merge(authenticated_router);
let listener = TcpListener::bind("127.0.0.1:3000").await?; let listener = TcpListener::bind("127.0.0.1:3000").await?;
let _ = axum::serve(listener, router).await; let _ = axum::serve(listener, router).await;

View File

@@ -1,20 +1,44 @@
use chrono::{NaiveDateTime, Utc}; use chrono::NaiveDateTime;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{Pool, Postgres, prelude::FromRow};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ChatMessage { pub struct ChatMessage {
pub sender_id: u32, #[sqlx(rename = "user_id")]
pub sender_id: i64,
pub content: String, pub content: String,
#[sqlx(rename = "created_at")]
pub timestamp: NaiveDateTime, pub timestamp: NaiveDateTime,
} }
impl ChatMessage { impl ChatMessage {
pub fn new(sender_id: u32, content: String) -> Self { pub async fn new(
Self { pool: &Pool<Postgres>,
sender_id, sender_id: i64,
content, content: String,
timestamp: Utc::now().naive_utc(), ) -> anyhow::Result<Self> {
} let mut tx = pool.begin().await?;
let message: Self = sqlx::query_as(
"INSERT INTO messages (
user_id,
content
) VALUES (
$1,
$2
)
RETURNING
*
",
)
.bind(sender_id)
.bind(&content)
.fetch_one(&mut *tx)
.await?;
tx.commit().await?;
Ok(message)
} }
} }

17
src/message_history.rs Normal file
View File

@@ -0,0 +1,17 @@
use axum::{Json, extract::State, http::StatusCode};
use sqlx::Postgres;
use crate::{message::ChatMessage, state::AppState};
pub async fn handler(State(state): State<AppState>) -> Result<Json<Vec<ChatMessage>>, StatusCode> {
let Ok(messages) = sqlx::query_as::<Postgres, ChatMessage>(
"SELECT * FROM messages ORDER BY created_at ASC LIMIT 100",
)
.fetch_all(&state.pg_pool)
.await
else {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
Ok(Json(messages))
}

12
src/postgres/mod.rs Normal file
View File

@@ -0,0 +1,12 @@
use sqlx::{Pool, Postgres};
pub async fn get_postgres_pool() -> anyhow::Result<Pool<Postgres>> {
let host = std::env::var("POSTGRES_HOST")?;
let port = std::env::var("POSTGRES_PORT")?;
let user = std::env::var("POSTGRES_USER")?;
let password = std::env::var("POSTGRES_PASSWORD")?;
let url = format!("postgresql://{user}:{password}@{host}:{port}");
Ok(Pool::connect(&url).await?)
}

View File

@@ -1,26 +1,25 @@
use axum::{extract::State, http::StatusCode}; use axum::{Extension, extract::State, http::StatusCode};
use axum_typed_multipart::{TryFromMultipart, TypedMultipart}; use axum_typed_multipart::{TryFromMultipart, TypedMultipart};
use axum_valid::ValidifiedByRef; use axum_valid::ValidifiedByRef;
use validify::Validify; use validify::Validify;
use crate::{message::ChatMessage, state::AppState}; use crate::{auth::Claims, message::ChatMessage, state::AppState};
#[derive(Validify, TryFromMultipart)] #[derive(Validify, TryFromMultipart)]
#[try_from_multipart(rename_all = "kebab-case")] #[try_from_multipart(rename_all = "kebab-case")]
pub struct SendMessageData { pub struct SendMessageData {
client_id: u32,
#[modify(trim)] #[modify(trim)]
content: String, content: String,
} }
pub async fn handler( pub async fn handler(
State(state): State<AppState>, State(state): State<AppState>,
Extension(claims): Extension<Claims>,
ValidifiedByRef(TypedMultipart(data)): ValidifiedByRef<TypedMultipart<SendMessageData>>, ValidifiedByRef(TypedMultipart(data)): ValidifiedByRef<TypedMultipart<SendMessageData>>,
) -> StatusCode { ) -> StatusCode {
let mut messages = state.messages.lock().await; let Ok(message) = ChatMessage::new(&state.pg_pool, claims.user_id, data.content).await else {
return StatusCode::INTERNAL_SERVER_ERROR;
let message = ChatMessage::new(data.client_id, data.content); };
messages.push(message.clone());
let _ = state.broadcast_sender.send(message); let _ = state.broadcast_sender.send(message);

View File

@@ -1,22 +1,19 @@
use std::sync::Arc; use sqlx::{Pool, Postgres};
use tokio::sync::broadcast::Sender;
use tokio::sync::{Mutex, broadcast::Sender};
use crate::message::ChatMessage; use crate::message::ChatMessage;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AppState { pub struct AppState {
pub messages: Arc<Mutex<Vec<ChatMessage>>>,
pub next_client_id: Arc<Mutex<u32>>,
pub broadcast_sender: Sender<ChatMessage>, pub broadcast_sender: Sender<ChatMessage>,
pub pg_pool: Pool<Postgres>,
} }
impl AppState { impl AppState {
pub fn new(websocket_sender: Sender<ChatMessage>) -> Self { pub fn new(websocket_sender: Sender<ChatMessage>, pool: Pool<Postgres>) -> Self {
Self { Self {
messages: Arc::default(),
next_client_id: Arc::default(),
broadcast_sender: websocket_sender, broadcast_sender: websocket_sender,
pg_pool: pool,
} }
} }
} }

View File

@@ -29,13 +29,11 @@ pub async fn websocket_handler(
} }
async fn handler(mut socket: WebSocket, state: AppState, params: WebsocketHandlerParams) { async fn handler(mut socket: WebSocket, state: AppState, params: WebsocketHandlerParams) {
let Some(claims) = verify_token(&params.token) else { if verify_token(&params.token).is_none() {
let _ = socket.close(); let _ = socket.close();
return; return;
}; };
let _ = socket.send((claims.user_id).to_string().into()).await;
let (sender, receiver) = socket.split(); let (sender, receiver) = socket.split();
let broadcast_receiver = state.broadcast_sender.subscribe(); let broadcast_receiver = state.broadcast_sender.subscribe();