From cff5c37a40df57ab64886ab6d31841109ecb89dd Mon Sep 17 00:00:00 2001 From: 409 <409dev@protonmail.com> Date: Tue, 1 Jul 2025 16:59:39 +0200 Subject: [PATCH] refactor client byte write / read --- src/buffer.rs | 33 +++++++- src/client.rs | 162 +++++++++------------------------------- src/commands/delete.rs | 2 +- src/commands/expire.rs | 2 +- src/commands/get.rs | 2 +- src/commands/has.rs | 2 +- src/commands/mod.rs | 3 +- src/commands/persist.rs | 2 +- src/commands/set.rs | 2 +- src/commands/ttl.rs | 2 +- src/errors.rs | 6 +- 11 files changed, 80 insertions(+), 138 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 914c90a..c00c5fe 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -4,6 +4,7 @@ use crate::{Result, errors::AppError}; pub trait ArchiveBuf { fn try_get_string(&mut self) -> Result; + fn try_get_short_string(&mut self) -> Result; fn try_get_bytes(&mut self) -> Result; fn try_get_bool(&mut self) -> Result; fn try_get_option(&mut self, f: F) -> Result> @@ -14,6 +15,8 @@ pub trait ArchiveBuf { } pub trait ArchiveBufMut { + fn put_string(&mut self, s: &str) -> Result<()>; + fn put_short_string(&mut self, s: &str) -> Result<()>; fn put_bytes_with_length>(&mut self, bytes: T); fn put_option(&mut self, value: Option, f: F) where @@ -23,6 +26,16 @@ pub trait ArchiveBufMut { impl ArchiveBuf for B { fn try_get_string(&mut self) -> Result { + let len = self.try_get_u32()? as usize; + + if self.remaining() < len { + return Err(AppError::IncompleteBuffer); + } + + Ok(String::from_utf8(self.copy_to_bytes(len).to_vec())?) + } + + fn try_get_short_string(&mut self) -> Result { let len = self.try_get_u16()? as usize; if self.remaining() < len { @@ -45,7 +58,11 @@ impl ArchiveBuf for B { } fn try_get_bool(&mut self) -> Result { - Ok(self.try_get_u8()? == 1) + Ok(match self.try_get_u8()? { + 1 => true, + 0 => false, + _ => return Err(AppError::UnexpectedData), + }) } fn try_get_option(&mut self, f: F) -> Result> @@ -63,6 +80,20 @@ impl ArchiveBuf for B { } impl ArchiveBufMut for B { + fn put_string(&mut self, s: &str) -> Result<()> { + self.put_u32(s.len().try_into()?); + self.put_slice(s.as_bytes()); + + Ok(()) + } + + fn put_short_string(&mut self, s: &str) -> Result<()> { + self.put_u16(s.len().try_into()?); + self.put_slice(s.as_bytes()); + + Ok(()) + } + fn put_bytes_with_length>(&mut self, bytes: T) { let bytes = bytes.as_ref(); diff --git a/src/client.rs b/src/client.rs index 6a7f49a..fcbfd62 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,12 @@ use bytes::{Buf, BufMut as _, Bytes, BytesMut}; use tokio::net::{TcpStream, ToSocketAddrs}; -use crate::{Result, connection::Connection, errors::AppError}; +use crate::{ + Result, + buffer::{ArchiveBuf, ArchiveBufMut as _}, + connection::Connection, + errors::AppError, +}; pub struct Client { connection: Connection, @@ -19,35 +24,16 @@ impl Client { pub async fn get(&mut self, key: &str) -> Result> { let mut bytes = BytesMut::new(); - bytes.put_u16(3); - bytes.put_slice(b"get"); - - let key_length: u16 = key - .len() - .try_into() - .map_err(|_| AppError::KeyLength(key.len()))?; - - bytes.put_u16(key_length); - bytes.put_slice(key.as_bytes()); + bytes.put_short_string("get")?; + bytes.put_short_string(key)?; self.connection.write(bytes.into()).await?; let mut r = self.get_response().await?; - let response = match r.try_get_u8()? { - 0 => None, - 1 => { - let len = r.try_get_u32()? as usize; - if r.remaining() < len { - return Err(AppError::InvalidCommandResponse); - } + let value = r.try_get_option(ArchiveBuf::try_get_bytes)?; - Some(r.copy_to_bytes(len)) - } - _ => return Err(AppError::InvalidCommandResponse), - }; - - Ok(response) + Ok(value) } pub async fn set( @@ -58,130 +44,74 @@ impl Client { ) -> Result<()> { let mut bytes = BytesMut::new(); - bytes.put_u16(3); - bytes.put_slice(b"set"); + bytes.put_short_string("set")?; + bytes.put_short_string(key)?; - let key_length: u16 = key - .len() - .try_into() - .map_err(|_| AppError::KeyLength(key.len()))?; + bytes.put_bytes_with_length(data); - bytes.put_u16(key_length); - bytes.put_slice(key.as_bytes()); - - bytes.put_u32(data.len() as u32); - bytes.put_slice(data); - - match expiration_secs { - Some(seconds) => { - bytes.put_u8(1); - bytes.put_u64(seconds); - } - None => { - bytes.put_u8(0); - } - } + bytes.put_option(expiration_secs, BytesMut::put_u64); self.connection.write(bytes.into()).await?; let mut r = self.get_response().await?; - match r.try_get_u8()? { - 1 => return Ok(()), - _ => return Err(AppError::InvalidCommandResponse), + if !r.try_get_bool()? { + return Err(AppError::InvalidCommandResponse); } + + Ok(()) } pub async fn delete(&mut self, key: &str) -> Result> { let mut bytes = BytesMut::new(); - bytes.put_u16(6); - bytes.put_slice(b"delete"); - - let key_length: u16 = key - .len() - .try_into() - .map_err(|_| AppError::KeyLength(key.len()))?; - - bytes.put_u16(key_length); - bytes.put_slice(key.as_bytes()); + bytes.put_short_string("delete")?; + bytes.put_short_string(key)?; self.connection.write(bytes.into()).await?; let mut r = self.get_response().await?; - let response = match r.try_get_u8()? { - 1 => { - let len = r.try_get_u32()?; - let bytes = r.copy_to_bytes(len as usize); + let value = r.try_get_option(ArchiveBuf::try_get_bytes)?; - Some(bytes) - } - 0 => None, - _ => return Err(AppError::InvalidCommandResponse), - }; - - Ok(response) + Ok(value) } pub async fn has(&mut self, key: &str) -> Result { let mut bytes = BytesMut::new(); - bytes.put_u16(3); - bytes.put_slice(b"has"); - let key_length: u16 = key - .len() - .try_into() - .map_err(|_| AppError::KeyLength(key.len()))?; - - bytes.put_u16(key_length); - bytes.put_slice(key.as_bytes()); + bytes.put_short_string("has")?; + bytes.put_short_string(key)?; self.connection.write(bytes.into()).await?; let mut r = self.get_response().await?; - Ok(r.try_get_u8()? == 1) + let has = r.try_get_bool()?; + + Ok(has) } pub async fn ttl(&mut self, key: &str) -> Result> { let mut bytes = BytesMut::new(); - bytes.put_u16(3); - bytes.put_slice(b"ttl"); - let key_length: u16 = key - .len() - .try_into() - .map_err(|_| AppError::KeyLength(key.len()))?; - - bytes.put_u16(key_length); - bytes.put_slice(key.as_bytes()); + bytes.put_short_string("ttl")?; + bytes.put_short_string(key)?; self.connection.write(bytes.into()).await?; let mut r = self.get_response().await?; - let ttl = match r.try_get_u8()? { - 1 => Some(r.try_get_u64()?), - 0 => None, - _ => return Err(AppError::InvalidCommandResponse), - }; + let ttl = r.try_get_option(Bytes::try_get_u64)?; Ok(ttl) } pub async fn expire(&mut self, key: &str, seconds: u64) -> Result { let mut bytes = BytesMut::new(); - bytes.put_u16(6); - bytes.put_slice(b"expire"); - let key_length: u16 = key - .len() - .try_into() - .map_err(|_| AppError::KeyLength(key.len()))?; - - bytes.put_u16(key_length); - bytes.put_slice(key.as_bytes()); + bytes.put_short_string("expire")?; + bytes.put_short_string(key)?; bytes.put_u64(seconds); @@ -189,39 +119,19 @@ impl Client { let mut r = self.get_response().await?; - let success = match r.try_get_u8()? { - 1 => true, - 0 => false, - _ => return Err(AppError::InvalidCommandResponse), - }; - - Ok(success) + Ok(r.try_get_bool()?) } pub async fn persist(&mut self, key: &str) -> Result { let mut bytes = BytesMut::new(); - bytes.put_u16(7); - bytes.put_slice(b"persist"); - - let key_length: u16 = key - .len() - .try_into() - .map_err(|_| AppError::KeyLength(key.len()))?; - - bytes.put_u16(key_length); - bytes.put_slice(key.as_bytes()); + bytes.put_short_string("persist")?; + bytes.put_short_string(key)?; self.connection.write(bytes.into()).await?; let mut r = self.get_response().await?; - let success = match r.try_get_u8()? { - 1 => true, - 0 => false, - _ => return Err(AppError::InvalidCommandResponse), - }; - - Ok(success) + Ok(r.try_get_bool()?) } async fn get_response(&mut self) -> Result { diff --git a/src/commands/delete.rs b/src/commands/delete.rs index 6e9b71e..f4dfc64 100644 --- a/src/commands/delete.rs +++ b/src/commands/delete.rs @@ -27,7 +27,7 @@ impl Delete { } pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { - let key = buf.try_get_string()?; + let key = buf.try_get_short_string()?; Ok(Self { key }) } diff --git a/src/commands/expire.rs b/src/commands/expire.rs index e214782..e0598fd 100644 --- a/src/commands/expire.rs +++ b/src/commands/expire.rs @@ -22,7 +22,7 @@ impl Expire { } pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { - let key = buf.try_get_string()?; + let key = buf.try_get_short_string()?; let seconds = buf.try_get_u64()?; Ok(Self { key, seconds }) diff --git a/src/commands/get.rs b/src/commands/get.rs index e0d6a9e..fb462ea 100644 --- a/src/commands/get.rs +++ b/src/commands/get.rs @@ -27,7 +27,7 @@ impl Get { } pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { - let key = buf.try_get_string()?; + let key = buf.try_get_short_string()?; Ok(Self { key }) } diff --git a/src/commands/has.rs b/src/commands/has.rs index 2b801b2..6341b82 100644 --- a/src/commands/has.rs +++ b/src/commands/has.rs @@ -21,7 +21,7 @@ impl Has { } pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { - let key = buf.try_get_string()?; + let key = buf.try_get_short_string()?; Ok(Self { key }) } diff --git a/src/commands/mod.rs b/src/commands/mod.rs index a334f0b..715189b 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -48,8 +48,7 @@ impl Command { pub fn parse(bytes: &BytesMut) -> Result<(Self, u64)> { let mut buf = Cursor::new(&bytes[..]); - let name = buf.try_get_string()?; - println!("Command name: {name}, buf: {buf:?}"); + let name = buf.try_get_short_string()?; Self::parse_inner(name, &mut buf) } diff --git a/src/commands/persist.rs b/src/commands/persist.rs index d7cbd93..c76b179 100644 --- a/src/commands/persist.rs +++ b/src/commands/persist.rs @@ -21,7 +21,7 @@ impl Persist { } pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { - let key = buf.try_get_string()?; + let key = buf.try_get_short_string()?; Ok(Self { key }) } diff --git a/src/commands/set.rs b/src/commands/set.rs index 037df04..11bf6ff 100644 --- a/src/commands/set.rs +++ b/src/commands/set.rs @@ -20,7 +20,7 @@ impl Set { } pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { - let key = buf.try_get_string()?; + let key = buf.try_get_short_string()?; let data = buf.try_get_bytes()?; diff --git a/src/commands/ttl.rs b/src/commands/ttl.rs index 36b790c..c18dfee 100644 --- a/src/commands/ttl.rs +++ b/src/commands/ttl.rs @@ -27,7 +27,7 @@ impl Ttl { } pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { - let key = buf.try_get_string()?; + let key = buf.try_get_short_string()?; Ok(Self { key }) } diff --git a/src/errors.rs b/src/errors.rs index 6036aad..606e5d7 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -18,6 +18,8 @@ pub enum AppError { NoResponse, #[error("Expected a different response for the executed command")] InvalidCommandResponse, - #[error("The binary command data is not structured correctly")] - UnexpectedCommandData, + #[error("The binary data is not structured correctly")] + UnexpectedData, + #[error("Failed to convert integer")] + TryFromInt(#[from] std::num::TryFromIntError), }