diff --git a/src/lib.rs b/src/lib.rs index e6da0c2..26020f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -128,7 +128,7 @@ impl ValidEx { /// the HTTP status code and response body returned on validation failure. /// #[derive(Debug, Copy, Clone, Default)] -pub struct ValidationArguments { +pub struct ValidationContext { arguments: Arguments, } @@ -143,24 +143,30 @@ fn response_builder(ve: ValidationErrors) -> Response { } } -impl ValidationArguments<()> { +impl ValidationContext<()> { /// Creates a new `ValidationArguments`. - pub fn with_arguments(self, arguments: Arguments) -> ValidationArguments { - ValidationArguments { arguments } + pub fn with_arguments(self, arguments: Arguments) -> ValidationContext { + ValidationContext { arguments } } } -impl ValidationArguments { +impl ValidationContext { /// Creates a `ValidationArguments` with arguments pub fn new(arguments: Arguments) -> Self { Self { arguments } } } -impl FromRef<()> for ValidationArguments<()> { - fn from_ref(_: &()) -> Self { - ValidationArguments::default() - } +/// Arguement store +/// +/// `T`: data type to validate +/// `Self::A`: dependent arguments +/// +pub trait ArgumentsStore<'a, T> { + /// Argument type + type A: 'a; + /// Get dependent arguments + fn get(&'a self) -> Self::A; } /// `ValidRejection` is returned when the `Valid` extractor fails. @@ -226,17 +232,17 @@ pub trait HasValidateArgs<'v> { } #[async_trait] -impl FromRequest for Valid +impl FromRequest for Valid where - S: Send + Sync, - B: Send + Sync + 'static, - E: HasValidate + FromRequest, - E::Validate: Validate, + State: Send + Sync, + Body: Send + Sync + 'static, + Extractor: HasValidate + FromRequest, + Extractor::Validate: Validate, { - type Rejection = ValidRejection<>::Rejection>; + type Rejection = ValidRejection<>::Rejection>; - async fn from_request(req: Request, state: &S) -> Result { - let inner = E::from_request(req, state) + async fn from_request(req: Request, state: &State) -> Result { + let inner = Extractor::from_request(req, state) .await .map_err(ValidRejection::Inner)?; inner.get_validate().validate()?; @@ -245,16 +251,16 @@ where } #[async_trait] -impl FromRequestParts for Valid +impl FromRequestParts for Valid where - S: Send + Sync, - E: HasValidate + FromRequestParts, - E::Validate: Validate, + State: Send + Sync, + Extractor: HasValidate + FromRequestParts, + Extractor::Validate: Validate, { - type Rejection = ValidRejection<>::Rejection>; + type Rejection = ValidRejection<>::Rejection>; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let inner = E::from_request_parts(parts, state) + async fn from_request_parts(parts: &mut Parts, state: &State) -> Result { + let inner = Extractor::from_request_parts(parts, state) .await .map_err(ValidRejection::Inner)?; inner.get_validate().validate()?; @@ -263,46 +269,53 @@ where } #[async_trait] -impl FromRequest for ValidEx +impl FromRequest for ValidEx where - S: Send + Sync, - B: Send + Sync + 'static, - Arguments: Send + Sync, - E: for<'v> HasValidateArgs<'v> + FromRequest, - for<'v> >::ValidateArgs: ValidateArgs<'v, Args = &'v Arguments>, - ValidationArguments: FromRef, + State: Send + Sync, + Body: Send + Sync + 'static, + Store: + Send + Sync + for<'a> ArgumentsStore<'a, >::ValidateArgs>, + Extractor: for<'v> HasValidateArgs<'v> + FromRequest, + for<'v> >::ValidateArgs: ValidateArgs< + 'v, + Args = >::ValidateArgs>>::A, + >, + ValidationContext: FromRef, { - type Rejection = ValidRejection<>::Rejection>; + type Rejection = ValidRejection<>::Rejection>; - async fn from_request(req: Request, state: &S) -> Result { - let ValidationArguments { arguments }: ValidationArguments = - FromRef::from_ref(state); - let inner = E::from_request(req, state) + async fn from_request(req: Request, state: &State) -> Result { + let ValidationContext { arguments }: ValidationContext = FromRef::from_ref(state); + let inner = Extractor::from_request(req, state) .await .map_err(ValidRejection::Inner)?; - inner.get_validate_args().validate_args(&arguments)?; + + inner.get_validate_args().validate_args(arguments.get())?; Ok(ValidEx(inner, arguments)) } } #[async_trait] -impl FromRequestParts for ValidEx +impl FromRequestParts for ValidEx where - S: Send + Sync, - Arguments: Send + Sync, - E: for<'v> HasValidateArgs<'v> + FromRequestParts, - for<'v> >::ValidateArgs: ValidateArgs<'v, Args = &'v Arguments>, - ValidationArguments: FromRef, + State: Send + Sync, + Store: + Send + Sync + for<'a> ArgumentsStore<'a, >::ValidateArgs>, + Extractor: for<'v> HasValidateArgs<'v> + FromRequestParts, + for<'v> >::ValidateArgs: ValidateArgs< + 'v, + Args = >::ValidateArgs>>::A, + >, + ValidationContext: FromRef, { - type Rejection = ValidRejection<>::Rejection>; + type Rejection = ValidRejection<>::Rejection>; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let ValidationArguments { arguments }: ValidationArguments = - FromRef::from_ref(state); - let inner = E::from_request_parts(parts, state) + async fn from_request_parts(parts: &mut Parts, state: &State) -> Result { + let ValidationContext { arguments }: ValidationContext = FromRef::from_ref(state); + let inner = Extractor::from_request_parts(parts, state) .await .map_err(ValidRejection::Inner)?; - inner.get_validate_args().validate_args(&arguments)?; + inner.get_validate_args().validate_args(arguments.get())?; Ok(ValidEx(inner, arguments)) } } diff --git a/src/test.rs b/src/test.rs index 9501e61..dcf3adf 100644 --- a/src/test.rs +++ b/src/test.rs @@ -9,8 +9,8 @@ use reqwest::{StatusCode, Url}; use serde::{Deserialize, Serialize}; use std::any::type_name; use std::net::SocketAddr; -use std::ops::Deref; -use validator::Validate; +use std::ops::{Deref, RangeInclusive}; +use validator::{Validate, ValidationError}; #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] @@ -27,6 +27,48 @@ pub struct Parameters { v1: String, } +#[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] +#[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] +#[cfg_attr( + feature = "typed_multipart", + derive(axum_typed_multipart::TryFromMultipart) +)] +pub struct ParametersEx { + #[validate(custom(function = "validate_v0", arg = "&'v_a RangeInclusive"))] + #[cfg_attr(feature = "extra_protobuf", prost(int32, tag = "1"))] + v0: i32, + #[validate(custom(function = "validate_v1", arg = "&'v_a RangeInclusive"))] + #[cfg_attr(feature = "extra_protobuf", prost(string, tag = "2"))] + v1: String, +} + +fn validate_v0(v: i32, args: &RangeInclusive) -> Result<(), ValidationError> { + args.contains(&v) + .then_some(()) + .ok_or_else(|| ValidationError::new("v0 is out of range")) +} + +fn validate_v1(v: &str, args: &RangeInclusive) -> Result<(), ValidationError> { + args.contains(&v.len()) + .then_some(()) + .ok_or_else(|| ValidationError::new("v1 is invalid")) +} + +#[derive(Debug)] +pub struct ValidationArgs { + v0_range: RangeInclusive, + v1_length_range: RangeInclusive, +} + +impl Default for ValidationArgs { + fn default() -> Self { + Self { + v0_range: 5..=10, + v1_length_range: 1..=10, + } + } +} + static VALID_PARAMETERS: Lazy = Lazy::new(|| Parameters { v0: 5, v1: String::from("0123456789"),