diff --git a/Cargo.toml b/Cargo.toml index dac3911..1f441bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,10 @@ optional = true version = "0.1.0" optional = true +[dependencies.axum_typed_multipart] +version = "0.11.0" +optional = true + [dependencies.serde] version = "1.0.193" optional = true @@ -74,6 +78,7 @@ msgpack = ["dep:axum-serde", "axum-serde/msgpack"] yaml = ["dep:axum-serde", "axum-serde/yaml"] xml = ["dep:axum-serde", "axum-serde/xml"] toml = ["dep:axum-serde", "axum-serde/toml"] +typed_multipart = ["dep:axum_typed_multipart"] into_json = ["json", "dep:serde"] 422 = [] extra = ["dep:axum-extra"] @@ -82,7 +87,7 @@ extra_query = ["extra", "axum-extra/query"] extra_form = ["extra", "axum-extra/form"] extra_protobuf = ["extra", "axum-extra/protobuf"] all_extra_types = ["extra", "typed_header", "extra_typed_path", "extra_query", "extra_form", "extra_protobuf"] -all_types = ["json", "form", "query", "msgpack", "yaml", "xml", "toml", "all_extra_types"] +all_types = ["json", "form", "query", "msgpack", "yaml", "xml", "toml", "all_extra_types", "typed_multipart"] full_validator = ["validator", "all_types", "422", "into_json"] full_garde = ["garde", "all_types", "422", "into_json"] full_validify = ["validify", "all_types", "422", "into_json"] diff --git a/README.md b/README.md index c11ff8b..dd0790e 100644 --- a/README.md +++ b/README.md @@ -440,6 +440,7 @@ Current module documentation predominantly showcases `Valid` examples, the usage | query | Enables support for `Query` | [`query`] | ✅ | ✅ | ✅ | | form | Enables support for `Form` | [`form`] | ✅ | ✅ | ✅ | | typed_header | Enables support for `TypedHeader` from `axum-extra` | [`typed_header`] | ❌ | ✅ | ✅ | +| typed_multipart | Enables support for `TypedMultipart` and `BaseMultipart` from `axum_typed_multipart` | [`typed_multipart`] | ❌ | ✅ | ✅ | | msgpack | Enables support for `MsgPack` and `MsgPackRaw` from `axum-serde` | [`msgpack`] | ❌ | ✅ | ✅ | | yaml | Enables support for `Yaml` from `axum-serde` | [`yaml`] | ❌ | ✅ | ✅ | | xml | Enables support for `Xml` from `axum-serde` | [`xml`] | ❌ | ✅ | ✅ | diff --git a/src/lib.rs b/src/lib.rs index 6637cc0..4a26c34 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,6 +25,8 @@ pub mod yaml; #[cfg(feature = "toml")] pub mod toml; +#[cfg(feature = "typed_multipart")] +pub mod typed_multipart; #[cfg(feature = "xml")] pub mod xml; diff --git a/src/typed_multipart.rs b/src/typed_multipart.rs new file mode 100644 index 0000000..231752f --- /dev/null +++ b/src/typed_multipart.rs @@ -0,0 +1,207 @@ +//! # Support for `TypedMultipart` and `BaseMultipart` from `axum_typed_multipart` +//! +//! ## Feature +//! +//! Enable the `typed_multipart` feature to use `Valid>` and `Valid>`. +//! +//! ## Usage +//! +//! 1. Implement `TryFromMultipart` and `Validate` for your data type `T`. +//! 2. In your handler function, use `Valid>` or `Valid` as some parameter's type. +//! +//! ## Example +//! +//! ```no_run +//! #[cfg(feature = "validator")] +//! mod validator_example { +//! use axum::routing::post; +//! use axum::Router; +//! use axum_typed_multipart::{BaseMultipart, TryFromMultipart, TypedMultipart, TypedMultipartError}; +//! use axum_valid::Valid; +//! use validator::Validate; +//! +//! pub fn router() -> Router { +//! Router::new() +//! .route("/typed_multipart", post(handler)) +//! .route("/base_multipart", post(base_handler)) +//! } +//! +//! async fn handler(Valid(TypedMultipart(parameter)): Valid>) { +//! assert!(parameter.validate().is_ok()); +//! // Support automatic dereferencing +//! println!("v0 = {}, v1 = {}", parameter.v0, parameter.v1); +//! } +//! +//! async fn base_handler( +//! Valid(BaseMultipart { +//! data: parameter, .. +//! }): Valid>, +//! ) { +//! assert!(parameter.validate().is_ok()); +//! } +//! +//! #[derive(TryFromMultipart, Validate)] +//! struct Parameter { +//! #[validate(range(min = 5, max = 10))] +//! v0: i32, +//! #[validate(length(min = 1, max = 10))] +//! v1: String, +//! } +//! } +//! +//! #[cfg(feature = "garde")] +//! mod garde_example { +//! use axum::routing::post; +//! use axum::Router; +//! use axum_typed_multipart::{BaseMultipart, TryFromMultipart, TypedMultipart, TypedMultipartError}; +//! use axum_valid::Garde; +//! use serde::Deserialize; +//! use garde::Validate; +//! +//! pub fn router() -> Router { +//! Router::new() +//! .route("/typed_multipart", post(handler)) +//! .route("/base_multipart", post(base_handler)) +//! } +//! +//! async fn handler(Garde(TypedMultipart(parameter)): Garde>) { +//! assert!(parameter.validate(&()).is_ok()); +//! // Support automatic dereferencing +//! println!("v0 = {}, v1 = {}", parameter.v0, parameter.v1); +//! } +//! +//! async fn base_handler( +//! Garde(BaseMultipart { +//! data: parameter, .. +//! }): Garde>, +//! ) { +//! assert!(parameter.validate(&()).is_ok()); +//! } +//! +//! #[derive(TryFromMultipart, Validate)] +//! pub struct Parameter { +//! #[garde(range(min = 5, max = 10))] +//! pub v0: i32, +//! #[garde(length(min = 1, max = 10))] +//! pub v1: String, +//! } +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! # use std::net::SocketAddr; +//! # use axum::Router; +//! # use tokio::net::TcpListener; +//! # let router = Router::new(); +//! # #[cfg(feature = "validator")] +//! # let router = router.nest("/validator", validator_example::router()); +//! # #[cfg(feature = "garde")] +//! # let router = router.nest("/garde", garde_example::router()); +//! # let listener = TcpListener::bind(&SocketAddr::from(([0u8, 0, 0, 0], 0u16))).await?; +//! # axum::serve(listener, router.into_make_service()) +//! # .await?; +//! # Ok(()) +//! # } +//! ``` + +use crate::HasValidate; +#[cfg(feature = "validator")] +use crate::HasValidateArgs; +use axum_typed_multipart::{BaseMultipart, TypedMultipart}; +#[cfg(feature = "validator")] +use validator::ValidateArgs; + +impl HasValidate for BaseMultipart { + type Validate = T; + fn get_validate(&self) -> &T { + &self.data + } +} + +#[cfg(feature = "validator")] +impl<'v, T: ValidateArgs<'v>, R> HasValidateArgs<'v> for BaseMultipart { + type ValidateArgs = T; + fn get_validate_args(&self) -> &Self::ValidateArgs { + &self.data + } +} + +#[cfg(feature = "validify")] +impl crate::HasModify for BaseMultipart { + type Modify = T; + + fn get_modify(&mut self) -> &mut Self::Modify { + &mut self.data + } +} + +impl HasValidate for TypedMultipart { + type Validate = T; + fn get_validate(&self) -> &T { + &self.0 + } +} + +#[cfg(feature = "validator")] +impl<'v, T: ValidateArgs<'v>> HasValidateArgs<'v> for TypedMultipart { + type ValidateArgs = T; + fn get_validate_args(&self) -> &Self::ValidateArgs { + &self.0 + } +} + +#[cfg(feature = "validify")] +impl crate::HasModify for TypedMultipart { + type Modify = T; + + fn get_modify(&mut self) -> &mut Self::Modify { + &mut self.0 + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{ValidTest, ValidTestParameter}; + use axum::http::StatusCode; + use axum_typed_multipart::{BaseMultipart, TypedMultipart}; + use reqwest::multipart::Form; + use reqwest::RequestBuilder; + + impl ValidTest for BaseMultipart + where + Form: From<&'static T>, + { + const ERROR_STATUS_CODE: StatusCode = StatusCode::BAD_REQUEST; + + fn set_valid_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::from(T::valid())) + } + + fn set_error_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::new()) + } + + fn set_invalid_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::from(T::invalid())) + } + } + + impl ValidTest for TypedMultipart + where + Form: From<&'static T>, + { + const ERROR_STATUS_CODE: StatusCode = StatusCode::BAD_REQUEST; + + fn set_valid_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::from(T::valid())) + } + + fn set_error_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::new()) + } + + fn set_invalid_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::from(T::invalid())) + } + } +} diff --git a/src/validator/test.rs b/src/validator/test.rs index a337a1f..6e966ad 100644 --- a/src/validator/test.rs +++ b/src/validator/test.rs @@ -19,6 +19,10 @@ use validator::{Validate, ValidateArgs, ValidationError}; #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] +#[cfg_attr( + feature = "typed_multipart", + derive(axum_typed_multipart::TryFromMultipart) +)] pub struct Parameters { #[validate(range(min = 5, max = 10))] #[cfg_attr(feature = "extra_protobuf", prost(int32, tag = "1"))] @@ -30,6 +34,10 @@ pub struct Parameters { #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] +#[cfg_attr( + feature = "typed_multipart", + derive(axum_typed_multipart::TryFromMultipart) +)] pub struct ParametersEx { #[validate(custom(function = "validate_v0", arg = "&'v_a RangeInclusive"))] #[cfg_attr(feature = "extra_protobuf", prost(int32, tag = "1"))]