diff --git a/src/lib.rs b/src/lib.rs index 0fb3202..f6abfe2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,7 @@ pub mod typed_multipart; pub mod yaml; use axum::async_trait; -use axum::extract::{FromRequest, FromRequestParts}; +use axum::extract::{FromRef, FromRequest, FromRequestParts}; use axum::http::request::Parts; use axum::http::{Request, StatusCode}; use axum::response::{IntoResponse, Response}; @@ -73,44 +73,74 @@ impl Valid { } } +/// Validation context +#[derive(Debug, Copy, Clone)] +pub struct ValidationContext { + /// Validation error response builder + pub response_builder: fn(ValidationErrors) -> Response, +} + +impl Default for ValidationContext { + fn default() -> Self { + fn response_builder(ve: ValidationErrors) -> Response { + #[cfg(feature = "into_json")] + { + (VALIDATION_ERROR_STATUS, axum::Json(ve)).into_response() + } + #[cfg(not(feature = "into_json"))] + { + (VALIDATION_ERROR_STATUS, ve.to_string()).into_response() + } + } + + Self { response_builder } + } +} + +impl FromRef<()> for ValidationContext { + fn from_ref(_: &()) -> Self { + ValidationContext::default() + } +} + /// If the valid extractor fails it'll use this "rejection" type. /// This rejection type can be converted into a response. #[derive(Debug)] -pub enum ValidRejection { +pub enum ValidError { /// Validation errors Valid(ValidationErrors), /// Inner extractor error Inner(E), } -impl Display for ValidRejection { +impl Display for ValidError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - ValidRejection::Valid(errors) => write!(f, "{errors}"), - ValidRejection::Inner(error) => write!(f, "{error}"), + ValidError::Valid(errors) => write!(f, "{errors}"), + ValidError::Inner(error) => write!(f, "{error}"), } } } -impl Error for ValidRejection { +impl Error for ValidError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { - ValidRejection::Valid(ve) => Some(ve), - ValidRejection::Inner(e) => Some(e), + ValidError::Valid(ve) => Some(ve), + ValidError::Inner(e) => Some(e), } } } -impl From for ValidRejection { +impl From for ValidError { fn from(value: ValidationErrors) -> Self { Self::Valid(value) } } -impl IntoResponse for ValidRejection { +impl IntoResponse for ValidError { fn into_response(self) -> Response { match self { - ValidRejection::Valid(validate_error) => { + ValidError::Valid(validate_error) => { #[cfg(feature = "into_json")] { (VALIDATION_ERROR_STATUS, axum::Json(validate_error)).into_response() @@ -120,7 +150,22 @@ impl IntoResponse for ValidRejection { (VALIDATION_ERROR_STATUS, validate_error.to_string()).into_response() } } - ValidRejection::Inner(json_error) => json_error.into_response(), + ValidError::Inner(json_error) => json_error.into_response(), + } + } +} + +/// Validation Rejection +pub struct ValidRejection { + error: ValidError, + response_builder: fn(ValidationErrors) -> Response, +} + +impl IntoResponse for ValidRejection { + fn into_response(self) -> Response { + match self.error { + ValidError::Valid(ve) => (self.response_builder)(ve), + ValidError::Inner(e) => e.into_response(), } } } @@ -143,14 +188,25 @@ where B: Send + Sync + 'static, E: HasValidate + FromRequest, E::Validate: Validate, + ValidationContext: FromRef, { type Rejection = ValidRejection<>::Rejection>; async fn from_request(req: Request, state: &S) -> Result { + let context: ValidationContext = FromRef::from_ref(state); let inner = E::from_request(req, state) .await - .map_err(ValidRejection::Inner)?; - inner.get_validate().validate()?; + .map_err(|e| ValidRejection { + error: ValidError::Inner(e), + response_builder: context.response_builder, + })?; + inner + .get_validate() + .validate() + .map_err(|e| ValidRejection { + error: ValidError::Valid(e), + response_builder: context.response_builder, + })?; Ok(Valid(inner)) } } @@ -161,21 +217,32 @@ where S: Send + Sync + 'static, E: HasValidate + FromRequestParts, E::Validate: Validate, + ValidationContext: FromRef, { type Rejection = ValidRejection<>::Rejection>; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let context: ValidationContext = FromRef::from_ref(state); let inner = E::from_request_parts(parts, state) .await - .map_err(ValidRejection::Inner)?; - inner.get_validate().validate()?; + .map_err(|e| ValidRejection { + error: ValidError::Inner(e), + response_builder: context.response_builder, + })?; + inner + .get_validate() + .validate() + .map_err(|e| ValidRejection { + error: ValidError::Valid(e), + response_builder: context.response_builder, + })?; Ok(Valid(inner)) } } #[cfg(test)] pub mod tests { - use crate::{Valid, ValidRejection}; + use crate::{Valid, ValidError}; use reqwest::{RequestBuilder, StatusCode}; use serde::Serialize; use std::error::Error; @@ -241,24 +308,24 @@ pub mod tests { // ValidRejection::Valid Display let mut ve = ValidationErrors::new(); ve.add(TEST, ValidationError::new(TEST)); - let vr = ValidRejection::::Valid(ve.clone()); + let vr = ValidError::::Valid(ve.clone()); assert_eq!(vr.to_string(), ve.to_string()); // ValidRejection::Inner Display let inner = String::from(TEST); - let vr = ValidRejection::::Inner(inner.clone()); + let vr = ValidError::::Inner(inner.clone()); assert_eq!(inner.to_string(), vr.to_string()); // ValidRejection::Valid Error let mut ve = ValidationErrors::new(); ve.add(TEST, ValidationError::new(TEST)); - let vr = ValidRejection::::Valid(ve.clone()); + let vr = ValidError::::Valid(ve.clone()); assert!( matches!(vr.source(), Some(source) if source.downcast_ref::().is_some()) ); // ValidRejection::Valid Error - let vr = ValidRejection::::Inner(io::Error::new(io::ErrorKind::Other, TEST)); + let vr = ValidError::::Inner(io::Error::new(io::ErrorKind::Other, TEST)); assert!( matches!(vr.source(), Some(source) if source.downcast_ref::().is_some()) ); diff --git a/src/test.rs b/src/test.rs index 81ac9bd..9501e61 100644 --- a/src/test.rs +++ b/src/test.rs @@ -703,16 +703,9 @@ mod extra { impl IntoResponse for WithRejectionValidRejection { fn into_response(self) -> Response { - match self.inner { - ValidRejection::Valid(v) => { - (StatusCode::IM_A_TEAPOT, v.to_string()).into_response() - } - ValidRejection::Inner(i) => { - let mut res = i.into_response(); - *res.status_mut() = StatusCode::IM_A_TEAPOT; - res - } - } + let mut res = self.inner.into_response(); + *res.status_mut() = StatusCode::IM_A_TEAPOT; + res } }