feat: add support for sonic

This commit is contained in:
gengteng
2024-03-01 10:31:31 +08:00
parent dd973172ee
commit 3e1736aff7
6 changed files with 398 additions and 86 deletions

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "axum-valid" name = "axum-valid"
version = "0.15.1" version = "0.16.0"
description = "Provides validation extractors for your Axum application, allowing you to validate data using validator, garde, validify or all of them." description = "Provides validation extractors for your Axum application, allowing you to validate data using validator, garde, validify or all of them."
authors = ["GengTeng <me@gteng.org>"] authors = ["GengTeng <me@gteng.org>"]
license = "MIT" license = "MIT"
@@ -36,7 +36,7 @@ default-features = false
optional = true optional = true
[dependencies.axum-serde] [dependencies.axum-serde]
version = "0.2.0" version = "0.3.0"
optional = true optional = true
[dependencies.axum_typed_multipart] [dependencies.axum_typed_multipart]
@@ -55,7 +55,6 @@ optional = true
anyhow = "1.0.75" anyhow = "1.0.75"
axum = { version = "0.7.1", features = ["macros"] } axum = { version = "0.7.1", features = ["macros"] }
tokio = { version = "1.34.0", features = ["full"] } tokio = { version = "1.34.0", features = ["full"] }
hyper = { version = "0.14.27", features = ["full"] }
reqwest = { version = "0.11.23", features = ["json", "multipart"] } reqwest = { version = "0.11.23", features = ["json", "multipart"] }
serde = { version = "1.0.195", features = ["derive"] } serde = { version = "1.0.195", features = ["derive"] }
validator = { version = "0.16.1", features = ["derive"] } validator = { version = "0.16.1", features = ["derive"] }
@@ -83,6 +82,7 @@ msgpack = ["dep:axum-serde", "axum-serde/msgpack"]
yaml = ["dep:axum-serde", "axum-serde/yaml"] yaml = ["dep:axum-serde", "axum-serde/yaml"]
xml = ["dep:axum-serde", "axum-serde/xml"] xml = ["dep:axum-serde", "axum-serde/xml"]
toml = ["dep:axum-serde", "axum-serde/toml"] toml = ["dep:axum-serde", "axum-serde/toml"]
sonic = ["dep:axum-serde", "axum-serde/sonic"]
typed_multipart = ["dep:axum_typed_multipart"] typed_multipart = ["dep:axum_typed_multipart"]
into_json = ["json", "dep:serde", "garde?/serde"] into_json = ["json", "dep:serde", "garde?/serde"]
422 = [] 422 = []
@@ -92,7 +92,7 @@ extra_query = ["extra", "axum-extra/query"]
extra_form = ["extra", "axum-extra/form"] extra_form = ["extra", "axum-extra/form"]
extra_protobuf = ["extra", "axum-extra/protobuf"] extra_protobuf = ["extra", "axum-extra/protobuf"]
all_extra_types = ["extra", "typed_header", "extra_typed_path", "extra_query", "extra_form", "extra_protobuf"] all_extra_types = ["extra", "typed_header", "extra_typed_path", "extra_query", "extra_form", "extra_protobuf"]
all_types = ["json", "form", "query", "msgpack", "yaml", "xml", "toml", "all_extra_types", "typed_multipart"] all_types = ["json", "form", "query", "msgpack", "yaml", "xml", "toml", "sonic", "all_extra_types", "typed_multipart"]
full_validator = ["validator", "all_types", "422", "into_json"] full_validator = ["validator", "all_types", "422", "into_json"]
full_garde = ["garde", "all_types", "422", "into_json"] full_garde = ["garde", "all_types", "422", "into_json"]
full_validify = ["validify", "all_types", "422", "into_json"] full_validify = ["validify", "all_types", "422", "into_json"]

View File

@@ -7,9 +7,8 @@ use axum::http::StatusCode;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Form, Json, Router}; use axum::{Form, Json, Router};
use garde::Validate; use garde::Validate;
use hyper::Method;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use reqwest::Url; use reqwest::{Method, Url};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::any::type_name; use std::any::type_name;
use std::net::SocketAddr; use std::net::SocketAddr;
@@ -19,8 +18,8 @@ use tokio::net::TcpListener;
#[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)]
#[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] #[cfg_attr(feature = "extra_protobuf", derive(prost::Message))]
#[cfg_attr( #[cfg_attr(
feature = "typed_multipart", feature = "typed_multipart",
derive(axum_typed_multipart::TryFromMultipart) derive(axum_typed_multipart::TryFromMultipart)
)] )]
pub struct ParametersGarde { pub struct ParametersGarde {
#[garde(range(min = 5, max = 10))] #[garde(range(min = 5, max = 10))]
@@ -77,13 +76,13 @@ async fn test_main() -> anyhow::Result<()> {
.route(route::JSON, post(extract_json)); .route(route::JSON, post(extract_json));
#[cfg(feature = "typed_header")] #[cfg(feature = "typed_header")]
let router = router.route( let router = router.route(
typed_header::route::TYPED_HEADER, typed_header::route::TYPED_HEADER,
post(typed_header::extract_typed_header), post(typed_header::extract_typed_header),
); );
#[cfg(feature = "typed_multipart")] #[cfg(feature = "typed_multipart")]
let router = router let router = router
.route( .route(
typed_multipart::route::TYPED_MULTIPART, typed_multipart::route::TYPED_MULTIPART,
post(typed_multipart::extract_typed_multipart), post(typed_multipart::extract_typed_multipart),
@@ -94,7 +93,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra")] #[cfg(feature = "extra")]
let router = router let router = router
.route(extra::route::CACHED, post(extra::extract_cached)) .route(extra::route::CACHED, post(extra::extract_cached))
.route( .route(
extra::route::WITH_REJECTION, extra::route::WITH_REJECTION,
@@ -106,34 +105,34 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_typed_path")] #[cfg(feature = "extra_typed_path")]
let router = router.route( let router = router.route(
extra_typed_path::route::EXTRA_TYPED_PATH, extra_typed_path::route::EXTRA_TYPED_PATH,
get(extra_typed_path::extract_extra_typed_path), get(extra_typed_path::extract_extra_typed_path),
); );
#[cfg(feature = "extra_query")] #[cfg(feature = "extra_query")]
let router = router.route( let router = router.route(
extra_query::route::EXTRA_QUERY, extra_query::route::EXTRA_QUERY,
post(extra_query::extract_extra_query), post(extra_query::extract_extra_query),
); );
#[cfg(feature = "extra_form")] #[cfg(feature = "extra_form")]
let router = router.route( let router = router.route(
extra_form::route::EXTRA_FORM, extra_form::route::EXTRA_FORM,
post(extra_form::extract_extra_form), post(extra_form::extract_extra_form),
); );
#[cfg(feature = "extra_protobuf")] #[cfg(feature = "extra_protobuf")]
let router = router.route( let router = router.route(
extra_protobuf::route::EXTRA_PROTOBUF, extra_protobuf::route::EXTRA_PROTOBUF,
post(extra_protobuf::extract_extra_protobuf), post(extra_protobuf::extract_extra_protobuf),
); );
#[cfg(feature = "yaml")] #[cfg(feature = "yaml")]
let router = router.route(yaml::route::YAML, post(yaml::extract_yaml)); let router = router.route(yaml::route::YAML, post(yaml::extract_yaml));
#[cfg(feature = "msgpack")] #[cfg(feature = "msgpack")]
let router = router let router = router
.route(msgpack::route::MSGPACK, post(msgpack::extract_msgpack)) .route(msgpack::route::MSGPACK, post(msgpack::extract_msgpack))
.route( .route(
msgpack::route::MSGPACK_RAW, msgpack::route::MSGPACK_RAW,
@@ -141,10 +140,13 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "xml")] #[cfg(feature = "xml")]
let router = router.route(xml::route::XML, post(xml::extract_xml)); let router = router.route(xml::route::XML, post(xml::extract_xml));
#[cfg(feature = "toml")] #[cfg(feature = "toml")]
let router = router.route(toml::route::TOML, post(toml::extract_toml)); let router = router.route(toml::route::TOML, post(toml::extract_toml));
#[cfg(feature = "sonic")]
let router = router.route(sonic::route::SONIC, post(sonic::extract_sonic));
let router = router.with_state(MyState::default()); let router = router.with_state(MyState::default());
@@ -340,7 +342,7 @@ async fn test_main() -> anyhow::Result<()> {
extra_typed_path_type_name, extra_typed_path_type_name,
invalid_extra_typed_path_response, invalid_extra_typed_path_response,
) )
.await; .await;
println!("All {} tests passed.", extra_typed_path_type_name); println!("All {} tests passed.", extra_typed_path_type_name);
Ok(()) Ok(())
} }
@@ -410,6 +412,14 @@ async fn test_main() -> anyhow::Result<()> {
.await?; .await?;
} }
#[cfg(feature = "sonic")]
{
use axum_serde::Sonic;
test_executor
.execute::<Sonic<ParametersGarde>>(Method::POST, sonic::route::SONIC)
.await?;
}
Ok(()) Ok(())
} }
@@ -528,7 +538,6 @@ fn validate_again<V: Validate>(validate: V, context: V::Context) -> StatusCode {
#[cfg(feature = "typed_header")] #[cfg(feature = "typed_header")]
mod typed_header { mod typed_header {
pub(crate) mod route { pub(crate) mod route {
pub const TYPED_HEADER: &str = "/typed_header"; pub const TYPED_HEADER: &str = "/typed_header";
} }
@@ -553,9 +562,9 @@ mod typed_header {
} }
fn decode<'i, I>(values: &mut I) -> Result<Self, Error> fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
where where
Self: Sized, Self: Sized,
I: Iterator<Item = &'i HeaderValue>, I: Iterator<Item=&'i HeaderValue>,
{ {
let value = values.next().ok_or_else(Error::invalid)?; let value = values.next().ok_or_else(Error::invalid)?;
let src = std::str::from_utf8(value.as_bytes()).map_err(|_| Error::invalid())?; let src = std::str::from_utf8(value.as_bytes()).map_err(|_| Error::invalid())?;
@@ -646,6 +655,7 @@ mod extra {
pub const WITH_REJECTION: &str = "/with_rejection"; pub const WITH_REJECTION: &str = "/with_rejection";
pub const WITH_REJECTION_GARDE: &str = "/with_rejection_garde"; pub const WITH_REJECTION_GARDE: &str = "/with_rejection_garde";
} }
pub const PARAMETERS_HEADER: &str = "parameters-header"; pub const PARAMETERS_HEADER: &str = "parameters-header";
pub const CACHED_REJECTION_STATUS: StatusCode = StatusCode::FORBIDDEN; pub const CACHED_REJECTION_STATUS: StatusCode = StatusCode::FORBIDDEN;
@@ -673,8 +683,8 @@ mod extra {
// 1.3. Implement your extractor (`FromRequestParts` or `FromRequest`) // 1.3. Implement your extractor (`FromRequestParts` or `FromRequest`)
#[axum::async_trait] #[axum::async_trait]
impl<S> FromRequestParts<S> for ParametersGarde impl<S> FromRequestParts<S> for ParametersGarde
where where
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = ParametersRejection; type Rejection = ParametersRejection;
@@ -903,6 +913,7 @@ mod msgpack {
) -> StatusCode { ) -> StatusCode {
validate_again(parameters, ()) validate_again(parameters, ())
} }
pub async fn extract_msgpack_raw( pub async fn extract_msgpack_raw(
Garde(MsgPackRaw(parameters)): Garde<MsgPackRaw<ParametersGarde>>, Garde(MsgPackRaw(parameters)): Garde<MsgPackRaw<ParametersGarde>>,
) -> StatusCode { ) -> StatusCode {
@@ -941,3 +952,19 @@ mod toml {
validate_again(parameters, ()) validate_again(parameters, ())
} }
} }
#[cfg(feature = "sonic")]
mod sonic {
use super::{validate_again, ParametersGarde};
use crate::Garde;
use axum::http::StatusCode;
use axum_serde::Sonic;
pub mod route {
pub const SONIC: &str = "/sonic";
}
pub async fn extract_sonic(Garde(Sonic(parameters)): Garde<Sonic<ParametersGarde>>) -> StatusCode {
validate_again(parameters, ())
}
}

View File

@@ -29,6 +29,7 @@ pub mod toml;
pub mod typed_multipart; pub mod typed_multipart;
#[cfg(feature = "xml")] #[cfg(feature = "xml")]
pub mod xml; pub mod xml;
mod sonic;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};

160
src/sonic.rs Normal file
View File

@@ -0,0 +1,160 @@
//! # Support for `Sonic<T>`
//!
//! ## Feature
//!
//! Enable the `sonic` feature to use `Valid<Sonic<T>>`.
//!
//! ## Usage
//!
//! 1. Implement `Deserialize` and `Validate` for your data type `T`.
//! 2. In your handler function, use `Valid<Sonic<T>>` as some parameter's type.
//!
//! ## Example
//!
//! ```no_run
//! #[cfg(feature = "validator")]
//! mod validator_example {
//! use axum::routing::post;
//! use axum_serde::Sonic;
//! use axum::Router;
//! use axum_valid::Valid;
//! use serde::Deserialize;
//! use validator::Validate;
//!
//! pub fn router() -> Router {
//! Router::new().route("/sonic", post(handler))
//! }
//!
//! async fn handler(Valid(Sonic(parameter)): Valid<Sonic<Parameter>>) {
//! assert!(parameter.validate().is_ok());
//! // Support automatic dereferencing
//! println!("v0 = {}, v1 = {}", parameter.v0, parameter.v1);
//! }
//!
//! #[derive(Validate, Deserialize)]
//! pub struct Parameter {
//! #[validate(range(min = 5, max = 10))]
//! pub v0: i32,
//! #[validate(length(min = 1, max = 10))]
//! pub v1: String,
//! }
//! }
//!
//! #[cfg(feature = "garde")]
//! mod garde_example {
//! use axum::routing::post;
//! use axum_serde::Sonic;
//! use axum::Router;
//! use axum_valid::Garde;
//! use serde::Deserialize;
//! use garde::Validate;
//!
//! pub fn router() -> Router {
//! Router::new().route("/sonic", post(handler))
//! }
//!
//! async fn handler(Garde(Sonic(parameter)): Garde<Sonic<Parameter>>) {
//! assert!(parameter.validate(&()).is_ok());
//! // Support automatic dereferencing
//! println!("v0 = {}, v1 = {}", parameter.v0, parameter.v1);
//! }
//!
//! #[derive(Validate, Deserialize)]
//! pub struct Parameter {
//! #[garde(range(min = 5, max = 10))]
//! pub v0: i32,
//! #[garde(length(min = 1, max = 10))]
//! pub v1: String,
//! }
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> anyhow::Result<()> {
//! # use std::net::SocketAddr;
//! # use axum::Router;
//! # use tokio::net::TcpListener;
//! # let router = Router::new();
//! # #[cfg(feature = "validator")]
//! # let router = router.nest("/validator", validator_example::router());
//! # #[cfg(feature = "garde")]
//! # let router = router.nest("/garde", garde_example::router());
//! # let listener = TcpListener::bind(&SocketAddr::from(([0u8, 0, 0, 0], 0u16))).await?;
//! # axum::serve(listener, router.into_make_service())
//! # .await?;
//! # Ok(())
//! # }
//! ```
use crate::HasValidate;
#[cfg(feature = "validator")]
use crate::HasValidateArgs;
use axum_serde::Sonic;
#[cfg(feature = "validator")]
use validator::ValidateArgs;
impl<T> HasValidate for Sonic<T> {
type Validate = T;
fn get_validate(&self) -> &T {
&self.0
}
}
#[cfg(feature = "validator")]
impl<'v, T: ValidateArgs<'v>> HasValidateArgs<'v> for Sonic<T> {
type ValidateArgs = T;
fn get_validate_args(&self) -> &Self::ValidateArgs {
&self.0
}
}
#[cfg(feature = "validify")]
impl<T: validify::Modify> crate::HasModify for Sonic<T> {
type Modify = T;
fn get_modify(&mut self) -> &mut Self::Modify {
&mut self.0
}
}
#[cfg(feature = "validify")]
impl<T> crate::PayloadExtractor for Sonic<T> {
type Payload = T;
fn get_payload(self) -> Self::Payload {
self.0
}
}
#[cfg(feature = "validify")]
impl<T: validify::Validify + validify::ValidifyPayload> crate::HasValidify for Sonic<T> {
type Validify = T;
type PayloadExtractor = Sonic<T::Payload>;
fn from_validify(v: Self::Validify) -> Self {
Sonic(v)
}
}
#[cfg(test)]
mod tests {
use crate::tests::{ValidTest, ValidTestParameter};
use axum::http::StatusCode;
use axum_serde::Sonic;
use reqwest::RequestBuilder;
use serde::Serialize;
impl<T: ValidTestParameter + Serialize> ValidTest for Sonic<T> {
const ERROR_STATUS_CODE: StatusCode = StatusCode::UNPROCESSABLE_ENTITY;
fn set_valid_request(builder: RequestBuilder) -> RequestBuilder {
builder.json(T::valid())
}
fn set_error_request(builder: RequestBuilder) -> RequestBuilder {
builder.json(&serde_json::json!({ "a" : 1}))
}
fn set_invalid_request(builder: RequestBuilder) -> RequestBuilder {
builder.json(T::invalid())
}
}
}

View File

@@ -6,9 +6,8 @@ use axum::extract::{FromRef, Path, Query};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Form, Json, Router}; use axum::{Form, Json, Router};
use hyper::Method;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use reqwest::Url; use reqwest::{Method, Url};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::any::type_name; use std::any::type_name;
use std::net::SocketAddr; use std::net::SocketAddr;
@@ -20,8 +19,8 @@ use validator::{Validate, ValidateArgs, ValidationError};
#[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)]
#[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] #[cfg_attr(feature = "extra_protobuf", derive(prost::Message))]
#[cfg_attr( #[cfg_attr(
feature = "typed_multipart", feature = "typed_multipart",
derive(axum_typed_multipart::TryFromMultipart) derive(axum_typed_multipart::TryFromMultipart)
)] )]
pub struct Parameters { pub struct Parameters {
#[validate(range(min = 5, max = 10))] #[validate(range(min = 5, max = 10))]
@@ -35,8 +34,8 @@ pub struct Parameters {
#[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)]
#[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] #[cfg_attr(feature = "extra_protobuf", derive(prost::Message))]
#[cfg_attr( #[cfg_attr(
feature = "typed_multipart", feature = "typed_multipart",
derive(axum_typed_multipart::TryFromMultipart) derive(axum_typed_multipart::TryFromMultipart)
)] )]
pub struct ParametersEx { pub struct ParametersEx {
#[validate(custom(function = "validate_v0", arg = "&'v_a RangeInclusive<i32>"))] #[validate(custom(function = "validate_v0", arg = "&'v_a RangeInclusive<i32>"))]
@@ -156,7 +155,7 @@ async fn test_main() -> anyhow::Result<()> {
.route(route::JSON_EX, post(extract_json_ex)); .route(route::JSON_EX, post(extract_json_ex));
#[cfg(feature = "typed_header")] #[cfg(feature = "typed_header")]
let router = router let router = router
.route( .route(
typed_header::route::TYPED_HEADER, typed_header::route::TYPED_HEADER,
post(typed_header::extract_typed_header), post(typed_header::extract_typed_header),
@@ -167,7 +166,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "typed_multipart")] #[cfg(feature = "typed_multipart")]
let router = router let router = router
.route( .route(
typed_multipart::route::TYPED_MULTIPART, typed_multipart::route::TYPED_MULTIPART,
post(typed_multipart::extract_typed_multipart), post(typed_multipart::extract_typed_multipart),
@@ -186,7 +185,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra")] #[cfg(feature = "extra")]
let router = router let router = router
.route(extra::route::CACHED, post(extra::extract_cached)) .route(extra::route::CACHED, post(extra::extract_cached))
.route(extra::route::CACHED_EX, post(extra::extract_cached_ex)) .route(extra::route::CACHED_EX, post(extra::extract_cached_ex))
.route( .route(
@@ -207,7 +206,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_typed_path")] #[cfg(feature = "extra_typed_path")]
let router = router let router = router
.route( .route(
extra_typed_path::route::EXTRA_TYPED_PATH, extra_typed_path::route::EXTRA_TYPED_PATH,
get(extra_typed_path::extract_extra_typed_path), get(extra_typed_path::extract_extra_typed_path),
@@ -218,7 +217,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_query")] #[cfg(feature = "extra_query")]
let router = router let router = router
.route( .route(
extra_query::route::EXTRA_QUERY, extra_query::route::EXTRA_QUERY,
post(extra_query::extract_extra_query), post(extra_query::extract_extra_query),
@@ -229,7 +228,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_form")] #[cfg(feature = "extra_form")]
let router = router let router = router
.route( .route(
extra_form::route::EXTRA_FORM, extra_form::route::EXTRA_FORM,
post(extra_form::extract_extra_form), post(extra_form::extract_extra_form),
@@ -240,7 +239,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_protobuf")] #[cfg(feature = "extra_protobuf")]
let router = router let router = router
.route( .route(
extra_protobuf::route::EXTRA_PROTOBUF, extra_protobuf::route::EXTRA_PROTOBUF,
post(extra_protobuf::extract_extra_protobuf), post(extra_protobuf::extract_extra_protobuf),
@@ -251,12 +250,12 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "yaml")] #[cfg(feature = "yaml")]
let router = router let router = router
.route(yaml::route::YAML, post(yaml::extract_yaml)) .route(yaml::route::YAML, post(yaml::extract_yaml))
.route(yaml::route::YAML_EX, post(yaml::extract_yaml_ex)); .route(yaml::route::YAML_EX, post(yaml::extract_yaml_ex));
#[cfg(feature = "msgpack")] #[cfg(feature = "msgpack")]
let router = router let router = router
.route(msgpack::route::MSGPACK, post(msgpack::extract_msgpack)) .route(msgpack::route::MSGPACK, post(msgpack::extract_msgpack))
.route( .route(
msgpack::route::MSGPACK_EX, msgpack::route::MSGPACK_EX,
@@ -272,15 +271,20 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "xml")] #[cfg(feature = "xml")]
let router = router let router = router
.route(xml::route::XML, post(xml::extract_xml)) .route(xml::route::XML, post(xml::extract_xml))
.route(xml::route::XML_EX, post(xml::extract_xml_ex)); .route(xml::route::XML_EX, post(xml::extract_xml_ex));
#[cfg(feature = "toml")] #[cfg(feature = "toml")]
let router = router let router = router
.route(toml::route::TOML, post(toml::extract_toml)) .route(toml::route::TOML, post(toml::extract_toml))
.route(toml::route::TOML_EX, post(toml::extract_toml_ex)); .route(toml::route::TOML_EX, post(toml::extract_toml_ex));
#[cfg(feature = "sonic")]
let router = router
.route(sonic::route::SONIC, post(sonic::extract_sonic))
.route(sonic::route::SONIC_EX, post(sonic::extract_sonic_ex));
let router = router.with_state(state); let router = router.with_state(state);
let listener = TcpListener::bind(&SocketAddr::from(([0u8, 0, 0, 0], 0u16))).await?; let listener = TcpListener::bind(&SocketAddr::from(([0u8, 0, 0, 0], 0u16))).await?;
@@ -524,7 +528,7 @@ async fn test_main() -> anyhow::Result<()> {
extra_typed_path_type_name, extra_typed_path_type_name,
invalid_extra_typed_path_response, invalid_extra_typed_path_response,
) )
.await; .await;
println!("All {} tests passed.", extra_typed_path_type_name); println!("All {} tests passed.", extra_typed_path_type_name);
Ok(()) Ok(())
} }
@@ -616,6 +620,17 @@ async fn test_main() -> anyhow::Result<()> {
.await?; .await?;
} }
#[cfg(feature = "sonic")]
{
use axum_serde::Sonic;
test_executor
.execute::<Sonic<Parameters>>(Method::POST, sonic::route::SONIC)
.await?;
test_executor
.execute::<Sonic<Parameters>>(Method::POST, sonic::route::SONIC_EX)
.await?;
}
Ok(()) Ok(())
} }
@@ -776,7 +791,6 @@ fn validate_again_ex<'v, V: ValidateArgs<'v>>(
#[cfg(feature = "typed_header")] #[cfg(feature = "typed_header")]
mod typed_header { mod typed_header {
pub(crate) mod route { pub(crate) mod route {
pub const TYPED_HEADER: &str = "/typed_header"; pub const TYPED_HEADER: &str = "/typed_header";
pub const TYPED_HEADER_EX: &str = "/typed_header_ex"; pub const TYPED_HEADER_EX: &str = "/typed_header_ex";
@@ -812,9 +826,9 @@ mod typed_header {
} }
fn decode<'i, I>(values: &mut I) -> Result<Self, Error> fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
where where
Self: Sized, Self: Sized,
I: Iterator<Item = &'i HeaderValue>, I: Iterator<Item=&'i HeaderValue>,
{ {
let value = values.next().ok_or_else(Error::invalid)?; let value = values.next().ok_or_else(Error::invalid)?;
let src = std::str::from_utf8(value.as_bytes()).map_err(|_| Error::invalid())?; let src = std::str::from_utf8(value.as_bytes()).map_err(|_| Error::invalid())?;
@@ -845,9 +859,9 @@ mod typed_header {
} }
fn decode<'i, I>(values: &mut I) -> Result<Self, Error> fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
where where
Self: Sized, Self: Sized,
I: Iterator<Item = &'i HeaderValue>, I: Iterator<Item=&'i HeaderValue>,
{ {
let value = values.next().ok_or_else(Error::invalid)?; let value = values.next().ok_or_else(Error::invalid)?;
let src = std::str::from_utf8(value.as_bytes()).map_err(|_| Error::invalid())?; let src = std::str::from_utf8(value.as_bytes()).map_err(|_| Error::invalid())?;
@@ -965,6 +979,7 @@ mod extra {
pub const WITH_REJECTION_VALID: &str = "/with_rejection_valid"; pub const WITH_REJECTION_VALID: &str = "/with_rejection_valid";
pub const WITH_REJECTION_VALID_EX: &str = "/with_rejection_valid_ex"; pub const WITH_REJECTION_VALID_EX: &str = "/with_rejection_valid_ex";
} }
pub const PARAMETERS_HEADER: &str = "parameters-header"; pub const PARAMETERS_HEADER: &str = "parameters-header";
pub const CACHED_REJECTION_STATUS: StatusCode = StatusCode::FORBIDDEN; pub const CACHED_REJECTION_STATUS: StatusCode = StatusCode::FORBIDDEN;
@@ -992,8 +1007,8 @@ mod extra {
// 1.3. Implement your extractor (`FromRequestParts` or `FromRequest`) // 1.3. Implement your extractor (`FromRequestParts` or `FromRequest`)
#[axum::async_trait] #[axum::async_trait]
impl<S> FromRequestParts<S> for Parameters impl<S> FromRequestParts<S> for Parameters
where where
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = ParametersRejection; type Rejection = ParametersRejection;
@@ -1008,8 +1023,8 @@ mod extra {
#[axum::async_trait] #[axum::async_trait]
impl<S> FromRequestParts<S> for ParametersEx impl<S> FromRequestParts<S> for ParametersEx
where where
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = ParametersRejection; type Rejection = ParametersRejection;
@@ -1447,3 +1462,32 @@ mod toml {
validate_again_ex(parameters, args.get()) validate_again_ex(parameters, args.get())
} }
} }
#[cfg(feature = "sonic")]
mod sonic {
use super::{
validate_again, validate_again_ex, Parameters, ParametersEx,
ParametersExValidationArguments,
};
use crate::{Arguments, Valid, ValidEx};
use axum::http::StatusCode;
use axum_serde::Sonic;
pub mod route {
pub const SONIC: &str = "/sonic";
pub const SONIC_EX: &str = "/sonic_ex";
}
pub async fn extract_sonic(Valid(Sonic(parameters)): Valid<Sonic<Parameters>>) -> StatusCode {
validate_again(parameters)
}
pub async fn extract_sonic_ex(
ValidEx(Sonic(parameters), args): ValidEx<
Sonic<ParametersEx>,
ParametersExValidationArguments,
>,
) -> StatusCode {
validate_again_ex(parameters, args.get())
}
}

View File

@@ -8,11 +8,10 @@ use axum::extract::{Path, Query};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Form, Json, Router}; use axum::{Form, Json, Router};
use hyper::Method;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
#[cfg(feature = "extra_protobuf")] #[cfg(feature = "extra_protobuf")]
use prost::Message; use prost::Message;
use reqwest::Url; use reqwest::{Method, Url};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::any::type_name; use std::any::type_name;
use std::net::SocketAddr; use std::net::SocketAddr;
@@ -32,8 +31,8 @@ pub struct ParametersValidify {
#[derive(Clone, Validify, Eq, PartialEq)] #[derive(Clone, Validify, Eq, PartialEq)]
#[cfg_attr(feature = "extra_protobuf", derive(Message))] #[cfg_attr(feature = "extra_protobuf", derive(Message))]
#[cfg_attr( #[cfg_attr(
feature = "typed_multipart", feature = "typed_multipart",
derive(axum_typed_multipart::TryFromMultipart) derive(axum_typed_multipart::TryFromMultipart)
)] )]
pub struct ParametersValidifyWithoutPayload { pub struct ParametersValidifyWithoutPayload {
#[validate(range(min = 5.0, max = 10.0))] #[validate(range(min = 5.0, max = 10.0))]
@@ -146,7 +145,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "typed_header")] #[cfg(feature = "typed_header")]
let router = router let router = router
.route( .route(
typed_header::route::TYPED_HEADER, typed_header::route::TYPED_HEADER,
post(typed_header::extract_typed_header), post(typed_header::extract_typed_header),
@@ -161,7 +160,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "typed_multipart")] #[cfg(feature = "typed_multipart")]
let router = router let router = router
.route( .route(
typed_multipart::route::TYPED_MULTIPART, typed_multipart::route::TYPED_MULTIPART,
post(typed_multipart::extract_typed_multipart), post(typed_multipart::extract_typed_multipart),
@@ -188,7 +187,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra")] #[cfg(feature = "extra")]
let router = router let router = router
.route(extra::route::CACHED, post(extra::extract_cached)) .route(extra::route::CACHED, post(extra::extract_cached))
.route( .route(
extra::route::CACHED_MODIFIED, extra::route::CACHED_MODIFIED,
@@ -224,7 +223,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_typed_path")] #[cfg(feature = "extra_typed_path")]
let router = router let router = router
.route( .route(
extra_typed_path::route::EXTRA_TYPED_PATH, extra_typed_path::route::EXTRA_TYPED_PATH,
get(extra_typed_path::extract_extra_typed_path), get(extra_typed_path::extract_extra_typed_path),
@@ -239,7 +238,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_query")] #[cfg(feature = "extra_query")]
let router = router let router = router
.route( .route(
extra_query::route::EXTRA_QUERY, extra_query::route::EXTRA_QUERY,
post(extra_query::extract_extra_query), post(extra_query::extract_extra_query),
@@ -258,7 +257,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_form")] #[cfg(feature = "extra_form")]
let router = router let router = router
.route( .route(
extra_form::route::EXTRA_FORM, extra_form::route::EXTRA_FORM,
post(extra_form::extract_extra_form), post(extra_form::extract_extra_form),
@@ -277,7 +276,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "extra_protobuf")] #[cfg(feature = "extra_protobuf")]
let router = router let router = router
.route( .route(
extra_protobuf::route::EXTRA_PROTOBUF, extra_protobuf::route::EXTRA_PROTOBUF,
post(extra_protobuf::extract_extra_protobuf), post(extra_protobuf::extract_extra_protobuf),
@@ -292,7 +291,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "yaml")] #[cfg(feature = "yaml")]
let router = router let router = router
.route(yaml::route::YAML, post(yaml::extract_yaml)) .route(yaml::route::YAML, post(yaml::extract_yaml))
.route( .route(
yaml::route::YAML_MODIFIED, yaml::route::YAML_MODIFIED,
@@ -308,7 +307,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "msgpack")] #[cfg(feature = "msgpack")]
let router = router let router = router
.route(msgpack::route::MSGPACK, post(msgpack::extract_msgpack)) .route(msgpack::route::MSGPACK, post(msgpack::extract_msgpack))
.route( .route(
msgpack::route::MSGPACK_MODIFIED, msgpack::route::MSGPACK_MODIFIED,
@@ -340,7 +339,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "xml")] #[cfg(feature = "xml")]
let router = router let router = router
.route(xml::route::XML, post(xml::extract_xml)) .route(xml::route::XML, post(xml::extract_xml))
.route(xml::route::XML_MODIFIED, post(xml::extract_xml_modified)) .route(xml::route::XML_MODIFIED, post(xml::extract_xml_modified))
.route( .route(
@@ -353,7 +352,7 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "toml")] #[cfg(feature = "toml")]
let router = router let router = router
.route(toml::route::TOML, post(toml::extract_toml)) .route(toml::route::TOML, post(toml::extract_toml))
.route( .route(
toml::route::TOML_MODIFIED, toml::route::TOML_MODIFIED,
@@ -368,6 +367,22 @@ async fn test_main() -> anyhow::Result<()> {
post(toml::extract_toml_validified_by_ref), post(toml::extract_toml_validified_by_ref),
); );
#[cfg(feature = "sonic")]
let router = router
.route(sonic::route::SONIC, post(sonic::extract_sonic))
.route(
sonic::route::SONIC_MODIFIED,
post(sonic::extract_sonic_modified),
)
.route(
sonic::route::SONIC_VALIDIFIED,
post(sonic::extract_sonic_validified),
)
.route(
sonic::route::SONIC_VALIDIFIED_BY_REF,
post(sonic::extract_sonic_validified_by_ref),
);
let listener = TcpListener::bind(&SocketAddr::from(([0u8, 0, 0, 0], 0u16))).await?; let listener = TcpListener::bind(&SocketAddr::from(([0u8, 0, 0, 0], 0u16))).await?;
let server_addr = listener.local_addr()?; let server_addr = listener.local_addr()?;
let server = axum::serve(listener, router.into_make_service()); let server = axum::serve(listener, router.into_make_service());
@@ -394,7 +409,7 @@ async fn test_main() -> anyhow::Result<()> {
VALIDATION_ERROR_STATUS, VALIDATION_ERROR_STATUS,
true, true,
) )
.await .await
} }
async fn test_extra_path_modified( async fn test_extra_path_modified(
@@ -411,7 +426,7 @@ async fn test_main() -> anyhow::Result<()> {
StatusCode::OK, StatusCode::OK,
false, false,
) )
.await .await
} }
async fn test_extra_path_validified( async fn test_extra_path_validified(
@@ -428,7 +443,7 @@ async fn test_main() -> anyhow::Result<()> {
VALIDATION_ERROR_STATUS, VALIDATION_ERROR_STATUS,
true, true,
) )
.await .await
} }
async fn do_test_extra_path( async fn do_test_extra_path(
@@ -711,7 +726,7 @@ async fn test_main() -> anyhow::Result<()> {
VALIDATION_ERROR_STATUS, VALIDATION_ERROR_STATUS,
true, true,
) )
.await .await
} }
async fn test_extra_typed_path_modified( async fn test_extra_typed_path_modified(
@@ -728,7 +743,7 @@ async fn test_main() -> anyhow::Result<()> {
StatusCode::OK, StatusCode::OK,
false, false,
) )
.await .await
} }
async fn do_test_extra_typed_path( async fn do_test_extra_typed_path(
@@ -789,7 +804,7 @@ async fn test_main() -> anyhow::Result<()> {
extra_typed_path_type_name, extra_typed_path_type_name,
invalid_extra_typed_path_response, invalid_extra_typed_path_response,
) )
.await; .await;
} }
println!("All {} tests passed.", extra_typed_path_type_name); println!("All {} tests passed.", extra_typed_path_type_name);
Ok(()) Ok(())
@@ -803,7 +818,7 @@ async fn test_main() -> anyhow::Result<()> {
"extra_typed_path_validified_by_ref", "extra_typed_path_validified_by_ref",
&server_url, &server_url,
) )
.await?; .await?;
} }
#[cfg(feature = "extra_query")] #[cfg(feature = "extra_query")]
@@ -1020,6 +1035,31 @@ async fn test_main() -> anyhow::Result<()> {
.await?; .await?;
} }
#[cfg(feature = "sonic")]
{
use axum_serde::Sonic;
// Validated
test_executor
.execute::<Sonic<ParametersValidify>>(Method::POST, sonic::route::SONIC)
.await?;
// Modified
test_executor
.execute_modified::<Sonic<ParametersValidify>>(Method::POST, sonic::route::SONIC_MODIFIED)
.await?;
// Validified
test_executor
.execute_validified::<Sonic<ParametersValidify>>(
Method::POST,
sonic::route::SONIC_VALIDIFIED,
)
.await?;
// ValidifiedByRef
test_executor
.execute::<Sonic<ParametersValidify>>(Method::POST, sonic::route::SONIC_VALIDIFIED_BY_REF)
.await?;
}
Ok(()) Ok(())
} }
@@ -1049,7 +1089,7 @@ impl TestExecutor {
T::INVALID_STATUS_CODE, T::INVALID_STATUS_CODE,
true, true,
) )
.await .await
} }
/// Execute all tests for `Modified` without validation /// Execute all tests for `Modified` without validation
@@ -1066,7 +1106,7 @@ impl TestExecutor {
StatusCode::OK, StatusCode::OK,
false, false,
) )
.await .await
} }
/// Execute all tests for `Modified` without validation /// Execute all tests for `Modified` without validation
@@ -1083,7 +1123,7 @@ impl TestExecutor {
T::INVALID_STATUS_CODE, T::INVALID_STATUS_CODE,
false, false,
) )
.await .await
} }
async fn do_execute<T: ValidTest>( async fn do_execute<T: ValidTest>(
@@ -1309,7 +1349,6 @@ fn check_validified<D: IsModified + Validate>(data: &D) -> StatusCode {
#[cfg(feature = "typed_header")] #[cfg(feature = "typed_header")]
mod typed_header { mod typed_header {
pub(crate) mod route { pub(crate) mod route {
pub const TYPED_HEADER: &str = "/typed_header"; pub const TYPED_HEADER: &str = "/typed_header";
pub const TYPED_HEADER_MODIFIED: &str = "/typed_header_modified"; pub const TYPED_HEADER_MODIFIED: &str = "/typed_header_modified";
@@ -1348,9 +1387,9 @@ mod typed_header {
} }
fn decode<'i, I>(values: &mut I) -> Result<Self, Error> fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
where where
Self: Sized, Self: Sized,
I: Iterator<Item = &'i HeaderValue>, I: Iterator<Item=&'i HeaderValue>,
{ {
let value = values.next().ok_or_else(Error::invalid)?; let value = values.next().ok_or_else(Error::invalid)?;
let src = std::str::from_utf8(value.as_bytes()).map_err(|_| Error::invalid())?; let src = std::str::from_utf8(value.as_bytes()).map_err(|_| Error::invalid())?;
@@ -1490,6 +1529,7 @@ mod extra {
pub const WITH_REJECTION_VALIDIFY_VALIDIFIED_BY_REF: &str = pub const WITH_REJECTION_VALIDIFY_VALIDIFIED_BY_REF: &str =
"/with_rejection_validify_validified_by_ref"; "/with_rejection_validify_validified_by_ref";
} }
pub const PARAMETERS_HEADER: &str = "parameters-header"; pub const PARAMETERS_HEADER: &str = "parameters-header";
pub const CACHED_REJECTION_STATUS: StatusCode = StatusCode::FORBIDDEN; pub const CACHED_REJECTION_STATUS: StatusCode = StatusCode::FORBIDDEN;
@@ -1517,8 +1557,8 @@ mod extra {
// 1.3. Implement your extractor (`FromRequestParts` or `FromRequest`) // 1.3. Implement your extractor (`FromRequestParts` or `FromRequest`)
#[axum::async_trait] #[axum::async_trait]
impl<S> FromRequestParts<S> for ParametersValidify impl<S> FromRequestParts<S> for ParametersValidify
where where
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = ParametersRejection; type Rejection = ParametersRejection;
@@ -1969,6 +2009,7 @@ mod msgpack {
) -> StatusCode { ) -> StatusCode {
check_validified(&parameters) check_validified(&parameters)
} }
pub async fn extract_msgpack_raw( pub async fn extract_msgpack_raw(
Validated(MsgPackRaw(parameters)): Validated<MsgPackRaw<ParametersValidify>>, Validated(MsgPackRaw(parameters)): Validated<MsgPackRaw<ParametersValidify>>,
) -> StatusCode { ) -> StatusCode {
@@ -2071,3 +2112,42 @@ mod toml {
check_validified(&parameters) check_validified(&parameters)
} }
} }
#[cfg(feature = "sonic")]
mod sonic {
use super::{check_modified, check_validated, check_validified, ParametersValidify};
use crate::{Modified, Validated, Validified, ValidifiedByRef};
use axum::http::StatusCode;
use axum_serde::Sonic;
pub mod route {
pub const SONIC: &str = "/sonic";
pub const SONIC_MODIFIED: &str = "/sonic_modified";
pub const SONIC_VALIDIFIED: &str = "/sonic_validified";
pub const SONIC_VALIDIFIED_BY_REF: &str = "/sonic_validified_by_ref";
}
pub async fn extract_sonic(
Validated(Sonic(parameters)): Validated<Sonic<ParametersValidify>>,
) -> StatusCode {
check_validated(&parameters)
}
pub async fn extract_sonic_modified(
Modified(Sonic(parameters)): Modified<Sonic<ParametersValidify>>,
) -> StatusCode {
check_modified(&parameters)
}
pub async fn extract_sonic_validified(
Validified(Sonic(parameters)): Validified<Sonic<ParametersValidify>>,
) -> StatusCode {
check_validified(&parameters)
}
pub async fn extract_sonic_validified_by_ref(
ValidifiedByRef(Sonic(parameters)): ValidifiedByRef<Sonic<ParametersValidify>>,
) -> StatusCode {
check_validified(&parameters)
}
}