add ValidationContext and refactor ValidRejection

This commit is contained in:
gengteng
2023-09-23 21:46:10 +08:00
parent 600dfcd92a
commit 0f1e595efe
2 changed files with 91 additions and 31 deletions

View File

@@ -22,7 +22,7 @@ pub mod typed_multipart;
pub mod yaml; pub mod yaml;
use axum::async_trait; use axum::async_trait;
use axum::extract::{FromRequest, FromRequestParts}; use axum::extract::{FromRef, FromRequest, FromRequestParts};
use axum::http::request::Parts; use axum::http::request::Parts;
use axum::http::{Request, StatusCode}; use axum::http::{Request, StatusCode};
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
@@ -73,44 +73,74 @@ impl<E> Valid<E> {
} }
} }
/// Validation context
#[derive(Debug, Copy, Clone)]
pub struct ValidationContext {
/// Validation error response builder
pub response_builder: fn(ValidationErrors) -> Response,
}
impl Default for ValidationContext {
fn default() -> Self {
fn response_builder(ve: ValidationErrors) -> Response {
#[cfg(feature = "into_json")]
{
(VALIDATION_ERROR_STATUS, axum::Json(ve)).into_response()
}
#[cfg(not(feature = "into_json"))]
{
(VALIDATION_ERROR_STATUS, ve.to_string()).into_response()
}
}
Self { response_builder }
}
}
impl FromRef<()> for ValidationContext {
fn from_ref(_: &()) -> Self {
ValidationContext::default()
}
}
/// If the valid extractor fails it'll use this "rejection" type. /// If the valid extractor fails it'll use this "rejection" type.
/// This rejection type can be converted into a response. /// This rejection type can be converted into a response.
#[derive(Debug)] #[derive(Debug)]
pub enum ValidRejection<E> { pub enum ValidError<E> {
/// Validation errors /// Validation errors
Valid(ValidationErrors), Valid(ValidationErrors),
/// Inner extractor error /// Inner extractor error
Inner(E), Inner(E),
} }
impl<E: Display> Display for ValidRejection<E> { impl<E: Display> Display for ValidError<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self { match self {
ValidRejection::Valid(errors) => write!(f, "{errors}"), ValidError::Valid(errors) => write!(f, "{errors}"),
ValidRejection::Inner(error) => write!(f, "{error}"), ValidError::Inner(error) => write!(f, "{error}"),
} }
} }
} }
impl<E: Error + 'static> Error for ValidRejection<E> { impl<E: Error + 'static> Error for ValidError<E> {
fn source(&self) -> Option<&(dyn Error + 'static)> { fn source(&self) -> Option<&(dyn Error + 'static)> {
match self { match self {
ValidRejection::Valid(ve) => Some(ve), ValidError::Valid(ve) => Some(ve),
ValidRejection::Inner(e) => Some(e), ValidError::Inner(e) => Some(e),
} }
} }
} }
impl<E> From<ValidationErrors> for ValidRejection<E> { impl<E> From<ValidationErrors> for ValidError<E> {
fn from(value: ValidationErrors) -> Self { fn from(value: ValidationErrors) -> Self {
Self::Valid(value) Self::Valid(value)
} }
} }
impl<E: IntoResponse> IntoResponse for ValidRejection<E> { impl<E: IntoResponse> IntoResponse for ValidError<E> {
fn into_response(self) -> Response { fn into_response(self) -> Response {
match self { match self {
ValidRejection::Valid(validate_error) => { ValidError::Valid(validate_error) => {
#[cfg(feature = "into_json")] #[cfg(feature = "into_json")]
{ {
(VALIDATION_ERROR_STATUS, axum::Json(validate_error)).into_response() (VALIDATION_ERROR_STATUS, axum::Json(validate_error)).into_response()
@@ -120,7 +150,22 @@ impl<E: IntoResponse> IntoResponse for ValidRejection<E> {
(VALIDATION_ERROR_STATUS, validate_error.to_string()).into_response() (VALIDATION_ERROR_STATUS, validate_error.to_string()).into_response()
} }
} }
ValidRejection::Inner(json_error) => json_error.into_response(), ValidError::Inner(json_error) => json_error.into_response(),
}
}
}
/// Validation Rejection
pub struct ValidRejection<E> {
error: ValidError<E>,
response_builder: fn(ValidationErrors) -> Response,
}
impl<E: IntoResponse> IntoResponse for ValidRejection<E> {
fn into_response(self) -> Response {
match self.error {
ValidError::Valid(ve) => (self.response_builder)(ve),
ValidError::Inner(e) => e.into_response(),
} }
} }
} }
@@ -143,14 +188,25 @@ where
B: Send + Sync + 'static, B: Send + Sync + 'static,
E: HasValidate + FromRequest<S, B>, E: HasValidate + FromRequest<S, B>,
E::Validate: Validate, E::Validate: Validate,
ValidationContext: FromRef<S>,
{ {
type Rejection = ValidRejection<<E as FromRequest<S, B>>::Rejection>; type Rejection = ValidRejection<<E as FromRequest<S, B>>::Rejection>;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let context: ValidationContext = FromRef::from_ref(state);
let inner = E::from_request(req, state) let inner = E::from_request(req, state)
.await .await
.map_err(ValidRejection::Inner)?; .map_err(|e| ValidRejection {
inner.get_validate().validate()?; error: ValidError::Inner(e),
response_builder: context.response_builder,
})?;
inner
.get_validate()
.validate()
.map_err(|e| ValidRejection {
error: ValidError::Valid(e),
response_builder: context.response_builder,
})?;
Ok(Valid(inner)) Ok(Valid(inner))
} }
} }
@@ -161,21 +217,32 @@ where
S: Send + Sync + 'static, S: Send + Sync + 'static,
E: HasValidate + FromRequestParts<S>, E: HasValidate + FromRequestParts<S>,
E::Validate: Validate, E::Validate: Validate,
ValidationContext: FromRef<S>,
{ {
type Rejection = ValidRejection<<E as FromRequestParts<S>>::Rejection>; type Rejection = ValidRejection<<E as FromRequestParts<S>>::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let context: ValidationContext = FromRef::from_ref(state);
let inner = E::from_request_parts(parts, state) let inner = E::from_request_parts(parts, state)
.await .await
.map_err(ValidRejection::Inner)?; .map_err(|e| ValidRejection {
inner.get_validate().validate()?; error: ValidError::Inner(e),
response_builder: context.response_builder,
})?;
inner
.get_validate()
.validate()
.map_err(|e| ValidRejection {
error: ValidError::Valid(e),
response_builder: context.response_builder,
})?;
Ok(Valid(inner)) Ok(Valid(inner))
} }
} }
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use crate::{Valid, ValidRejection}; use crate::{Valid, ValidError};
use reqwest::{RequestBuilder, StatusCode}; use reqwest::{RequestBuilder, StatusCode};
use serde::Serialize; use serde::Serialize;
use std::error::Error; use std::error::Error;
@@ -241,24 +308,24 @@ pub mod tests {
// ValidRejection::Valid Display // ValidRejection::Valid Display
let mut ve = ValidationErrors::new(); let mut ve = ValidationErrors::new();
ve.add(TEST, ValidationError::new(TEST)); ve.add(TEST, ValidationError::new(TEST));
let vr = ValidRejection::<String>::Valid(ve.clone()); let vr = ValidError::<String>::Valid(ve.clone());
assert_eq!(vr.to_string(), ve.to_string()); assert_eq!(vr.to_string(), ve.to_string());
// ValidRejection::Inner Display // ValidRejection::Inner Display
let inner = String::from(TEST); let inner = String::from(TEST);
let vr = ValidRejection::<String>::Inner(inner.clone()); let vr = ValidError::<String>::Inner(inner.clone());
assert_eq!(inner.to_string(), vr.to_string()); assert_eq!(inner.to_string(), vr.to_string());
// ValidRejection::Valid Error // ValidRejection::Valid Error
let mut ve = ValidationErrors::new(); let mut ve = ValidationErrors::new();
ve.add(TEST, ValidationError::new(TEST)); ve.add(TEST, ValidationError::new(TEST));
let vr = ValidRejection::<io::Error>::Valid(ve.clone()); let vr = ValidError::<io::Error>::Valid(ve.clone());
assert!( assert!(
matches!(vr.source(), Some(source) if source.downcast_ref::<ValidationErrors>().is_some()) matches!(vr.source(), Some(source) if source.downcast_ref::<ValidationErrors>().is_some())
); );
// ValidRejection::Valid Error // ValidRejection::Valid Error
let vr = ValidRejection::<io::Error>::Inner(io::Error::new(io::ErrorKind::Other, TEST)); let vr = ValidError::<io::Error>::Inner(io::Error::new(io::ErrorKind::Other, TEST));
assert!( assert!(
matches!(vr.source(), Some(source) if source.downcast_ref::<io::Error>().is_some()) matches!(vr.source(), Some(source) if source.downcast_ref::<io::Error>().is_some())
); );

View File

@@ -703,16 +703,9 @@ mod extra {
impl<E: IntoResponse> IntoResponse for WithRejectionValidRejection<E> { impl<E: IntoResponse> IntoResponse for WithRejectionValidRejection<E> {
fn into_response(self) -> Response { fn into_response(self) -> Response {
match self.inner { let mut res = self.inner.into_response();
ValidRejection::Valid(v) => { *res.status_mut() = StatusCode::IM_A_TEAPOT;
(StatusCode::IM_A_TEAPOT, v.to_string()).into_response() res
}
ValidRejection::Inner(i) => {
let mut res = i.into_response();
*res.status_mut() = StatusCode::IM_A_TEAPOT;
res
}
}
} }
} }