From c527bb00728bf35deff34b5092213e62c49fa627 Mon Sep 17 00:00:00 2001 From: 409 <409dev@protonmail.com> Date: Tue, 1 Jul 2025 18:33:13 +0200 Subject: [PATCH] mset / mget commands (cli currently only supports mget) --- src/bin/cli.rs | 14 +++++++ src/buffer.rs | 47 +++++++++++++++++++++ src/client.rs | 93 ++++++++++++++++++++++++++++++----------- src/commands/delete.rs | 13 ++++++ src/commands/expire.rs | 27 +++++++++++- src/commands/get.rs | 15 ++++++- src/commands/has.rs | 22 +++++++++- src/commands/m_get.rs | 60 ++++++++++++++++++++++++++ src/commands/m_set.rs | 49 ++++++++++++++++++++++ src/commands/mod.rs | 26 ++++++++---- src/commands/persist.rs | 22 +++++++++- src/commands/set.rs | 39 ++++++++++++++--- src/commands/ttl.rs | 16 +++++++ src/tests.rs | 48 +++++++++++++++++++++ 14 files changed, 448 insertions(+), 43 deletions(-) create mode 100644 src/commands/m_get.rs create mode 100644 src/commands/m_set.rs diff --git a/src/bin/cli.rs b/src/bin/cli.rs index e912f84..0113c20 100644 --- a/src/bin/cli.rs +++ b/src/bin/cli.rs @@ -39,6 +39,13 @@ enum Commands { Persist { key: String, }, + + #[command(name = "mget")] + MGet { + #[arg(num_args = 1..)] + keys: Vec, + }, + #[command(aliases = &["exit", "q"])] Quit, } @@ -108,6 +115,13 @@ async fn main() -> Result<()> { let value = client.persist(&key).await?; println!("{value:?}"); } + + Commands::MGet { keys } => { + let value = client.m_get(keys).await?; + + println!("{value:?}"); + } + Commands::Quit => break, } } diff --git a/src/buffer.rs b/src/buffer.rs index c00c5fe..f76444c 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,3 +1,5 @@ +use std::num::TryFromIntError; + use bytes::{Buf, BufMut, Bytes}; use crate::{Result, errors::AppError}; @@ -12,6 +14,11 @@ pub trait ArchiveBuf { T: Sized, F: FnOnce(&mut B) -> std::result::Result, AppError: From; + fn try_get_vec(&mut self, f: F) -> Result> + where + T: Sized, + F: FnMut(&mut B) -> std::result::Result, + AppError: From; } pub trait ArchiveBufMut { @@ -22,6 +29,12 @@ pub trait ArchiveBufMut { where T: Sized, F: FnOnce(&mut B, T); + fn try_put_vec(&mut self, value: V, f: F) -> Result<()> + where + T: Sized, + F: FnMut(&T, &mut B) -> std::result::Result<(), E>, + V: AsRef<[T]>, + AppError: From + From; } impl ArchiveBuf for B { @@ -77,6 +90,23 @@ impl ArchiveBuf for B { Ok(Some(f(self)?)) } + + fn try_get_vec(&mut self, mut f: F) -> Result> + where + T: Sized, + F: FnMut(&mut B) -> std::result::Result, + AppError: From, + { + let len = self.try_get_u16()?; + + let mut vec = Vec::with_capacity(len.into()); + + for _ in 0..len { + vec.push(f(self)?); + } + + Ok(vec) + } } impl ArchiveBufMut for B { @@ -114,4 +144,21 @@ impl ArchiveBufMut for B { self.put_u8(1); f(self, value); } + + fn try_put_vec(&mut self, vec: V, mut f: F) -> Result<()> + where + T: Sized, + F: FnMut(&T, &mut B) -> std::result::Result<(), E>, + V: AsRef<[T]>, + AppError: From + From, + { + let vec = vec.as_ref(); + self.put_u16(vec.len().try_into()?); + + for element in vec { + f(element, self)?; + } + + Ok(()) + } } diff --git a/src/client.rs b/src/client.rs index fcbfd62..ae9b304 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,9 +1,13 @@ -use bytes::{Buf, BufMut as _, Bytes, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; use tokio::net::{TcpStream, ToSocketAddrs}; use crate::{ Result, - buffer::{ArchiveBuf, ArchiveBufMut as _}, + buffer::ArchiveBuf, + commands::{ + delete::Delete, expire::Expire, get::Get, has::Has, m_get::MGet, m_set::MSet, + persist::Persist, set::Set, ttl::Ttl, + }, connection::Connection, errors::AppError, }; @@ -24,8 +28,8 @@ impl Client { pub async fn get(&mut self, key: &str) -> Result> { let mut bytes = BytesMut::new(); - bytes.put_short_string("get")?; - bytes.put_short_string(key)?; + let cmd = Get::new(key.to_owned()); + cmd.put(&mut bytes)?; self.connection.write(bytes.into()).await?; @@ -42,16 +46,12 @@ impl Client { data: &[u8], expiration_secs: Option, ) -> Result<()> { - let mut bytes = BytesMut::new(); + let mut buf = BytesMut::new(); - bytes.put_short_string("set")?; - bytes.put_short_string(key)?; + let cmd = Set::new(key.to_owned(), data.into(), expiration_secs); + cmd.put(&mut buf)?; - bytes.put_bytes_with_length(data); - - bytes.put_option(expiration_secs, BytesMut::put_u64); - - self.connection.write(bytes.into()).await?; + self.connection.write(buf.into()).await?; let mut r = self.get_response().await?; @@ -65,8 +65,8 @@ impl Client { pub async fn delete(&mut self, key: &str) -> Result> { let mut bytes = BytesMut::new(); - bytes.put_short_string("delete")?; - bytes.put_short_string(key)?; + let cmd = Delete::new(key.to_owned()); + cmd.put(&mut bytes)?; self.connection.write(bytes.into()).await?; @@ -80,8 +80,8 @@ impl Client { pub async fn has(&mut self, key: &str) -> Result { let mut bytes = BytesMut::new(); - bytes.put_short_string("has")?; - bytes.put_short_string(key)?; + let cmd = Has::new(key.to_owned()); + cmd.put(&mut bytes)?; self.connection.write(bytes.into()).await?; @@ -95,8 +95,8 @@ impl Client { pub async fn ttl(&mut self, key: &str) -> Result> { let mut bytes = BytesMut::new(); - bytes.put_short_string("ttl")?; - bytes.put_short_string(key)?; + let cmd = Ttl::new(key.to_owned()); + cmd.put(&mut bytes)?; self.connection.write(bytes.into()).await?; @@ -110,10 +110,8 @@ impl Client { pub async fn expire(&mut self, key: &str, seconds: u64) -> Result { let mut bytes = BytesMut::new(); - bytes.put_short_string("expire")?; - bytes.put_short_string(key)?; - - bytes.put_u64(seconds); + let cmd = Expire::new(key.to_owned(), seconds); + cmd.put(&mut bytes)?; self.connection.write(bytes.into()).await?; @@ -124,8 +122,9 @@ impl Client { pub async fn persist(&mut self, key: &str) -> Result { let mut bytes = BytesMut::new(); - bytes.put_short_string("persist")?; - bytes.put_short_string(key)?; + + let cmd = Persist::new(key.to_owned()); + cmd.put(&mut bytes)?; self.connection.write(bytes.into()).await?; @@ -134,6 +133,52 @@ impl Client { Ok(r.try_get_bool()?) } + pub async fn m_set( + &mut self, + keys: Vec<&str>, + data: Vec<&[u8]>, + expirations: Vec>, + ) -> Result<()> { + let mut bytes = BytesMut::new(); + + let len = keys.len().min(data.len()).min(expirations.len()); + + let mut sets = Vec::with_capacity(len); + + for i in 0..len { + sets.push(Set::new(keys[i].to_owned(), data[i].into(), expirations[i])); + } + + let cmd = MSet::new(sets); + cmd.put(&mut bytes)?; + + self.connection.write(bytes.into()).await?; + + let mut r = self.get_response().await?; + + if !r.try_get_bool()? { + return Err(AppError::InvalidCommandResponse); + } + + Ok(()) + } + + pub async fn m_get(&mut self, keys: Vec) -> Result>> { + let mut bytes = BytesMut::new(); + + let gets: Vec = keys.into_iter().map(Get::new).collect(); + let cmd = MGet::new(gets); + cmd.put(&mut bytes)?; + + self.connection.write(bytes.into()).await?; + + let mut r = self.get_response().await?; + + let values = r.try_get_vec(|b| b.try_get_option(ArchiveBuf::try_get_bytes))?; + + Ok(values) + } + async fn get_response(&mut self) -> Result { self.connection .read_bytes() diff --git a/src/commands/delete.rs b/src/commands/delete.rs index f4dfc64..656dcef 100644 --- a/src/commands/delete.rs +++ b/src/commands/delete.rs @@ -15,6 +15,10 @@ pub struct Delete { } impl Delete { + pub fn new(key: String) -> Self { + Self { key } + } + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { let value = db.delete(&self.key).await; @@ -31,4 +35,13 @@ impl Delete { Ok(Self { key }) } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("delete")?; + self.put_without_cmd_name(buf) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string(&self.key) + } } diff --git a/src/commands/expire.rs b/src/commands/expire.rs index e0598fd..4d138a9 100644 --- a/src/commands/expire.rs +++ b/src/commands/expire.rs @@ -1,8 +1,13 @@ use std::io::Cursor; -use bytes::{Buf as _, Bytes}; +use bytes::{Buf as _, BufMut as _, Bytes, BytesMut}; -use crate::{Result, buffer::ArchiveBuf as _, connection::Connection, database::Database}; +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut as _}, + connection::Connection, + database::Database, +}; #[derive(Debug, Clone)] pub struct Expire { @@ -11,6 +16,10 @@ pub struct Expire { } impl Expire { + pub fn new(key: String, seconds: u64) -> Self { + Self { key, seconds } + } + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { let success = db.expire(&self.key, self.seconds).await?; @@ -27,4 +36,18 @@ impl Expire { Ok(Self { key, seconds }) } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("expire")?; + + self.put_without_cmd_name(buf) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string(&self.key)?; + + buf.put_u64(self.seconds); + + Ok(()) + } } diff --git a/src/commands/get.rs b/src/commands/get.rs index fb462ea..7083398 100644 --- a/src/commands/get.rs +++ b/src/commands/get.rs @@ -11,10 +11,14 @@ use crate::{ #[derive(Debug, Clone)] pub struct Get { - key: String, + pub(super) key: String, } impl Get { + pub fn new(key: String) -> Self { + Self { key } + } + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { let value = db.get(&self.key).await; @@ -31,4 +35,13 @@ impl Get { Ok(Self { key }) } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("get")?; + self.put_without_cmd_name(buf) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string(&self.key) + } } diff --git a/src/commands/has.rs b/src/commands/has.rs index 6341b82..1ae7d54 100644 --- a/src/commands/has.rs +++ b/src/commands/has.rs @@ -1,8 +1,13 @@ use std::io::Cursor; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; -use crate::{Result, buffer::ArchiveBuf as _, connection::Connection, database::Database}; +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut as _}, + connection::Connection, + database::Database, +}; #[derive(Debug, Clone)] pub struct Has { @@ -10,6 +15,10 @@ pub struct Has { } impl Has { + pub fn new(key: String) -> Self { + Self { key } + } + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { let value = db.has(&self.key).await; @@ -25,4 +34,13 @@ impl Has { Ok(Self { key }) } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("has")?; + self.put_without_cmd_name(buf) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string(&self.key) + } } diff --git a/src/commands/m_get.rs b/src/commands/m_get.rs new file mode 100644 index 0000000..030fc94 --- /dev/null +++ b/src/commands/m_get.rs @@ -0,0 +1,60 @@ +use std::io::Cursor; + +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut}, + connection::Connection, + database::Database, + errors::AppError, +}; +use bytes::BytesMut; + +use super::get::Get; + +#[derive(Debug, Clone)] +pub struct MGet { + gets: Vec, +} + +impl MGet { + pub fn new(gets: Vec) -> Self { + Self { gets } + } + + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { + let mut values = Vec::with_capacity(self.gets.len()); + + for get in self.gets { + values.push(db.get(&get.key).await); + } + + let mut buf = BytesMut::new(); + + buf.try_put_vec(values, |data, buf| { + buf.put_option(data.as_deref(), ArchiveBufMut::put_bytes_with_length); + + Ok::<(), AppError>(()) + })?; + + connection.write(buf.into()).await?; + + Ok(()) + } + + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let gets = buf.try_get_vec(Get::parse)?; + + Ok(Self { gets }) + } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("mget")?; + self.put_without_cmd_name(buf) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.try_put_vec(&self.gets, Get::put_without_cmd_name)?; + + Ok(()) + } +} diff --git a/src/commands/m_set.rs b/src/commands/m_set.rs new file mode 100644 index 0000000..ec715fa --- /dev/null +++ b/src/commands/m_set.rs @@ -0,0 +1,49 @@ +use std::io::Cursor; + +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut as _}, + connection::Connection, + database::Database, +}; +use bytes::{Bytes, BytesMut}; + +use super::set::Set; + +#[derive(Debug, Clone)] +pub struct MSet { + sets: Vec, +} + +impl MSet { + pub fn new(sets: Vec) -> Self { + Self { sets } + } + + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { + for set in self.sets { + db.set(set.key, set.data, set.expiration).await?; + } + + connection.write(Bytes::from_static(&[1])).await?; + + Ok(()) + } + + pub fn parse(buf: &mut Cursor<&[u8]>) -> Result { + let sets = buf.try_get_vec(Set::parse)?; + + Ok(Self { sets }) + } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("mset")?; + self.put_without_cmd_name(buf) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.try_put_vec(&self.sets, Set::put_without_cmd_name)?; + + Ok(()) + } +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 715189b..5e00dc0 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,10 +1,12 @@ -mod delete; -mod expire; -mod get; -mod has; -mod persist; -mod set; -mod ttl; +pub mod delete; +pub mod expire; +pub mod get; +pub mod has; +pub mod m_get; +pub mod m_set; +pub mod persist; +pub mod set; +pub mod ttl; use std::io::Cursor; @@ -13,6 +15,8 @@ use delete::Delete; use expire::Expire; use get::Get; use has::Has; +use m_get::MGet; +use m_set::MSet; use persist::Persist; use set::Set; use ttl::Ttl; @@ -30,6 +34,9 @@ pub enum Command { Ttl(Ttl), Expire(Expire), Persist(Persist), + + MSet(MSet), + MGet(MGet), } impl Command { @@ -42,6 +49,9 @@ impl Command { Command::Ttl(ttl) => ttl.execute(db, connection).await, Command::Expire(expire) => expire.execute(db, connection).await, Command::Persist(persist) => persist.execute(db, connection).await, + + Command::MSet(m_set) => m_set.execute(db, connection).await, + Command::MGet(m_get) => m_get.execute(db, connection).await, } } @@ -62,6 +72,8 @@ impl Command { "ttl" => Self::Ttl(Ttl::parse(bytes)?), "expire" => Self::Expire(Expire::parse(bytes)?), "persist" => Self::Persist(Persist::parse(bytes)?), + "mset" => Self::MSet(MSet::parse(bytes)?), + "mget" => Self::MGet(MGet::parse(bytes)?), _ => return Err(AppError::UnknownCommand(command_name)), }; diff --git a/src/commands/persist.rs b/src/commands/persist.rs index c76b179..3c559bb 100644 --- a/src/commands/persist.rs +++ b/src/commands/persist.rs @@ -1,8 +1,13 @@ use std::io::Cursor; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; -use crate::{Result, buffer::ArchiveBuf as _, connection::Connection, database::Database}; +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut as _}, + connection::Connection, + database::Database, +}; #[derive(Debug, Clone)] pub struct Persist { @@ -10,6 +15,10 @@ pub struct Persist { } impl Persist { + pub fn new(key: String) -> Self { + Self { key } + } + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { let value = db.persist(&self.key).await?; @@ -25,4 +34,13 @@ impl Persist { Ok(Self { key }) } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("persist")?; + self.put_without_cmd_name(buf) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string(&self.key) + } } diff --git a/src/commands/set.rs b/src/commands/set.rs index 11bf6ff..d3e2265 100644 --- a/src/commands/set.rs +++ b/src/commands/set.rs @@ -1,16 +1,29 @@ use std::io::Cursor; -use crate::{Result, buffer::ArchiveBuf as _, connection::Connection, database::Database}; -use bytes::{Buf as _, Bytes}; +use crate::{ + Result, + buffer::{ArchiveBuf as _, ArchiveBufMut as _}, + connection::Connection, + database::Database, +}; +use bytes::{Buf as _, BufMut as _, Bytes, BytesMut}; #[derive(Debug, Clone)] pub struct Set { - key: String, - data: Box<[u8]>, - expiration: Option, + pub(super) key: String, + pub(super) data: Box<[u8]>, + pub(super) expiration: Option, } impl Set { + pub fn new(key: String, data: Box<[u8]>, expiration: Option) -> Self { + Self { + key, + data, + expiration, + } + } + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { db.set(self.key, self.data, self.expiration).await?; @@ -32,4 +45,20 @@ impl Set { expiration, }) } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("set")?; + + self.put_without_cmd_name(buf)?; + + Ok(()) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string(&self.key)?; + buf.put_bytes_with_length(&self.data); + buf.put_option(self.expiration, BytesMut::put_u64); + + Ok(()) + } } diff --git a/src/commands/ttl.rs b/src/commands/ttl.rs index c18dfee..1bf30f8 100644 --- a/src/commands/ttl.rs +++ b/src/commands/ttl.rs @@ -15,6 +15,10 @@ pub struct Ttl { } impl Ttl { + pub fn new(key: String) -> Self { + Self { key } + } + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { let ttl = db.ttl(&self.key).await; @@ -31,4 +35,16 @@ impl Ttl { Ok(Self { key }) } + + pub fn put(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string("ttl")?; + + self.put_without_cmd_name(buf)?; + + Ok(()) + } + + pub fn put_without_cmd_name(&self, buf: &mut BytesMut) -> Result<()> { + buf.put_short_string(&self.key) + } } diff --git a/src/tests.rs b/src/tests.rs index f2d1703..2b77e12 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -50,6 +50,54 @@ async fn expiration() -> Result<(), Box> { Ok(()) } +#[tokio::test] +async fn m_set_m_get() -> Result<(), Box> { + let config = ServerConfig::builder() + .host("127.0.0.1".into()) + .port(6172) + .build(); + let mut server = Server::new(&config).await?; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let server_handle = tokio::spawn(async move { server.run(shutdown_rx).await }); + + let mut client = client("127.0.0.1:6172").await?; + + client + .m_set( + vec!["key-0", "key-1", "key-2"], + vec![b"value-0", b"value-1", b"value-2"], + vec![None, Some(2), None], + ) + .await?; + + assert_eq!( + client + .m_get(vec!["key-0".into(), "key-1".into(), "key-2".into()]) + .await?, + vec![ + Some("value-0".into()), + Some("value-1".into()), + Some("value-2".into()) + ] + ); + + tokio::time::sleep(Duration::from_secs(2)).await; + + assert_eq!( + client + .m_get(vec!["key-0".into(), "key-1".into(), "key-2".into()]) + .await?, + vec![Some("value-0".into()), None, Some("value-2".into())] + ); + + shutdown_tx.send(()).unwrap(); + + server_handle.await??; + + Ok(()) +} + async fn client( addr: A, ) -> Result> {