diff --git a/.gitignore b/.gitignore index b8705c1..282dd71 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ /Cargo.lock /.idea /lcov.info +/build_rs_cov.profraw diff --git a/Cargo.toml b/Cargo.toml index 6e83f65..be5923e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "axum-valid" -version = "0.5.1" -description = "Validation tools for axum using the validator library." +version = "0.6.0" +description = "Provide validator extractor for your axum application." authors = ["GengTeng "] license = "MIT" homepage = "https://github.com/gengteng/axum-valid" @@ -9,7 +9,7 @@ repository = "https://github.com/gengteng/axum-valid" keywords = [ "http", "web", - "framework", + "axum", "validator", ] categories = [ @@ -23,6 +23,11 @@ edition = "2021" axum = { version = "0.6.18", default-features = false } validator = "0.16.0" +[dependencies.axum_typed_multipart] +version = "0.9.0" +default-features = false +optional = true + [dependencies.axum-msgpack] version = "0.3.0" default-features = false @@ -43,7 +48,7 @@ anyhow = "1.0.72" axum = { version = "0.6.19" } tokio = { version = "1.29.1", features = ["full"] } hyper = { version = "0.14.27", features = ["full"] } -reqwest = { version = "0.11.18", features = ["json"] } +reqwest = { version = "0.11.18", features = ["json", "multipart"] } serde = { version = "1.0.181", features = ["derive"] } validator = { version = "0.16.0", features = ["derive"] } serde_json = "1.0.104" @@ -59,6 +64,7 @@ json = ["axum/json"] form = ["axum/form"] query = ["axum/query"] typed_header = ["axum/headers"] +typed_multipart = ["axum_typed_multipart"] msgpack = ["axum-msgpack"] yaml = ["axum-yaml"] into_json = ["json"] @@ -68,5 +74,5 @@ extra_query = ["axum-extra/query"] extra_form = ["axum-extra/form"] extra_protobuf = ["axum-extra/protobuf"] all_extra_types = ["extra", "extra_query", "extra_form", "extra_protobuf"] -all_types = ["json", "form", "query", "typed_header", "msgpack", "yaml", "all_extra_types"] +all_types = ["json", "form", "query", "typed_header", "typed_multipart", "msgpack", "yaml", "all_extra_types"] full = ["all_types", "422", "into_json"] diff --git a/README.md b/README.md index 6f69bd6..516a4dd 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ When validation errors occur, the extractor will automatically return 400 with v | query | Enables support for `Query` | ✅ | ✅ | | form | Enables support for `Form` | ✅ | ✅ | | typed_header | Enables support for `TypedHeader` | ❌ | ✅ | +| typed_multipart | Enables support for `TypedMultipart` from `axum_typed_multipart` | ❌ | ✅ | | msgpack | Enables support for `MsgPack` and `MsgPackRaw` from `axum-msgpack` | ❌ | ✅ | | yaml | Enables support for `Yaml` from `axum-yaml` | ❌ | ✅ | | extra | Enables support for `Cached`, `WithRejection` from `axum-extra` | ❌ | ✅ | diff --git a/build_rs_cov.profraw b/build_rs_cov.profraw new file mode 100644 index 0000000..4a1eb66 Binary files /dev/null and b/build_rs_cov.profraw differ diff --git a/src/lib.rs b/src/lib.rs index d2b7812..c493b0c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,8 @@ pub mod query; pub mod test; #[cfg(feature = "typed_header")] pub mod typed_header; +#[cfg(feature = "typed_multipart")] +pub mod typed_multipart; #[cfg(feature = "yaml")] pub mod yaml; diff --git a/src/test.rs b/src/test.rs index 441c409..0d91218 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,6 +14,10 @@ use validator::Validate; #[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] #[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] +#[cfg_attr( + feature = "typed_multipart", + derive(axum_typed_multipart::TryFromMultipart) +)] pub struct Parameters { #[validate(range(min = 5, max = 10))] #[cfg_attr(feature = "extra_protobuf", prost(int32, tag = "1"))] @@ -69,6 +73,12 @@ async fn test_main() -> anyhow::Result<()> { post(typed_header::extract_typed_header), ); + #[cfg(feature = "typed_multipart")] + let router = router.route( + typed_multipart::route::TYPED_MULTIPART, + post(typed_multipart::extract_typed_header), + ); + #[cfg(feature = "extra")] let router = router .route(extra::route::CACHED, post(extra::extract_cached)) @@ -191,6 +201,17 @@ async fn test_main() -> anyhow::Result<()> { .await?; } + #[cfg(feature = "typed_multipart")] + { + use axum_typed_multipart::TypedMultipart; + test_executor + .execute::>( + Method::POST, + typed_multipart::route::TYPED_MULTIPART, + ) + .await?; + } + #[cfg(feature = "extra")] { use axum_extra::extract::{Cached, WithRejection}; @@ -443,6 +464,32 @@ mod typed_header { } } +#[cfg(feature = "typed_multipart")] +mod typed_multipart { + use crate::test::{validate_again, Parameters}; + use crate::Valid; + use axum::http::StatusCode; + use axum_typed_multipart::TypedMultipart; + + pub mod route { + pub const TYPED_MULTIPART: &str = "/typed_multipart"; + } + + impl From<&Parameters> for reqwest::multipart::Form { + fn from(value: &Parameters) -> Self { + reqwest::multipart::Form::new() + .text("v0", value.v0.to_string()) + .text("v1", value.v1.clone()) + } + } + + pub(super) async fn extract_typed_header( + Valid(TypedMultipart(parameters)): Valid>, + ) -> StatusCode { + validate_again(parameters) + } +} + #[cfg(feature = "extra")] mod extra { use crate::test::{validate_again, Parameters}; diff --git a/src/typed_multipart.rs b/src/typed_multipart.rs new file mode 100644 index 0000000..7d2df21 --- /dev/null +++ b/src/typed_multipart.rs @@ -0,0 +1,41 @@ +//! # Implementation of the `HasValidate` trait for the `TypedMultipart` extractor. +//! + +use crate::HasValidate; +use axum_typed_multipart::TypedMultipart; +use validator::Validate; + +impl HasValidate for TypedMultipart { + type Validate = T; + fn get_validate(&self) -> &T { + &self.0 + } +} + +#[cfg(test)] +mod tests { + use crate::tests::{ValidTest, ValidTestParameter}; + use axum::http::StatusCode; + use axum_typed_multipart::TypedMultipart; + use reqwest::multipart::Form; + use reqwest::RequestBuilder; + + impl ValidTest for TypedMultipart + where + Form: From<&'static T>, + { + const ERROR_STATUS_CODE: StatusCode = StatusCode::BAD_REQUEST; + + fn set_valid_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::from(T::valid())) + } + + fn set_error_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::new()) + } + + fn set_invalid_request(builder: RequestBuilder) -> RequestBuilder { + builder.multipart(Form::from(T::invalid())) + } + } +}