From ba2b94b78396d0e6f7c6784098ee716f0949d713 Mon Sep 17 00:00:00 2001 From: gengteng Date: Sat, 22 Jul 2023 13:04:03 +0800 Subject: [PATCH] add test --- Cargo.toml | 3 +- README.md | 4 + examples/custom.rs | 2 - {examples => tests}/basic.rs | 9 ++- tests/custom.rs | 138 +++++++++++++++++++++++++++++++++++ 5 files changed, 152 insertions(+), 4 deletions(-) delete mode 100644 examples/custom.rs rename {examples => tests}/basic.rs (97%) create mode 100644 tests/custom.rs diff --git a/Cargo.toml b/Cargo.toml index 3262283..fd45fb0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "axum-valid" -version = "0.2.3" +version = "0.3.0" description = "Validation tools for axum using the validator library." authors = ["GengTeng "] license = "MIT" @@ -31,6 +31,7 @@ hyper = { version = "0.14.26", features = ["full"] } reqwest = { version = "0.11.18", features = ["json"] } serde = { version = "1.0.163", features = ["derive"] } validator = { version = "0.16.0", features = ["derive"] } +serde_json = "1.0.103" [features] default = ["json", "form", "query"] diff --git a/README.md b/README.md index 9e2a2a0..3436794 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ cargo add axum-valid use validator::Validate; use serde::Deserialize; use axum_valid::Valid; +use axum::extract::Query; +use axum::Json; #[derive(Debug, Validate, Deserialize)] pub struct Pager { @@ -35,3 +37,5 @@ pub async fn get_page_by_json( assert!((1..).contains(&pager.page_no)); } ``` + +For more usage examples, please refer to the `basic.rs` and `custom.rs` files in the `tests` directory. \ No newline at end of file diff --git a/examples/custom.rs b/examples/custom.rs deleted file mode 100644 index 7f755fb..0000000 --- a/examples/custom.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[tokio::main] -async fn main() {} diff --git a/examples/basic.rs b/tests/basic.rs similarity index 97% rename from examples/basic.rs rename to tests/basic.rs index 9e60b29..d941a73 100644 --- a/examples/basic.rs +++ b/tests/basic.rs @@ -1,3 +1,10 @@ +//! # Basic extractors validation +//! +//! * `Path` +//! * `Query` +//! * `Form` +//! * `Json` + use axum::extract::{Path, Query}; use axum::http::StatusCode; use axum::routing::{get, post}; @@ -14,7 +21,7 @@ mod route { pub const JSON: &'static str = "/json"; } -#[tokio::main] +#[tokio::test] async fn main() -> anyhow::Result<()> { let router = Router::new() .route(route::PATH, get(extract_path)) diff --git a/tests/custom.rs b/tests/custom.rs new file mode 100644 index 0000000..af29f2e --- /dev/null +++ b/tests/custom.rs @@ -0,0 +1,138 @@ +//! # Custom extractor validation +//! + +use axum::extract::FromRequestParts; +use axum::http::request::Parts; +use axum::response::{IntoResponse, Response}; +use axum::routing::get; +use axum::Router; +use axum_valid::{HasValidate, Valid, ValidRejection}; +use hyper::StatusCode; +use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; +use validator::Validate; + +const MY_DATA_HEADER: &'static str = "My-Data"; + +// 1. Implement your own extractor. +// 1.1. Define you own extractor type. +#[derive(Debug, Serialize, Deserialize, Validate)] +struct MyData { + #[validate(length(min = 1, max = 10))] + content: String, +} + +// 1.2. Define you own `Rejection` type and implement `IntoResponse` for it. +enum MyDataRejection { + Null, + InvalidJson(serde_json::error::Error), +} + +impl IntoResponse for MyDataRejection { + fn into_response(self) -> Response { + match self { + MyDataRejection::Null => { + (StatusCode::BAD_REQUEST, "My-Data header is missing").into_response() + } + MyDataRejection::InvalidJson(e) => ( + StatusCode::BAD_REQUEST, + format!("My-Data is not valid json string: {e}"), + ) + .into_response(), + } + } +} + +// 1.3. Implement your extractor (`FromRequestParts` or `FromRequest`) +#[axum::async_trait] +impl FromRequestParts for MyData +where + S: Send + Sync, +{ + type Rejection = MyDataRejection; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + let Some(value) = parts.headers.get(MY_DATA_HEADER) else { + return Err(MyDataRejection::Null); + }; + + serde_json::from_slice(value.as_bytes()).map_err(|e| MyDataRejection::InvalidJson(e)) + } +} + +// 2. Use axum-valid to validate the extractor +// 2.1. Implement `HasValidate` for your extractor +impl HasValidate for MyData { + type Validate = Self; + type Rejection = MyDataRejection; + + fn get_validate(&self) -> &Self::Validate { + &self + } +} + +// 2.2. Implement `From` for `ValidRejection`. +impl From for ValidRejection { + fn from(value: MyDataRejection) -> Self { + Self::Inner(value) + } +} + +#[tokio::test] +async fn main() -> anyhow::Result<()> { + let router = Router::new().route("/", get(handler)); + + let server = axum::Server::bind(&SocketAddr::from(([0u8, 0, 0, 0], 0u16))) + .serve(router.into_make_service()); + + let server_addr = server.local_addr(); + println!("Axum server address: {}.", server_addr); + + let (server_guard, close) = tokio::sync::oneshot::channel::<()>(); + let server_handle = tokio::spawn(server.with_graceful_shutdown(async move { + let _ = close.await; + })); + + let client = reqwest::Client::default(); + let url = format!("http://{}/", server_addr); + + let valid_my_data = MyData { + content: String::from("hello"), + }; + let valid_my_data_response = client + .get(&url) + .header(MY_DATA_HEADER, serde_json::to_string(&valid_my_data)?) + .send() + .await?; + assert_eq!(valid_my_data_response.status(), StatusCode::OK); + + let invalid_json = String::from("{{}"); + let valid_my_data_response = client + .get(&url) + .header(MY_DATA_HEADER, invalid_json) + .send() + .await?; + assert_eq!(valid_my_data_response.status(), StatusCode::BAD_REQUEST); + + let invalid_my_data = MyData { + content: String::new(), + }; + let invalid_my_data_response = client + .get(&url) + .header(MY_DATA_HEADER, serde_json::to_string(&invalid_my_data)?) + .send() + .await?; + assert_eq!(invalid_my_data_response.status(), StatusCode::BAD_REQUEST); + println!("Valid works."); + + drop(server_guard); + server_handle.await??; + Ok(()) +} + +async fn handler(Valid(my_data): Valid) -> StatusCode { + match my_data.validate() { + Ok(_) => StatusCode::OK, + Err(_) => StatusCode::INTERNAL_SERVER_ERROR, + } +}