diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000..6672928 --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,25 @@ +use bytes::{Buf, Bytes}; + +use crate::{Result, errors::AppError}; + +pub fn try_get_string(buf: &mut B) -> Result { + let len = buf.try_get_u16()? as usize; + + if buf.remaining() <= len { + return Err(AppError::IncompleteBuffer); + } + + Ok(String::from_utf8(buf.copy_to_bytes(len).to_vec())?) +} + +pub fn try_get_bytes(buf: &mut B) -> Result { + let len = buf.try_get_u32()? as usize; + + if buf.remaining() < len { + return Err(AppError::IncompleteBuffer); + } + + let data = buf.copy_to_bytes(len); + + Ok(data) +} diff --git a/src/commands/delete.rs b/src/commands/delete.rs index 9c7b269..5c1e099 100644 --- a/src/commands/delete.rs +++ b/src/commands/delete.rs @@ -1,8 +1,8 @@ use std::io::Cursor; -use bytes::{Buf as _, BufMut as _, BytesMut}; +use bytes::{BufMut as _, BytesMut}; -use crate::{Result, connection::Connection, database::Database, errors::AppError}; +use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Delete { @@ -29,13 +29,7 @@ impl Delete { } pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key_length = bytes.try_get_u16()? as usize; - - if bytes.remaining() < key_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; + let key = try_get_string(bytes)?; Ok(Self { key }) } diff --git a/src/commands/expire.rs b/src/commands/expire.rs index ea7363f..c7f2c39 100644 --- a/src/commands/expire.rs +++ b/src/commands/expire.rs @@ -2,7 +2,7 @@ use std::io::Cursor; use bytes::{Buf as _, Bytes}; -use crate::{Result, connection::Connection, database::Database, errors::AppError}; +use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Expire { @@ -22,14 +22,7 @@ impl Expire { } pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key_length = bytes.try_get_u16()? as usize; - - if bytes.remaining() < key_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; - + let key = try_get_string(bytes)?; let seconds = bytes.try_get_u64()?; Ok(Self { key, seconds }) diff --git a/src/commands/get.rs b/src/commands/get.rs index e2dcead..a04e8b9 100644 --- a/src/commands/get.rs +++ b/src/commands/get.rs @@ -1,8 +1,8 @@ use std::io::Cursor; -use bytes::{Buf as _, BufMut as _, BytesMut}; +use bytes::{BufMut as _, BytesMut}; -use crate::{Result, connection::Connection, database::Database, errors::AppError}; +use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Get { @@ -31,13 +31,7 @@ impl Get { } pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key_length = bytes.try_get_u16()? as usize; - - if bytes.remaining() < key_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; + let key = try_get_string(bytes)?; Ok(Self { key }) } diff --git a/src/commands/has.rs b/src/commands/has.rs index e2e442b..a30be15 100644 --- a/src/commands/has.rs +++ b/src/commands/has.rs @@ -1,8 +1,8 @@ use std::io::Cursor; -use bytes::{Buf as _, Bytes}; +use bytes::Bytes; -use crate::{Result, connection::Connection, database::Database, errors::AppError}; +use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Has { @@ -21,13 +21,7 @@ impl Has { } pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key_length = bytes.try_get_u16()? as usize; - - if bytes.remaining() < key_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; + let key = try_get_string(bytes)?; Ok(Self { key }) } diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 9569030..1d6d5b6 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -8,7 +8,7 @@ mod ttl; use std::io::Cursor; -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use delete::Delete; use expire::Expire; use get::Get; @@ -17,7 +17,9 @@ use persist::Persist; use set::Set; use ttl::Ttl; -use crate::{Result, connection::Connection, database::Database, errors::AppError}; +use crate::{ + Result, buffer::try_get_string, connection::Connection, database::Database, errors::AppError, +}; #[derive(Debug)] pub enum Command { @@ -46,13 +48,7 @@ impl Command { pub fn parse(bytes: &BytesMut) -> Result<(Self, u64)> { let mut buffer = Cursor::new(&bytes[..]); - let name_length = buffer.try_get_u16()? as usize; - - if buffer.remaining() < name_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let name = String::from_utf8(buffer.copy_to_bytes(name_length).to_vec())?; + let name = try_get_string(&mut buffer)?; Self::parse_inner(name, &mut buffer) } diff --git a/src/commands/persist.rs b/src/commands/persist.rs index 0288c44..ecf5f45 100644 --- a/src/commands/persist.rs +++ b/src/commands/persist.rs @@ -1,8 +1,8 @@ use std::io::Cursor; -use bytes::{Buf as _, Bytes}; +use bytes::Bytes; -use crate::{Result, connection::Connection, database::Database, errors::AppError}; +use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Persist { @@ -21,13 +21,7 @@ impl Persist { } pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key_length = bytes.try_get_u16()? as usize; - - if bytes.remaining() < key_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; + let key = try_get_string(bytes)?; Ok(Self { key }) } diff --git a/src/commands/set.rs b/src/commands/set.rs index 88ac7a3..1c31d14 100644 --- a/src/commands/set.rs +++ b/src/commands/set.rs @@ -1,6 +1,12 @@ use std::io::Cursor; -use crate::{Result, connection::Connection, database::Database, errors::AppError}; +use crate::{ + Result, + buffer::{try_get_bytes, try_get_string}, + connection::Connection, + database::Database, + errors::AppError, +}; use bytes::{Buf as _, Bytes}; #[derive(Debug, Clone)] @@ -20,21 +26,9 @@ impl Set { } pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key_length = bytes.try_get_u16()? as usize; + let key = try_get_string(bytes)?; - if bytes.remaining() < key_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; - - let value_length = bytes.try_get_u32()? as usize; - - if bytes.remaining() < value_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let data = bytes.copy_to_bytes(value_length); + let data = try_get_bytes(bytes)?; let expiration: Option = match bytes.try_get_u8()? { 1 => Some(bytes.try_get_u64()?), diff --git a/src/commands/ttl.rs b/src/commands/ttl.rs index 1ef2e44..5955590 100644 --- a/src/commands/ttl.rs +++ b/src/commands/ttl.rs @@ -1,8 +1,8 @@ use std::io::Cursor; -use bytes::{Buf as _, BufMut, Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; -use crate::{Result, connection::Connection, database::Database, errors::AppError}; +use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Ttl { @@ -29,13 +29,7 @@ impl Ttl { } pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key_length = bytes.try_get_u16()? as usize; - - if bytes.remaining() < key_length { - return Err(AppError::IncompleteCommandBuffer); - } - - let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; + let key = try_get_string(bytes)?; Ok(Self { key }) } diff --git a/src/errors.rs b/src/errors.rs index b720822..6036aad 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -6,8 +6,8 @@ pub enum AppError { Io(#[from] std::io::Error), #[error("A TryGetError occurred")] TryGet(#[from] bytes::TryGetError), - #[error("The buffer is missing data for a complete command")] - IncompleteCommandBuffer, + #[error("The buffer is missing data")] + IncompleteBuffer, #[error("A Utf8Error occurred")] FromUtf8(#[from] std::string::FromUtf8Error), #[error("The command {0} was not recognized")] diff --git a/src/lib.rs b/src/lib.rs index bc6e8f5..318c2e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use errors::AppError; #[cfg(test)] pub mod tests; +pub mod buffer; pub mod client; pub mod commands; pub mod config;