diff --git a/src/test.rs b/src/test.rs index 0d91218..a2854aa 100644 --- a/src/test.rs +++ b/src/test.rs @@ -74,10 +74,15 @@ async fn test_main() -> anyhow::Result<()> { ); #[cfg(feature = "typed_multipart")] - let router = router.route( - typed_multipart::route::TYPED_MULTIPART, - post(typed_multipart::extract_typed_header), - ); + let router = router + .route( + typed_multipart::route::TYPED_MULTIPART, + post(typed_multipart::extract_typed_multipart), + ) + .route( + typed_multipart::route::BASE_MULTIPART, + post(typed_multipart::extract_base_multipart), + ); #[cfg(feature = "extra")] let router = router @@ -203,7 +208,13 @@ async fn test_main() -> anyhow::Result<()> { #[cfg(feature = "typed_multipart")] { - use axum_typed_multipart::TypedMultipart; + use axum_typed_multipart::{BaseMultipart, TypedMultipart, TypedMultipartError}; + test_executor + .execute::>( + Method::POST, + typed_multipart::route::BASE_MULTIPART, + ) + .await?; test_executor .execute::>( Method::POST, @@ -469,10 +480,11 @@ mod typed_multipart { use crate::test::{validate_again, Parameters}; use crate::Valid; use axum::http::StatusCode; - use axum_typed_multipart::TypedMultipart; + use axum_typed_multipart::{BaseMultipart, TypedMultipart, TypedMultipartError}; pub mod route { pub const TYPED_MULTIPART: &str = "/typed_multipart"; + pub const BASE_MULTIPART: &str = "/base_multipart"; } impl From<&Parameters> for reqwest::multipart::Form { @@ -483,11 +495,17 @@ mod typed_multipart { } } - pub(super) async fn extract_typed_header( + pub(super) async fn extract_typed_multipart( Valid(TypedMultipart(parameters)): Valid>, ) -> StatusCode { validate_again(parameters) } + + pub(super) async fn extract_base_multipart( + Valid(BaseMultipart { data, .. }): Valid>, + ) -> StatusCode { + validate_again(data) + } } #[cfg(feature = "extra")] diff --git a/src/typed_multipart.rs b/src/typed_multipart.rs index 7d2df21..9675412 100644 --- a/src/typed_multipart.rs +++ b/src/typed_multipart.rs @@ -2,9 +2,16 @@ //! use crate::HasValidate; -use axum_typed_multipart::TypedMultipart; +use axum_typed_multipart::{BaseMultipart, TypedMultipart}; use validator::Validate; +impl HasValidate for BaseMultipart { + type Validate = T; + fn get_validate(&self) -> &T { + &self.data + } +} + impl HasValidate for TypedMultipart { type Validate = T; fn get_validate(&self) -> &T { @@ -16,10 +23,29 @@ impl HasValidate for TypedMultipart { mod tests { use crate::tests::{ValidTest, ValidTestParameter}; use axum::http::StatusCode; - use axum_typed_multipart::TypedMultipart; + use axum_typed_multipart::{BaseMultipart, TypedMultipart}; use reqwest::multipart::Form; use reqwest::RequestBuilder; + impl ValidTest for BaseMultipart + where + Form: From<&'static T>, + { + const ERROR_STATUS_CODE: StatusCode = StatusCode::BAD_REQUEST; + + fn set_valid_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::from(T::valid())) + } + + fn set_error_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::new()) + } + + fn set_invalid_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::from(T::invalid())) + } + } + impl ValidTest for TypedMultipart where Form: From<&'static T>,