diff --git a/Cargo.toml b/Cargo.toml index 6b5ba83..b40e08f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,5 +24,4 @@ edition = "2021" [dependencies] axum = "0.6.18" -serde = "1.0.163" validator = "0.16.0" diff --git a/src/lib.rs b/src/lib.rs index 4e15204..92658f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,9 @@ -use axum::body::HttpBody; use axum::extract::rejection::{FormRejection, JsonRejection, PathRejection, QueryRejection}; use axum::extract::{FromRequest, FromRequestParts, Path, Query}; use axum::http::request::Parts; use axum::http::{Request, StatusCode}; use axum::response::{IntoResponse, Response}; -use axum::{async_trait, BoxError, Form, Json}; -use serde::de::DeserializeOwned; +use axum::{async_trait, Form, Json}; use validator::{Validate, ValidationErrors}; #[derive(Debug, Clone, Copy, Default)] @@ -39,86 +37,95 @@ impl From for ValidRejection { } } -#[async_trait] -impl FromRequest for Valid> -where - T: DeserializeOwned + Validate, - B: HttpBody + Send + 'static, - B::Data: Send, - B::Error: Into, - S: Send + Sync, -{ - type Rejection = ValidRejection; - - async fn from_request(req: Request, state: &S) -> Result { - let json = Json::::from_request(req, state).await?; - json.0.validate()?; - Ok(Valid(json)) - } -} - impl From for ValidRejection { fn from(value: QueryRejection) -> Self { Self::Inner(value) } } -#[async_trait] -impl FromRequestParts for Valid> -where - T: DeserializeOwned + Validate, - S: Send + Sync, -{ - type Rejection = ValidRejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let query = Query::::from_request_parts(parts, state).await?; - query.validate()?; - Ok(Valid(query)) - } -} - impl From for ValidRejection { fn from(value: PathRejection) -> Self { Self::Inner(value) } } -#[async_trait] -impl FromRequestParts for Valid> -where - T: DeserializeOwned + Validate + Send, - S: Send + Sync, -{ - type Rejection = ValidRejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let path = Path::::from_request_parts(parts, state).await?; - path.validate()?; - Ok(Valid(path)) - } -} - impl From for ValidRejection { fn from(value: FormRejection) -> Self { Self::Inner(value) } } -#[async_trait] -impl FromRequest for Valid> -where - T: DeserializeOwned + Validate, - B: HttpBody + Send + 'static, - B::Data: Send, - B::Error: Into, - S: Send + Sync, -{ - type Rejection = ValidRejection; +pub trait Inner0 { + type Inner: Validate; + type Rejection; + fn inner0_ref(&self) -> &Self::Inner; +} - async fn from_request(req: Request, state: &S) -> Result { - let form = Form::::from_request(req, state).await?; - form.validate()?; - Ok(Valid(form)) +impl Inner0 for Json { + type Inner = T; + type Rejection = JsonRejection; + fn inner0_ref(&self) -> &T { + &self.0 + } +} + +impl Inner0 for Form { + type Inner = T; + type Rejection = FormRejection; + fn inner0_ref(&self) -> &T { + &self.0 + } +} + +impl Inner0 for Query { + type Inner = T; + type Rejection = QueryRejection; + fn inner0_ref(&self) -> &T { + &self.0 + } +} + +impl Inner0 for Path { + type Inner = T; + type Rejection = QueryRejection; + fn inner0_ref(&self) -> &T { + &self.0 + } +} + +#[async_trait] +impl FromRequest for Valid +where + S: Send + Sync + 'static, + B: Send + Sync + 'static, + T: Inner0 + FromRequest, + T::Inner: Validate, + ::Rejection: IntoResponse, + ValidRejection<::Rejection>: From<>::Rejection>, +{ + type Rejection = ValidRejection<::Rejection>; + + async fn from_request(req: Request, state: &S) -> Result { + let valid = T::from_request(req, state).await?; + valid.inner0_ref().validate()?; + Ok(Valid(valid)) + } +} + +#[async_trait] +impl FromRequestParts for Valid +where + S: Send + Sync + 'static, + T: Inner0 + FromRequestParts, + T::Inner: Validate, + ::Rejection: IntoResponse, + ValidRejection<::Rejection>: From<>::Rejection>, +{ + type Rejection = ValidRejection<::Rejection>; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let valid = T::from_request_parts(parts, state).await?; + valid.inner0_ref().validate()?; + Ok(Valid(valid)) } }