diff --git a/src/buffer.rs b/src/buffer.rs index 6672928..914c90a 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,25 +1,86 @@ -use bytes::{Buf, Bytes}; +use bytes::{Buf, BufMut, Bytes}; use crate::{Result, errors::AppError}; -pub fn try_get_string(buf: &mut B) -> Result { - let len = buf.try_get_u16()? as usize; - - if buf.remaining() <= len { - return Err(AppError::IncompleteBuffer); - } - - Ok(String::from_utf8(buf.copy_to_bytes(len).to_vec())?) +pub trait ArchiveBuf { + fn try_get_string(&mut self) -> Result; + fn try_get_bytes(&mut self) -> Result; + fn try_get_bool(&mut self) -> Result; + fn try_get_option(&mut self, f: F) -> Result> + where + T: Sized, + F: FnOnce(&mut B) -> std::result::Result, + AppError: From; } -pub fn try_get_bytes(buf: &mut B) -> Result { - let len = buf.try_get_u32()? as usize; +pub trait ArchiveBufMut { + fn put_bytes_with_length>(&mut self, bytes: T); + fn put_option(&mut self, value: Option, f: F) + where + T: Sized, + F: FnOnce(&mut B, T); +} - if buf.remaining() < len { - return Err(AppError::IncompleteBuffer); +impl ArchiveBuf for B { + fn try_get_string(&mut self) -> Result { + let len = self.try_get_u16()? as usize; + + if self.remaining() < len { + return Err(AppError::IncompleteBuffer); + } + + Ok(String::from_utf8(self.copy_to_bytes(len).to_vec())?) } - let data = buf.copy_to_bytes(len); + fn try_get_bytes(&mut self) -> Result { + let len = self.try_get_u32()? as usize; - Ok(data) + if self.remaining() < len { + return Err(AppError::IncompleteBuffer); + } + + let data = self.copy_to_bytes(len); + + Ok(data) + } + + fn try_get_bool(&mut self) -> Result { + Ok(self.try_get_u8()? == 1) + } + + fn try_get_option(&mut self, f: F) -> Result> + where + T: Sized, + F: FnOnce(&mut B) -> std::result::Result, + AppError: From, + { + if !self.try_get_bool()? { + return Ok(None); + } + + Ok(Some(f(self)?)) + } +} + +impl ArchiveBufMut for B { + fn put_bytes_with_length>(&mut self, bytes: T) { + let bytes = bytes.as_ref(); + + self.put_u32(bytes.len() as u32); + self.put_slice(bytes); + } + + fn put_option(&mut self, value: Option, f: F) + where + T: Sized, + F: FnOnce(&mut B, T), + { + let Some(value) = value else { + self.put_u8(0); + return; + }; + + self.put_u8(1); + f(self, value); + } } diff --git a/src/commands/delete.rs b/src/commands/delete.rs index 5c1e099..6e9b71e 100644 --- a/src/commands/delete.rs +++ b/src/commands/delete.rs @@ -1,8 +1,13 @@ use std::io::Cursor; -use bytes::{BufMut as _, BytesMut}; +use bytes::BytesMut; -use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut}, + connection::Connection, + database::Database, +}; #[derive(Debug, Clone)] pub struct Delete { @@ -14,22 +19,15 @@ impl Delete { let value = db.delete(&self.key).await; let mut buf = BytesMut::new(); - match value { - Some(v) => { - buf.put_u8(1); - buf.put_u32(v.len() as u32); - buf.put_slice(&v); - } - None => buf.put_u8(0), - } + buf.put_option(value, ArchiveBufMut::put_bytes_with_length); connection.write(buf.into()).await?; Ok(()) } - pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key = try_get_string(bytes)?; + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let key = buf.try_get_string()?; Ok(Self { key }) } diff --git a/src/commands/expire.rs b/src/commands/expire.rs index c7f2c39..e214782 100644 --- a/src/commands/expire.rs +++ b/src/commands/expire.rs @@ -2,7 +2,7 @@ use std::io::Cursor; use bytes::{Buf as _, Bytes}; -use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; +use crate::{Result, buffer::ArchiveBuf as _, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Expire { @@ -12,18 +12,18 @@ pub struct Expire { impl Expire { pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { - let value = db.expire(&self.key, self.seconds).await?; + let success = db.expire(&self.key, self.seconds).await?; connection - .write(Bytes::from_static(if value { &[1] } else { &[0] })) + .write(Bytes::from_static(if success { &[1] } else { &[0] })) .await?; Ok(()) } - pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key = try_get_string(bytes)?; - let seconds = bytes.try_get_u64()?; + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let key = buf.try_get_string()?; + let seconds = buf.try_get_u64()?; Ok(Self { key, seconds }) } diff --git a/src/commands/get.rs b/src/commands/get.rs index a04e8b9..e0d6a9e 100644 --- a/src/commands/get.rs +++ b/src/commands/get.rs @@ -1,8 +1,13 @@ use std::io::Cursor; -use bytes::{BufMut as _, BytesMut}; +use bytes::BytesMut; -use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut}, + connection::Connection, + database::Database, +}; #[derive(Debug, Clone)] pub struct Get { @@ -14,24 +19,15 @@ impl Get { let value = db.get(&self.key).await; let mut buf = BytesMut::new(); - match value { - Some(v) => { - buf.put_u8(1); - buf.put_u32(v.len() as u32); - buf.put_slice(&v); - } - None => { - buf.put_u8(0); - } - } + buf.put_option(value, ArchiveBufMut::put_bytes_with_length); connection.write(buf.into()).await?; Ok(()) } - pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key = try_get_string(bytes)?; + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let key = buf.try_get_string()?; Ok(Self { key }) } diff --git a/src/commands/has.rs b/src/commands/has.rs index a30be15..2b801b2 100644 --- a/src/commands/has.rs +++ b/src/commands/has.rs @@ -2,7 +2,7 @@ use std::io::Cursor; use bytes::Bytes; -use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; +use crate::{Result, buffer::ArchiveBuf as _, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Has { @@ -20,8 +20,8 @@ impl Has { Ok(()) } - pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key = try_get_string(bytes)?; + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let key = buf.try_get_string()?; Ok(Self { key }) } diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 1d6d5b6..a334f0b 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -18,7 +18,7 @@ use set::Set; use ttl::Ttl; use crate::{ - Result, buffer::try_get_string, connection::Connection, database::Database, errors::AppError, + Result, buffer::ArchiveBuf as _, connection::Connection, database::Database, errors::AppError, }; #[derive(Debug)] @@ -46,11 +46,12 @@ impl Command { } pub fn parse(bytes: &BytesMut) -> Result<(Self, u64)> { - let mut buffer = Cursor::new(&bytes[..]); + let mut buf = Cursor::new(&bytes[..]); - let name = try_get_string(&mut buffer)?; + let name = buf.try_get_string()?; + println!("Command name: {name}, buf: {buf:?}"); - Self::parse_inner(name, &mut buffer) + Self::parse_inner(name, &mut buf) } fn parse_inner(command_name: String, bytes: &mut Cursor<&[u8]>) -> Result<(Self, u64)> { diff --git a/src/commands/persist.rs b/src/commands/persist.rs index ecf5f45..d7cbd93 100644 --- a/src/commands/persist.rs +++ b/src/commands/persist.rs @@ -2,7 +2,7 @@ use std::io::Cursor; use bytes::Bytes; -use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; +use crate::{Result, buffer::ArchiveBuf as _, connection::Connection, database::Database}; #[derive(Debug, Clone)] pub struct Persist { @@ -20,8 +20,8 @@ impl Persist { Ok(()) } - pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key = try_get_string(bytes)?; + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let key = buf.try_get_string()?; Ok(Self { key }) } diff --git a/src/commands/set.rs b/src/commands/set.rs index 1c31d14..037df04 100644 --- a/src/commands/set.rs +++ b/src/commands/set.rs @@ -1,12 +1,6 @@ use std::io::Cursor; -use crate::{ - Result, - buffer::{try_get_bytes, try_get_string}, - connection::Connection, - database::Database, - errors::AppError, -}; +use crate::{Result, buffer::ArchiveBuf as _, connection::Connection, database::Database}; use bytes::{Buf as _, Bytes}; #[derive(Debug, Clone)] @@ -25,16 +19,12 @@ impl Set { Ok(()) } - pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key = try_get_string(bytes)?; + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let key = buf.try_get_string()?; - let data = try_get_bytes(bytes)?; + let data = buf.try_get_bytes()?; - let expiration: Option = match bytes.try_get_u8()? { - 1 => Some(bytes.try_get_u64()?), - 0 => None, - _ => return Err(AppError::UnexpectedCommandData), - }; + let expiration = buf.try_get_option(Cursor::try_get_u64)?; Ok(Self { key, diff --git a/src/commands/ttl.rs b/src/commands/ttl.rs index 5955590..36b790c 100644 --- a/src/commands/ttl.rs +++ b/src/commands/ttl.rs @@ -1,8 +1,13 @@ use std::io::Cursor; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, BytesMut}; -use crate::{Result, buffer::try_get_string, connection::Connection, database::Database}; +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut as _}, + connection::Connection, + database::Database, +}; #[derive(Debug, Clone)] pub struct Ttl { @@ -13,23 +18,16 @@ impl Ttl { pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { let ttl = db.ttl(&self.key).await; - let Some(ttl) = ttl else { - connection.write(Bytes::from_static(&[0])).await?; - return Ok(()); - }; - let mut buf = BytesMut::new(); - - buf.put_u8(1); - buf.put_u64(ttl); + buf.put_option(ttl, BytesMut::put_u64); connection.write(buf.into()).await?; Ok(()) } - pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { - let key = try_get_string(bytes)?; + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let key = buf.try_get_string()?; Ok(Self { key }) }