From 8e8af0e0b0e8f80c2b8b9e78219d4b89dc6253d8 Mon Sep 17 00:00:00 2001 From: gengteng Date: Fri, 4 Aug 2023 22:02:02 +0800 Subject: [PATCH] add tests for Protobuf --- Cargo.toml | 4 +- src/extra/protobuf.rs | 24 ++++++ src/test.rs | 175 +++++++++++++++++++++++++++++------------- 3 files changed, 150 insertions(+), 53 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6fb32d9..852935b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,8 @@ serde = { version = "1.0.180", features = ["derive"] } validator = { version = "0.16.0", features = ["derive"] } serde_json = "1.0.104" mime = "0.3.17" +prost = "0.11.9" +once_cell = "1.18.0" [features] default = ["json", "form", "query"] @@ -63,5 +65,5 @@ extra = ["axum-extra"] extra_query = ["axum-extra/query"] extra_form = ["axum-extra/form"] extra_protobuf = ["axum-extra/protobuf"] -extra_all = ["extra","extra_query", "extra_form", "extra_protobuf"] +extra_all = ["extra", "extra_query", "extra_form", "extra_protobuf"] all_types = ["json", "form", "query", "typed_header", "msgpack", "yaml", "extra_all"] diff --git a/src/extra/protobuf.rs b/src/extra/protobuf.rs index 41f07a1..d04ce3b 100644 --- a/src/extra/protobuf.rs +++ b/src/extra/protobuf.rs @@ -11,3 +11,27 @@ impl HasValidate for Protobuf { &self.0 } } + +#[cfg(test)] +mod tests { + use crate::tests::{ValidTest, ValidTestParameter}; + use axum::http::StatusCode; + use axum_extra::protobuf::Protobuf; + use reqwest::RequestBuilder; + + impl ValidTest for Protobuf { + const ERROR_STATUS_CODE: StatusCode = StatusCode::UNPROCESSABLE_ENTITY; + + fn set_valid_request(builder: RequestBuilder) -> RequestBuilder { + builder.body(T::valid().encode_to_vec()) + } + + fn set_error_request(builder: RequestBuilder) -> RequestBuilder { + builder.body("invalid protobuf") + } + + fn set_invalid_request(builder: RequestBuilder) -> RequestBuilder { + builder.body(T::invalid().encode_to_vec()) + } + } +} diff --git a/src/test.rs b/src/test.rs index 253fb5f..49a7c9d 100644 --- a/src/test.rs +++ b/src/test.rs @@ -4,34 +4,38 @@ use axum::extract::{Path, Query}; use axum::routing::{get, post}; use axum::{Form, Json, Router}; use hyper::Method; +use once_cell::sync::Lazy; use reqwest::{StatusCode, Url}; use serde::{Deserialize, Serialize}; use std::any::type_name; -use std::borrow::Cow; use std::net::SocketAddr; +use std::ops::Deref; use validator::Validate; -#[derive(Debug, Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] +#[derive(Clone, Deserialize, Serialize, Validate, Eq, PartialEq)] +#[cfg_attr(feature = "extra_protobuf", derive(prost::Message))] pub struct Parameters { #[validate(range(min = 5, max = 10))] + #[cfg_attr(feature = "extra_protobuf", prost(int32, tag = "1"))] v0: i32, #[validate(length(min = 1, max = 10))] - v1: Cow<'static, str>, + #[cfg_attr(feature = "extra_protobuf", prost(string, tag = "2"))] + v1: String, } -static VALID_PARAMETERS: Parameters = Parameters { +static VALID_PARAMETERS: Lazy = Lazy::new(|| Parameters { v0: 5, - v1: Cow::Borrowed("0123456789"), -}; + v1: String::from("0123456789"), +}); -static INVALID_PARAMETERS: Parameters = Parameters { +static INVALID_PARAMETERS: Lazy = Lazy::new(|| Parameters { v0: 6, - v1: Cow::Borrowed("01234567890"), -}; + v1: String::from("01234567890"), +}); impl ValidTestParameter for Parameters { fn valid() -> &'static Self { - &VALID_PARAMETERS + VALID_PARAMETERS.deref() } fn error() -> &'static [(&'static str, &'static str)] { @@ -39,7 +43,7 @@ impl ValidTestParameter for Parameters { } fn invalid() -> &'static Self { - &INVALID_PARAMETERS + INVALID_PARAMETERS.deref() } } @@ -77,6 +81,12 @@ async fn test_main() -> anyhow::Result<()> { post(extra_form::extract_extra_form), ); + #[cfg(feature = "extra_protobuf")] + let router = router.route( + extra_protobuf::route::EXTRA_PROTOBUF, + post(extra_protobuf::extract_extra_protobuf), + ); + let server = axum::Server::bind(&SocketAddr::from(([0u8, 0, 0, 0], 0u16))) .serve(router.into_make_service()); let server_addr = server.local_addr(); @@ -90,36 +100,53 @@ async fn test_main() -> anyhow::Result<()> { let server_url = format!("http://{}", server_addr); let test_executor = TestExecutor::from(Url::parse(&format!("http://{}", server_addr))?); - // Valid> - let valid_path_response = test_executor - .client() - .get(format!( - "{}/path/{}/{}", - server_url, VALID_PARAMETERS.v0, VALID_PARAMETERS.v1 - )) - .send() - .await?; - assert_eq!(valid_path_response.status(), StatusCode::OK); + { + let path_type_name = type_name::>(); + let valid_path_response = test_executor + .client() + .get(format!( + "{}/path/{}/{}", + server_url, VALID_PARAMETERS.v0, VALID_PARAMETERS.v1 + )) + .send() + .await?; + assert_eq!( + valid_path_response.status(), + StatusCode::OK, + "Valid '{}' test failed.", + path_type_name + ); - let invalid_path_response = test_executor - .client() - .get(format!("{}/path/not_i32/path", server_url)) - .send() - .await?; - assert_eq!(invalid_path_response.status(), StatusCode::BAD_REQUEST); + let error_path_response = test_executor + .client() + .get(format!("{}/path/not_i32/path", server_url)) + .send() + .await?; + assert_eq!( + error_path_response.status(), + StatusCode::BAD_REQUEST, + "Error '{}' test failed.", + path_type_name + ); - let invalid_path_response = test_executor - .client() - .get(format!( - "{}/path/{}/{}", - server_url, INVALID_PARAMETERS.v0, INVALID_PARAMETERS.v1 - )) - .send() - .await?; - assert_eq!(invalid_path_response.status(), VALIDATION_ERROR_STATUS); - #[cfg(feature = "into_json")] - check_json(invalid_path_response).await; - println!("Valid> works."); + let invalid_path_response = test_executor + .client() + .get(format!( + "{}/path/{}/{}", + server_url, INVALID_PARAMETERS.v0, INVALID_PARAMETERS.v1 + )) + .send() + .await?; + assert_eq!( + invalid_path_response.status(), + VALIDATION_ERROR_STATUS, + "Invalid '{}' test failed.", + path_type_name + ); + #[cfg(feature = "into_json")] + check_json(path_type_name, invalid_path_response).await; + println!("All {} tests passed.", path_type_name); + } test_executor .execute::>(Method::GET, route::QUERY) @@ -172,6 +199,14 @@ async fn test_main() -> anyhow::Result<()> { .await?; } + #[cfg(feature = "extra_protobuf")] + { + use axum_extra::protobuf::Protobuf; + test_executor + .execute::>(Method::POST, extra_protobuf::route::EXTRA_PROTOBUF) + .await?; + } + drop(server_guard); server_handle.await??; Ok(()) @@ -201,20 +236,37 @@ impl TestExecutor { url_builder }; + let type_name = type_name::(); + let valid_builder = self.client.request(method.clone(), url.clone()); let valid_response = T::set_valid_request(valid_builder).send().await?; - assert_eq!(valid_response.status(), StatusCode::OK); + assert_eq!( + valid_response.status(), + StatusCode::OK, + "Valid '{}' test failed.", + type_name + ); let error_builder = self.client.request(method.clone(), url.clone()); let error_response = T::set_error_request(error_builder).send().await?; - assert_eq!(error_response.status(), T::ERROR_STATUS_CODE); + assert_eq!( + error_response.status(), + T::ERROR_STATUS_CODE, + "Error '{}' test failed.", + type_name + ); let invalid_builder = self.client.request(method, url); let invalid_response = T::set_invalid_request(invalid_builder).send().await?; - assert_eq!(invalid_response.status(), VALIDATION_ERROR_STATUS); + assert_eq!( + invalid_response.status(), + VALIDATION_ERROR_STATUS, + "Invalid '{}' test failed.", + type_name + ); #[cfg(feature = "into_json")] - check_json(invalid_response).await; - println!("{} works.", type_name::()); + check_json(type_name, invalid_response).await; + println!("All '{}' tests passed.", type_name); Ok(()) } @@ -226,10 +278,12 @@ impl TestExecutor { /// Check if the response is a json response #[cfg(feature = "into_json")] -pub async fn check_json(response: reqwest::Response) { +pub async fn check_json(type_name: &'static str, response: reqwest::Response) { assert_eq!( response.headers()[axum::http::header::CONTENT_TYPE], - axum::http::HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()) + axum::http::HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()), + "'{}' rejection into json test failed", + type_name ); assert!(response.json::().await.is_ok()); } @@ -257,12 +311,12 @@ async fn extract_json(Valid(Json(parameters)): Valid>) -> Statu validate_again(parameters) } -fn validate_again(parameters: Parameters) -> StatusCode { +fn validate_again(validate: V) -> StatusCode { // The `Valid` extractor has validated the `parameters` once, // it should have returned `400 BAD REQUEST` if the `parameters` were invalid, // Let's validate them again to check if the `Valid` extractor works well. // If it works properly, this function will never return `500 INTERNAL SERVER ERROR` - match parameters.validate() { + match validate.validate() { Ok(_) => StatusCode::OK, Err(_) => StatusCode::INTERNAL_SERVER_ERROR, } @@ -280,7 +334,6 @@ mod typed_header { use axum::headers::{Error, Header, HeaderName, HeaderValue}; use axum::http::StatusCode; use axum::TypedHeader; - use std::borrow::Cow; pub static AXUM_VALID_PARAMETERS: HeaderName = HeaderName::from_static("axum-valid-parameters"); @@ -306,7 +359,7 @@ mod typed_header { match split.as_slice() { [v0, v1] => Ok(Parameters { v0: v0.parse().map_err(|_| Error::invalid())?, - v1: Cow::Owned(v1.to_string()), + v1: v1.to_string(), }), _ => Err(Error::invalid()), } @@ -327,7 +380,7 @@ mod typed_header { fn parameter_is_header() -> anyhow::Result<()> { let parameter = Parameters { v0: 123456, - v1: Cow::Owned("111111".to_string()), + v1: "111111".to_string(), }; let mut vec = Vec::new(); parameter.encode(&mut vec); @@ -474,7 +527,7 @@ mod extra_query { } #[cfg(feature = "extra_form")] -pub mod extra_form { +mod extra_form { use crate::test::{validate_again, Parameters}; use crate::Valid; use axum::http::StatusCode; @@ -490,3 +543,21 @@ pub mod extra_form { validate_again(parameters) } } + +#[cfg(feature = "extra_protobuf")] +mod extra_protobuf { + use crate::test::{validate_again, Parameters}; + use crate::Valid; + use axum::http::StatusCode; + use axum_extra::protobuf::Protobuf; + + pub mod route { + pub const EXTRA_PROTOBUF: &str = "/extra_protobuf"; + } + + pub async fn extract_extra_protobuf( + Valid(Protobuf(parameters)): Valid>, + ) -> StatusCode { + validate_again(parameters) + } +}