refactor Arguments

This commit is contained in:
gengteng
2023-09-29 10:45:39 +08:00
parent e4f54a5cb9
commit 21d7ce5e4e
2 changed files with 37 additions and 49 deletions

View File

@@ -122,16 +122,6 @@ impl<E, A> ValidEx<E, A> {
} }
} }
/// `ValidationArguments` configures the response returned when validation fails.
///
/// By providing a ValidationArguments to the Valid extractor, you can customize
/// the HTTP status code and response body returned on validation failure.
///
#[derive(Debug, Copy, Clone, Default)]
pub struct ValidationContext<Arguments> {
arguments: Arguments,
}
fn response_builder(ve: ValidationErrors) -> Response { fn response_builder(ve: ValidationErrors) -> Response {
#[cfg(feature = "into_json")] #[cfg(feature = "into_json")]
{ {
@@ -143,30 +133,17 @@ fn response_builder(ve: ValidationErrors) -> Response {
} }
} }
impl ValidationContext<()> { /// `Arguments` provides the validation arguments for the data type `T`.
/// Creates a new `ValidationArguments`.
pub fn with_arguments<Arguments>(self, arguments: Arguments) -> ValidationContext<Arguments> {
ValidationContext { arguments }
}
}
impl<Arguments> ValidationContext<Arguments> {
/// Creates a `ValidationArguments` with arguments
pub fn new(arguments: Arguments) -> Self {
Self { arguments }
}
}
/// # Arguments
/// ///
/// * `T`: The data type to validate using arguments /// This trait has an associated type `T` which represents the data type to
/// validate. `T` must implement the `ValidateArgs` trait which defines the
/// validation logic.
/// ///
pub trait Arguments<'a, T> pub trait Arguments<'a> {
where /// The data type to validate using this arguments
T: ValidateArgs<'a>, type T: ValidateArgs<'a>;
{ /// This method gets the arguments required by `ValidateArgs::validate_args`
/// Get arguments from `self` fn get(&'a self) -> <<Self as Arguments<'a>>::T as ValidateArgs<'a>>::Args;
fn get(&'a self) -> T::Args;
} }
/// `ValidRejection` is returned when the `Valid` extractor fails. /// `ValidRejection` is returned when the `Valid` extractor fails.
@@ -276,15 +253,17 @@ impl<State, Body, Extractor, Args> FromRequest<State, Body> for ValidEx<Extracto
where where
State: Send + Sync, State: Send + Sync,
Body: Send + Sync + 'static, Body: Send + Sync + 'static,
Args: Send + Sync + for<'a> Arguments<'a, <Extractor as HasValidateArgs<'a>>::ValidateArgs>, Args: Send
+ Sync
+ FromRef<State>
+ for<'a> Arguments<'a, T = <Extractor as HasValidateArgs<'a>>::ValidateArgs>,
Extractor: for<'v> HasValidateArgs<'v> + FromRequest<State, Body>, Extractor: for<'v> HasValidateArgs<'v> + FromRequest<State, Body>,
for<'v> <Extractor as HasValidateArgs<'v>>::ValidateArgs: ValidateArgs<'v>, for<'v> <Extractor as HasValidateArgs<'v>>::ValidateArgs: ValidateArgs<'v>,
ValidationContext<Args>: FromRef<State>,
{ {
type Rejection = ValidRejection<<Extractor as FromRequest<State, Body>>::Rejection>; type Rejection = ValidRejection<<Extractor as FromRequest<State, Body>>::Rejection>;
async fn from_request(req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
let ValidationContext { arguments }: ValidationContext<Args> = FromRef::from_ref(state); let arguments: Args = FromRef::from_ref(state);
let inner = Extractor::from_request(req, state) let inner = Extractor::from_request(req, state)
.await .await
.map_err(ValidRejection::Inner)?; .map_err(ValidRejection::Inner)?;
@@ -298,15 +277,17 @@ where
impl<State, Extractor, Args> FromRequestParts<State> for ValidEx<Extractor, Args> impl<State, Extractor, Args> FromRequestParts<State> for ValidEx<Extractor, Args>
where where
State: Send + Sync, State: Send + Sync,
Args: Send + Sync + for<'a> Arguments<'a, <Extractor as HasValidateArgs<'a>>::ValidateArgs>, Args: Send
+ Sync
+ FromRef<State>
+ for<'a> Arguments<'a, T = <Extractor as HasValidateArgs<'a>>::ValidateArgs>,
Extractor: for<'v> HasValidateArgs<'v> + FromRequestParts<State>, Extractor: for<'v> HasValidateArgs<'v> + FromRequestParts<State>,
for<'v> <Extractor as HasValidateArgs<'v>>::ValidateArgs: ValidateArgs<'v>, for<'v> <Extractor as HasValidateArgs<'v>>::ValidateArgs: ValidateArgs<'v>,
ValidationContext<Args>: FromRef<State>,
{ {
type Rejection = ValidRejection<<Extractor as FromRequestParts<State>>::Rejection>; type Rejection = ValidRejection<<Extractor as FromRequestParts<State>>::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &State) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &State) -> Result<Self, Self::Rejection> {
let ValidationContext { arguments }: ValidationContext<Args> = FromRef::from_ref(state); let arguments: Args = FromRef::from_ref(state);
let inner = Extractor::from_request_parts(parts, state) let inner = Extractor::from_request_parts(parts, state)
.await .await
.map_err(ValidRejection::Inner)?; .map_err(ValidRejection::Inner)?;

View File

@@ -1,6 +1,6 @@
use crate::test::extra_typed_path::TypedPathParamExValidationArguments; use crate::test::extra_typed_path::TypedPathParamExValidationArguments;
use crate::tests::{ValidTest, ValidTestParameter}; use crate::tests::{ValidTest, ValidTestParameter};
use crate::{Arguments, HasValidate, Valid, ValidEx, ValidationContext, VALIDATION_ERROR_STATUS}; use crate::{Arguments, HasValidate, Valid, ValidEx, VALIDATION_ERROR_STATUS};
use axum::extract::{FromRef, Path, Query}; use axum::extract::{FromRef, Path, Query};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Form, Json, Router}; use axum::{Form, Json, Router};
@@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
use std::any::type_name; use std::any::type_name;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::ops::{Deref, RangeInclusive}; use std::ops::{Deref, RangeInclusive};
use std::sync::Arc;
use validator::{Validate, ValidateArgs, ValidationError}; use validator::{Validate, ValidateArgs, ValidationError};
#[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)]
@@ -56,18 +57,24 @@ fn validate_v1(v: &str, args: &RangeInclusive<usize>) -> Result<(), ValidationEr
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ParametersExValidationArguments { struct ParametersExValidationArgumentsInner {
v0_range: RangeInclusive<i32>, v0_range: RangeInclusive<i32>,
v1_length_range: RangeInclusive<usize>, v1_length_range: RangeInclusive<usize>,
} }
impl<'a> Arguments<'a, ParametersEx> for ParametersExValidationArguments { #[derive(Debug, Clone, Default)]
pub struct ParametersExValidationArguments {
inner: Arc<ParametersExValidationArgumentsInner>,
}
impl<'a> Arguments<'a> for ParametersExValidationArguments {
type T = ParametersEx;
fn get(&'a self) -> <ParametersEx as ValidateArgs<'a>>::Args { fn get(&'a self) -> <ParametersEx as ValidateArgs<'a>>::Args {
(&self.v0_range, &self.v1_length_range) (&self.inner.v0_range, &self.inner.v1_length_range)
} }
} }
impl Default for ParametersExValidationArguments { impl Default for ParametersExValidationArgumentsInner {
fn default() -> Self { fn default() -> Self {
Self { Self {
v0_range: 5..=10, v0_range: 5..=10,
@@ -110,16 +117,15 @@ impl HasValidate for Parameters {
#[derive(Debug, Clone, FromRef)] #[derive(Debug, Clone, FromRef)]
struct MyState { struct MyState {
param_validation_ctx: ValidationContext<ParametersExValidationArguments>, param_validation_ctx: ParametersExValidationArguments,
typed_path_validation_ctx: ValidationContext<TypedPathParamExValidationArguments>, typed_path_validation_ctx: TypedPathParamExValidationArguments,
} }
#[tokio::test] #[tokio::test]
async fn test_main() -> anyhow::Result<()> { async fn test_main() -> anyhow::Result<()> {
let state = MyState { let state = MyState {
param_validation_ctx: ValidationContext::<ParametersExValidationArguments>::default(), param_validation_ctx: ParametersExValidationArguments::default(),
typed_path_validation_ctx: typed_path_validation_ctx: TypedPathParamExValidationArguments::default(),
ValidationContext::<TypedPathParamExValidationArguments>::default(),
}; };
let router = Router::new() let router = Router::new()
@@ -1068,7 +1074,8 @@ mod extra_typed_path {
} }
} }
impl<'a> Arguments<'a, TypedPathParamEx> for TypedPathParamExValidationArguments { impl<'a> Arguments<'a> for TypedPathParamExValidationArguments {
type T = TypedPathParamEx;
fn get(&'a self) -> <TypedPathParamEx as ValidateArgs<'a>>::Args { fn get(&'a self) -> <TypedPathParamEx as ValidateArgs<'a>>::Args {
(&self.v0_range, &self.v1_length_range) (&self.v0_range, &self.v1_length_range)
} }