diff --git a/src/test.rs b/src/test.rs index dcf3adf..7989f5d 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,5 +1,5 @@ use crate::tests::{ValidTest, ValidTestParameter}; -use crate::{HasValidate, Valid, VALIDATION_ERROR_STATUS}; +use crate::{Arguments, HasValidate, Valid, ValidEx, ValidationContext, VALIDATION_ERROR_STATUS}; use axum::extract::{Path, Query}; use axum::routing::{get, post}; use axum::{Form, Json, Router}; @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; use std::any::type_name; use std::net::SocketAddr; use std::ops::{Deref, RangeInclusive}; -use validator::{Validate, ValidationError}; +use validator::{Validate, ValidateArgs, ValidationError}; #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] @@ -54,13 +54,21 @@ fn validate_v1(v: &str, args: &RangeInclusive) -> Result<(), ValidationEr .ok_or_else(|| ValidationError::new("v1 is invalid")) } -#[derive(Debug)] -pub struct ValidationArgs { +#[derive(Debug, Clone)] +pub struct ParametersExValidationArguments { v0_range: RangeInclusive, v1_length_range: RangeInclusive, } -impl Default for ValidationArgs { +impl<'a> Arguments<'a, ParametersEx> for ParametersExValidationArguments { + type A = >::Args; + + fn get(&'a self) -> Self::A { + (&self.v0_range, &self.v1_length_range) + } +} + +impl Default for ParametersExValidationArguments { fn default() -> Self { Self { v0_range: 5..=10, @@ -109,6 +117,13 @@ async fn test_main() -> anyhow::Result<()> { .route(route::FORM, post(extract_form)) .route(route::JSON, post(extract_json)); + let router_ex = Router::new() + .route(route::QUERY_EX, get(extract_query_ex)) + .route(route::JSON_EX, post(extract_json_ex)) + .with_state(ValidationContext::::default()); + + let router = router.merge(router_ex); + #[cfg(feature = "typed_header")] let router = router.route( typed_header::route::TYPED_HEADER, @@ -234,18 +249,30 @@ async fn test_main() -> anyhow::Result<()> { println!("All {} tests passed.", path_type_name); } + // Valid test_executor .execute::>(Method::GET, route::QUERY) .await?; + // ValidEx + test_executor + .execute::>(Method::GET, route::QUERY_EX) + .await?; + test_executor .execute::>(Method::POST, route::FORM) .await?; + // Valid test_executor .execute::>(Method::POST, route::JSON) .await?; + // ValidEx + test_executor + .execute::>(Method::POST, route::JSON_EX) + .await?; + #[cfg(feature = "typed_header")] { use axum::TypedHeader; @@ -479,8 +506,10 @@ pub async fn check_json(type_name: &'static str, response: reqwest::Response) { mod route { pub const PATH: &str = "/path/:v0/:v1"; pub const QUERY: &str = "/query"; + pub const QUERY_EX: &str = "/query_ex"; pub const FORM: &str = "/form"; pub const JSON: &str = "/json"; + pub const JSON_EX: &str = "/json_ex"; } async fn extract_path(Valid(Path(parameters)): Valid>) -> StatusCode { @@ -491,6 +520,15 @@ async fn extract_query(Valid(Query(parameters)): Valid>) -> St validate_again(parameters) } +async fn extract_query_ex( + ValidEx(Query(parameters), args): ValidEx, ParametersExValidationArguments>, +) -> StatusCode { + match parameters.validate_args(args.get()) { + Ok(_) => StatusCode::OK, + Err(_) => StatusCode::INTERNAL_SERVER_ERROR, + } +} + async fn extract_form(Valid(Form(parameters)): Valid>) -> StatusCode { validate_again(parameters) } @@ -499,6 +537,15 @@ async fn extract_json(Valid(Json(parameters)): Valid>) -> Statu validate_again(parameters) } +async fn extract_json_ex( + ValidEx(Json(parameters), args): ValidEx, ParametersExValidationArguments>, +) -> StatusCode { + match parameters.validate_args(args.get()) { + Ok(_) => StatusCode::OK, + Err(_) => StatusCode::INTERNAL_SERVER_ERROR, + } +} + fn validate_again(validate: V) -> StatusCode { // The `Valid` extractor has validated the `parameters` once, // it should have returned `400 BAD REQUEST` if the `parameters` were invalid,