add ValidationContext and refactor ValidRejection

This commit is contained in:
gengteng
2023-09-23 21:46:10 +08:00
parent 600dfcd92a
commit 0f1e595efe
2 changed files with 91 additions and 31 deletions

View File

@@ -22,7 +22,7 @@ pub mod typed_multipart;
pub mod yaml;
use axum::async_trait;
use axum::extract::{FromRequest, FromRequestParts};
use axum::extract::{FromRef, FromRequest, FromRequestParts};
use axum::http::request::Parts;
use axum::http::{Request, StatusCode};
use axum::response::{IntoResponse, Response};
@@ -73,44 +73,74 @@ impl<E> Valid<E> {
}
}
/// Validation context
#[derive(Debug, Copy, Clone)]
pub struct ValidationContext {
/// Validation error response builder
pub response_builder: fn(ValidationErrors) -> Response,
}
impl Default for ValidationContext {
fn default() -> Self {
fn response_builder(ve: ValidationErrors) -> Response {
#[cfg(feature = "into_json")]
{
(VALIDATION_ERROR_STATUS, axum::Json(ve)).into_response()
}
#[cfg(not(feature = "into_json"))]
{
(VALIDATION_ERROR_STATUS, ve.to_string()).into_response()
}
}
Self { response_builder }
}
}
impl FromRef<()> for ValidationContext {
fn from_ref(_: &()) -> Self {
ValidationContext::default()
}
}
/// If the valid extractor fails it'll use this "rejection" type.
/// This rejection type can be converted into a response.
#[derive(Debug)]
pub enum ValidRejection<E> {
pub enum ValidError<E> {
/// Validation errors
Valid(ValidationErrors),
/// Inner extractor error
Inner(E),
}
impl<E: Display> Display for ValidRejection<E> {
impl<E: Display> Display for ValidError<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ValidRejection::Valid(errors) => write!(f, "{errors}"),
ValidRejection::Inner(error) => write!(f, "{error}"),
ValidError::Valid(errors) => write!(f, "{errors}"),
ValidError::Inner(error) => write!(f, "{error}"),
}
}
}
impl<E: Error + 'static> Error for ValidRejection<E> {
impl<E: Error + 'static> Error for ValidError<E> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ValidRejection::Valid(ve) => Some(ve),
ValidRejection::Inner(e) => Some(e),
ValidError::Valid(ve) => Some(ve),
ValidError::Inner(e) => Some(e),
}
}
}
impl<E> From<ValidationErrors> for ValidRejection<E> {
impl<E> From<ValidationErrors> for ValidError<E> {
fn from(value: ValidationErrors) -> Self {
Self::Valid(value)
}
}
impl<E: IntoResponse> IntoResponse for ValidRejection<E> {
impl<E: IntoResponse> IntoResponse for ValidError<E> {
fn into_response(self) -> Response {
match self {
ValidRejection::Valid(validate_error) => {
ValidError::Valid(validate_error) => {
#[cfg(feature = "into_json")]
{
(VALIDATION_ERROR_STATUS, axum::Json(validate_error)).into_response()
@@ -120,7 +150,22 @@ impl<E: IntoResponse> IntoResponse for ValidRejection<E> {
(VALIDATION_ERROR_STATUS, validate_error.to_string()).into_response()
}
}
ValidRejection::Inner(json_error) => json_error.into_response(),
ValidError::Inner(json_error) => json_error.into_response(),
}
}
}
/// Validation Rejection
pub struct ValidRejection<E> {
error: ValidError<E>,
response_builder: fn(ValidationErrors) -> Response,
}
impl<E: IntoResponse> IntoResponse for ValidRejection<E> {
fn into_response(self) -> Response {
match self.error {
ValidError::Valid(ve) => (self.response_builder)(ve),
ValidError::Inner(e) => e.into_response(),
}
}
}
@@ -143,14 +188,25 @@ where
B: Send + Sync + 'static,
E: HasValidate + FromRequest<S, B>,
E::Validate: Validate,
ValidationContext: FromRef<S>,
{
type Rejection = ValidRejection<<E as FromRequest<S, B>>::Rejection>;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let context: ValidationContext = FromRef::from_ref(state);
let inner = E::from_request(req, state)
.await
.map_err(ValidRejection::Inner)?;
inner.get_validate().validate()?;
.map_err(|e| ValidRejection {
error: ValidError::Inner(e),
response_builder: context.response_builder,
})?;
inner
.get_validate()
.validate()
.map_err(|e| ValidRejection {
error: ValidError::Valid(e),
response_builder: context.response_builder,
})?;
Ok(Valid(inner))
}
}
@@ -161,21 +217,32 @@ where
S: Send + Sync + 'static,
E: HasValidate + FromRequestParts<S>,
E::Validate: Validate,
ValidationContext: FromRef<S>,
{
type Rejection = ValidRejection<<E as FromRequestParts<S>>::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let context: ValidationContext = FromRef::from_ref(state);
let inner = E::from_request_parts(parts, state)
.await
.map_err(ValidRejection::Inner)?;
inner.get_validate().validate()?;
.map_err(|e| ValidRejection {
error: ValidError::Inner(e),
response_builder: context.response_builder,
})?;
inner
.get_validate()
.validate()
.map_err(|e| ValidRejection {
error: ValidError::Valid(e),
response_builder: context.response_builder,
})?;
Ok(Valid(inner))
}
}
#[cfg(test)]
pub mod tests {
use crate::{Valid, ValidRejection};
use crate::{Valid, ValidError};
use reqwest::{RequestBuilder, StatusCode};
use serde::Serialize;
use std::error::Error;
@@ -241,24 +308,24 @@ pub mod tests {
// ValidRejection::Valid Display
let mut ve = ValidationErrors::new();
ve.add(TEST, ValidationError::new(TEST));
let vr = ValidRejection::<String>::Valid(ve.clone());
let vr = ValidError::<String>::Valid(ve.clone());
assert_eq!(vr.to_string(), ve.to_string());
// ValidRejection::Inner Display
let inner = String::from(TEST);
let vr = ValidRejection::<String>::Inner(inner.clone());
let vr = ValidError::<String>::Inner(inner.clone());
assert_eq!(inner.to_string(), vr.to_string());
// ValidRejection::Valid Error
let mut ve = ValidationErrors::new();
ve.add(TEST, ValidationError::new(TEST));
let vr = ValidRejection::<io::Error>::Valid(ve.clone());
let vr = ValidError::<io::Error>::Valid(ve.clone());
assert!(
matches!(vr.source(), Some(source) if source.downcast_ref::<ValidationErrors>().is_some())
);
// ValidRejection::Valid Error
let vr = ValidRejection::<io::Error>::Inner(io::Error::new(io::ErrorKind::Other, TEST));
let vr = ValidError::<io::Error>::Inner(io::Error::new(io::ErrorKind::Other, TEST));
assert!(
matches!(vr.source(), Some(source) if source.downcast_ref::<io::Error>().is_some())
);

View File

@@ -703,16 +703,9 @@ mod extra {
impl<E: IntoResponse> IntoResponse for WithRejectionValidRejection<E> {
fn into_response(self) -> Response {
match self.inner {
ValidRejection::Valid(v) => {
(StatusCode::IM_A_TEAPOT, v.to_string()).into_response()
}
ValidRejection::Inner(i) => {
let mut res = i.into_response();
*res.status_mut() = StatusCode::IM_A_TEAPOT;
res
}
}
let mut res = self.inner.into_response();
*res.status_mut() = StatusCode::IM_A_TEAPOT;
res
}
}