add support for BaseMultipart

This commit is contained in:
gengteng
2023-09-04 23:40:39 +08:00
parent d1b3c47f8d
commit 2ca106df46
2 changed files with 53 additions and 9 deletions

View File

@@ -74,10 +74,15 @@ async fn test_main() -> anyhow::Result<()> {
); );
#[cfg(feature = "typed_multipart")] #[cfg(feature = "typed_multipart")]
let router = router.route( let router = router
typed_multipart::route::TYPED_MULTIPART, .route(
post(typed_multipart::extract_typed_header), typed_multipart::route::TYPED_MULTIPART,
); post(typed_multipart::extract_typed_multipart),
)
.route(
typed_multipart::route::BASE_MULTIPART,
post(typed_multipart::extract_base_multipart),
);
#[cfg(feature = "extra")] #[cfg(feature = "extra")]
let router = router let router = router
@@ -203,7 +208,13 @@ async fn test_main() -> anyhow::Result<()> {
#[cfg(feature = "typed_multipart")] #[cfg(feature = "typed_multipart")]
{ {
use axum_typed_multipart::TypedMultipart; use axum_typed_multipart::{BaseMultipart, TypedMultipart, TypedMultipartError};
test_executor
.execute::<BaseMultipart<Parameters, TypedMultipartError>>(
Method::POST,
typed_multipart::route::BASE_MULTIPART,
)
.await?;
test_executor test_executor
.execute::<TypedMultipart<Parameters>>( .execute::<TypedMultipart<Parameters>>(
Method::POST, Method::POST,
@@ -469,10 +480,11 @@ mod typed_multipart {
use crate::test::{validate_again, Parameters}; use crate::test::{validate_again, Parameters};
use crate::Valid; use crate::Valid;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum_typed_multipart::TypedMultipart; use axum_typed_multipart::{BaseMultipart, TypedMultipart, TypedMultipartError};
pub mod route { pub mod route {
pub const TYPED_MULTIPART: &str = "/typed_multipart"; pub const TYPED_MULTIPART: &str = "/typed_multipart";
pub const BASE_MULTIPART: &str = "/base_multipart";
} }
impl From<&Parameters> for reqwest::multipart::Form { impl From<&Parameters> for reqwest::multipart::Form {
@@ -483,11 +495,17 @@ mod typed_multipart {
} }
} }
pub(super) async fn extract_typed_header( pub(super) async fn extract_typed_multipart(
Valid(TypedMultipart(parameters)): Valid<TypedMultipart<Parameters>>, Valid(TypedMultipart(parameters)): Valid<TypedMultipart<Parameters>>,
) -> StatusCode { ) -> StatusCode {
validate_again(parameters) validate_again(parameters)
} }
pub(super) async fn extract_base_multipart(
Valid(BaseMultipart { data, .. }): Valid<BaseMultipart<Parameters, TypedMultipartError>>,
) -> StatusCode {
validate_again(data)
}
} }
#[cfg(feature = "extra")] #[cfg(feature = "extra")]

View File

@@ -2,9 +2,16 @@
//! //!
use crate::HasValidate; use crate::HasValidate;
use axum_typed_multipart::TypedMultipart; use axum_typed_multipart::{BaseMultipart, TypedMultipart};
use validator::Validate; use validator::Validate;
impl<T: Validate, R> HasValidate for BaseMultipart<T, R> {
type Validate = T;
fn get_validate(&self) -> &T {
&self.data
}
}
impl<T: Validate> HasValidate for TypedMultipart<T> { impl<T: Validate> HasValidate for TypedMultipart<T> {
type Validate = T; type Validate = T;
fn get_validate(&self) -> &T { fn get_validate(&self) -> &T {
@@ -16,10 +23,29 @@ impl<T: Validate> HasValidate for TypedMultipart<T> {
mod tests { mod tests {
use crate::tests::{ValidTest, ValidTestParameter}; use crate::tests::{ValidTest, ValidTestParameter};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum_typed_multipart::TypedMultipart; use axum_typed_multipart::{BaseMultipart, TypedMultipart};
use reqwest::multipart::Form; use reqwest::multipart::Form;
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
impl<T: ValidTestParameter, R> ValidTest for BaseMultipart<T, R>
where
Form: From<&'static T>,
{
const ERROR_STATUS_CODE: StatusCode = StatusCode::BAD_REQUEST;
fn set_valid_request(builder: RequestBuilder) -> RequestBuilder {
builder.multipart(Form::from(T::valid()))
}
fn set_error_request(builder: RequestBuilder) -> RequestBuilder {
builder.multipart(Form::new())
}
fn set_invalid_request(builder: RequestBuilder) -> RequestBuilder {
builder.multipart(Form::from(T::invalid()))
}
}
impl<T: ValidTestParameter> ValidTest for TypedMultipart<T> impl<T: ValidTestParameter> ValidTest for TypedMultipart<T>
where where
Form: From<&'static T>, Form: From<&'static T>,