forked from 409/chat-app
auth middleware + message history
This commit is contained in:
4
.env
Normal file
4
.env
Normal file
@@ -0,0 +1,4 @@
|
||||
POSTGRES_HOST="172.19.0.2"
|
||||
POSTGRES_PORT="5432"
|
||||
POSTGRES_USER="postgres"
|
||||
POSTGRES_PASSWORD="pg"
|
||||
733
Cargo.lock
generated
733
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -15,3 +15,5 @@ axum_typed_multipart = "0.16.2"
|
||||
serde_json = "1.0.140"
|
||||
chrono = { version = "0.4.41", features = ["serde"] }
|
||||
jsonwebtoken = "9.3.1"
|
||||
sqlx = { version = "0.8.6", features = ["chrono", "postgres", "runtime-tokio"] }
|
||||
dotenv = "0.15.0"
|
||||
|
||||
9
compose.yaml
Normal file
9
compose.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
services:
|
||||
postgres:
|
||||
container_name: 'chat-app-pg'
|
||||
image: 'postgres:17'
|
||||
restart: 'unless-stopped'
|
||||
volumes:
|
||||
- './postgres:/var/lib/postgresql/data'
|
||||
environment:
|
||||
POSTGRES_PASSWORD: 'pg'
|
||||
70
src/auth.rs
70
src/auth.rs
@@ -1,15 +1,22 @@
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use axum::{extract::State, http::StatusCode};
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{FromRow, Postgres};
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Claims {
|
||||
pub user_id: u32,
|
||||
pub user_id: i64,
|
||||
pub iat: usize,
|
||||
pub exp: usize,
|
||||
}
|
||||
@@ -29,7 +36,7 @@ pub static AUTH_PUBLIC_KEY: LazyLock<DecodingKey> = LazyLock::new(|| {
|
||||
});
|
||||
|
||||
impl Claims {
|
||||
pub fn new(user_id: u32) -> Self {
|
||||
pub fn new(user_id: i64) -> Self {
|
||||
let now = (Utc::now().timestamp_millis() / 1000) as usize;
|
||||
Self {
|
||||
user_id,
|
||||
@@ -40,16 +47,26 @@ impl Claims {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, FromRow)]
|
||||
pub struct User {
|
||||
id: i64,
|
||||
}
|
||||
|
||||
pub async fn get_auth_token(State(state): State<AppState>) -> Result<String, StatusCode> {
|
||||
let mut next_client_id = state.next_client_id.lock().await;
|
||||
let claims = Claims::new(*next_client_id);
|
||||
let Ok(user) = sqlx::query_as::<Postgres, User>("INSERT INTO users DEFAULT VALUES RETURNING *")
|
||||
.fetch_one(&state.pg_pool)
|
||||
.await
|
||||
else {
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
};
|
||||
|
||||
*next_client_id += 1;
|
||||
let claims = Claims::new(user.id);
|
||||
|
||||
encode(&Header::new(Algorithm::RS512), &claims, &AUTH_SECRET_KEY).map_err(|e| {
|
||||
dbg!(&e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})
|
||||
let Ok(token) = encode(&Header::new(Algorithm::RS512), &claims, &AUTH_SECRET_KEY) else {
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
};
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub fn verify_token(token: &str) -> Option<Claims> {
|
||||
@@ -57,9 +74,34 @@ pub fn verify_token(token: &str) -> Option<Claims> {
|
||||
|
||||
decode::<Claims>(token, key, &Validation::new(Algorithm::RS512))
|
||||
.map(|token_data| token_data.claims)
|
||||
.map_err(|e| {
|
||||
println!("{e:?}");
|
||||
e
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
||||
const AUTH_HEADER_PREFIX: &'static str = "Bearer ";
|
||||
|
||||
pub async fn authentication_middleware(
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let Some(Ok(auth_header)) = request
|
||||
.headers()
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.map(|header| header.to_str())
|
||||
else {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
};
|
||||
|
||||
if auth_header.len() <= AUTH_HEADER_PREFIX.len() {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let token = auth_header.split_at(AUTH_HEADER_PREFIX.len()).1;
|
||||
|
||||
let Some(claims) = verify_token(token) else {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
};
|
||||
|
||||
request.extensions_mut().insert(claims);
|
||||
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
|
||||
26
src/main.rs
26
src/main.rs
@@ -1,27 +1,43 @@
|
||||
mod auth;
|
||||
mod message;
|
||||
mod message_history;
|
||||
mod postgres;
|
||||
mod send_message_handler;
|
||||
mod state;
|
||||
mod websockets;
|
||||
|
||||
use auth::authentication_middleware;
|
||||
use axum::{
|
||||
Router,
|
||||
middleware::from_fn,
|
||||
routing::{get, post},
|
||||
};
|
||||
use postgres::get_postgres_pool;
|
||||
use state::AppState;
|
||||
use tokio::{net::TcpListener, sync::broadcast::channel};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let (sender, _) = channel(32);
|
||||
let state = AppState::new(sender);
|
||||
dotenv::dotenv().ok();
|
||||
|
||||
let router: Router = Router::new()
|
||||
.route("/auth", get(auth::get_auth_token))
|
||||
.route("/message", post(send_message_handler::handler))
|
||||
let pool = get_postgres_pool().await?;
|
||||
|
||||
let (sender, _) = channel(32);
|
||||
let state = AppState::new(sender, pool);
|
||||
|
||||
let unauthenticated_routes: Router = Router::new()
|
||||
.route("/ws", get(websockets::websocket_handler))
|
||||
.route("/auth", get(auth::get_auth_token))
|
||||
.with_state(state.clone());
|
||||
|
||||
let authenticated_router: Router = Router::new()
|
||||
.route("/message", post(send_message_handler::handler))
|
||||
.route("/messages", get(message_history::handler))
|
||||
.layer(from_fn(authentication_middleware))
|
||||
.with_state(state);
|
||||
|
||||
let router = unauthenticated_routes.merge(authenticated_router);
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:3000").await?;
|
||||
|
||||
let _ = axum::serve(listener, router).await;
|
||||
|
||||
@@ -1,20 +1,44 @@
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use chrono::NaiveDateTime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Pool, Postgres, prelude::FromRow};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ChatMessage {
|
||||
pub sender_id: u32,
|
||||
#[sqlx(rename = "user_id")]
|
||||
pub sender_id: i64,
|
||||
pub content: String,
|
||||
#[sqlx(rename = "created_at")]
|
||||
pub timestamp: NaiveDateTime,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(sender_id: u32, content: String) -> Self {
|
||||
Self {
|
||||
sender_id,
|
||||
content,
|
||||
timestamp: Utc::now().naive_utc(),
|
||||
}
|
||||
pub async fn new(
|
||||
pool: &Pool<Postgres>,
|
||||
sender_id: i64,
|
||||
content: String,
|
||||
) -> anyhow::Result<Self> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
let message: Self = sqlx::query_as(
|
||||
"INSERT INTO messages (
|
||||
user_id,
|
||||
content
|
||||
) VALUES (
|
||||
$1,
|
||||
$2
|
||||
)
|
||||
RETURNING
|
||||
*
|
||||
",
|
||||
)
|
||||
.bind(sender_id)
|
||||
.bind(&content)
|
||||
.fetch_one(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
}
|
||||
|
||||
17
src/message_history.rs
Normal file
17
src/message_history.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use axum::{Json, extract::State, http::StatusCode};
|
||||
use sqlx::Postgres;
|
||||
|
||||
use crate::{message::ChatMessage, state::AppState};
|
||||
|
||||
pub async fn handler(State(state): State<AppState>) -> Result<Json<Vec<ChatMessage>>, StatusCode> {
|
||||
let Ok(messages) = sqlx::query_as::<Postgres, ChatMessage>(
|
||||
"SELECT * FROM messages ORDER BY created_at ASC LIMIT 100",
|
||||
)
|
||||
.fetch_all(&state.pg_pool)
|
||||
.await
|
||||
else {
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
};
|
||||
|
||||
Ok(Json(messages))
|
||||
}
|
||||
12
src/postgres/mod.rs
Normal file
12
src/postgres/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
use sqlx::{Pool, Postgres};
|
||||
|
||||
pub async fn get_postgres_pool() -> anyhow::Result<Pool<Postgres>> {
|
||||
let host = std::env::var("POSTGRES_HOST")?;
|
||||
let port = std::env::var("POSTGRES_PORT")?;
|
||||
let user = std::env::var("POSTGRES_USER")?;
|
||||
let password = std::env::var("POSTGRES_PASSWORD")?;
|
||||
|
||||
let url = format!("postgresql://{user}:{password}@{host}:{port}");
|
||||
|
||||
Ok(Pool::connect(&url).await?)
|
||||
}
|
||||
@@ -1,26 +1,25 @@
|
||||
use axum::{extract::State, http::StatusCode};
|
||||
use axum::{Extension, extract::State, http::StatusCode};
|
||||
use axum_typed_multipart::{TryFromMultipart, TypedMultipart};
|
||||
use axum_valid::ValidifiedByRef;
|
||||
use validify::Validify;
|
||||
|
||||
use crate::{message::ChatMessage, state::AppState};
|
||||
use crate::{auth::Claims, message::ChatMessage, state::AppState};
|
||||
|
||||
#[derive(Validify, TryFromMultipart)]
|
||||
#[try_from_multipart(rename_all = "kebab-case")]
|
||||
pub struct SendMessageData {
|
||||
client_id: u32,
|
||||
#[modify(trim)]
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub async fn handler(
|
||||
State(state): State<AppState>,
|
||||
Extension(claims): Extension<Claims>,
|
||||
ValidifiedByRef(TypedMultipart(data)): ValidifiedByRef<TypedMultipart<SendMessageData>>,
|
||||
) -> StatusCode {
|
||||
let mut messages = state.messages.lock().await;
|
||||
|
||||
let message = ChatMessage::new(data.client_id, data.content);
|
||||
messages.push(message.clone());
|
||||
let Ok(message) = ChatMessage::new(&state.pg_pool, claims.user_id, data.content).await else {
|
||||
return StatusCode::INTERNAL_SERVER_ERROR;
|
||||
};
|
||||
|
||||
let _ = state.broadcast_sender.send(message);
|
||||
|
||||
|
||||
13
src/state.rs
13
src/state.rs
@@ -1,22 +1,19 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::{Mutex, broadcast::Sender};
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tokio::sync::broadcast::Sender;
|
||||
|
||||
use crate::message::ChatMessage;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppState {
|
||||
pub messages: Arc<Mutex<Vec<ChatMessage>>>,
|
||||
pub next_client_id: Arc<Mutex<u32>>,
|
||||
pub broadcast_sender: Sender<ChatMessage>,
|
||||
pub pg_pool: Pool<Postgres>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(websocket_sender: Sender<ChatMessage>) -> Self {
|
||||
pub fn new(websocket_sender: Sender<ChatMessage>, pool: Pool<Postgres>) -> Self {
|
||||
Self {
|
||||
messages: Arc::default(),
|
||||
next_client_id: Arc::default(),
|
||||
broadcast_sender: websocket_sender,
|
||||
pg_pool: pool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,13 +29,11 @@ pub async fn websocket_handler(
|
||||
}
|
||||
|
||||
async fn handler(mut socket: WebSocket, state: AppState, params: WebsocketHandlerParams) {
|
||||
let Some(claims) = verify_token(¶ms.token) else {
|
||||
if verify_token(¶ms.token).is_none() {
|
||||
let _ = socket.close();
|
||||
return;
|
||||
};
|
||||
|
||||
let _ = socket.send((claims.user_id).to_string().into()).await;
|
||||
|
||||
let (sender, receiver) = socket.split();
|
||||
|
||||
let broadcast_receiver = state.broadcast_sender.subscribe();
|
||||
|
||||
Reference in New Issue
Block a user