diff --git a/src/client.rs b/src/client.rs index 366804c..6a7f49a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -198,6 +198,32 @@ impl Client { Ok(success) } + pub async fn persist(&mut self, key: &str) -> Result { + let mut bytes = BytesMut::new(); + bytes.put_u16(7); + bytes.put_slice(b"persist"); + + let key_length: u16 = key + .len() + .try_into() + .map_err(|_| AppError::KeyLength(key.len()))?; + + bytes.put_u16(key_length); + bytes.put_slice(key.as_bytes()); + + self.connection.write(bytes.into()).await?; + + let mut r = self.get_response().await?; + + let success = match r.try_get_u8()? { + 1 => true, + 0 => false, + _ => return Err(AppError::InvalidCommandResponse), + }; + + Ok(success) + } + async fn get_response(&mut self) -> Result { self.connection .read_bytes() diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 43cf882..9569030 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -2,6 +2,7 @@ mod delete; mod expire; mod get; mod has; +mod persist; mod set; mod ttl; @@ -12,6 +13,7 @@ use delete::Delete; use expire::Expire; use get::Get; use has::Has; +use persist::Persist; use set::Set; use ttl::Ttl; @@ -25,6 +27,7 @@ pub enum Command { Has(Has), Ttl(Ttl), Expire(Expire), + Persist(Persist), } impl Command { @@ -36,6 +39,7 @@ impl Command { Command::Has(has) => has.execute(db, connection).await, Command::Ttl(ttl) => ttl.execute(db, connection).await, Command::Expire(expire) => expire.execute(db, connection).await, + Command::Persist(persist) => persist.execute(db, connection).await, } } @@ -61,6 +65,7 @@ impl Command { "has" => Self::Has(Has::parse(bytes)?), "ttl" => Self::Ttl(Ttl::parse(bytes)?), "expire" => Self::Expire(Expire::parse(bytes)?), + "persist" => Self::Persist(Persist::parse(bytes)?), _ => return Err(AppError::UnknownCommand(command_name)), }; diff --git a/src/commands/persist.rs b/src/commands/persist.rs new file mode 100644 index 0000000..0288c44 --- /dev/null +++ b/src/commands/persist.rs @@ -0,0 +1,34 @@ +use std::io::Cursor; + +use bytes::{Buf as _, Bytes}; + +use crate::{Result, connection::Connection, database::Database, errors::AppError}; + +#[derive(Debug, Clone)] +pub struct Persist { + key: String, +} + +impl Persist { + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { + let value = db.persist(&self.key).await?; + + connection + .write(Bytes::from_static(if value { &[1] } else { &[0] })) + .await?; + + Ok(()) + } + + pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { + let key_length = bytes.try_get_u16()? as usize; + + if bytes.remaining() < key_length { + return Err(AppError::IncompleteCommandBuffer); + } + + let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; + + Ok(Self { key }) + } +} diff --git a/src/database.rs b/src/database.rs index 284fdbf..7b0d5e8 100644 --- a/src/database.rs +++ b/src/database.rs @@ -142,6 +142,8 @@ impl Database { pub async fn expire(&self, key: &str, seconds: u64) -> Result { let mut state = self.state.lock().await; + let key = Yarn::copy(key); + let expiration = Instant::now() + Duration::from_secs(seconds); let notify = @@ -153,8 +155,6 @@ impl Database { next_expiration_key == &key || instant > expiration }); - let key = Yarn::copy(key); - let Some(value) = state.entries.get_mut(&key) else { return Ok(false); }; @@ -177,6 +177,35 @@ impl Database { Ok(true) } + pub async fn persist(&self, key: &str) -> Result { + let mut state = self.state.lock().await; + + let key = Yarn::copy(key); + + let notify = state + .expirations + .iter() + .next() + .is_some_and(|&(_, ref next_expiration_key)| next_expiration_key == &key); + + let Some(value) = state.entries.get_mut(&key) else { + return Ok(false); + }; + + match value.expiration.take() { + Some(expiration) => { + state.expirations.remove(&(expiration, key)); + } + None => return Ok(false), + } + + if notify { + self.notify.notify_one(); + } + + Ok(true) + } + pub async fn shutdown(&mut self) { self.state.lock().await.shutdown = true; self.notify.notify_one(); diff --git a/src/tests.rs b/src/tests.rs index c3e9b3d..f2d1703 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -15,7 +15,7 @@ async fn expiration() -> Result<(), Box> { let mut client = client("127.0.0.1:6171").await?; client - .set("test-key", "test-value".as_bytes(), Some(3)) + .set("test-key", b"test-value", Some(3)) .await .unwrap(); @@ -34,6 +34,15 @@ async fn expiration() -> Result<(), Box> { assert!(!client.expire("test-key", 10).await?); + client.set("test-key", b"test-value", Some(2)).await?; + + assert_eq!(client.ttl("test-key").await?, Some(1)); + assert!(client.persist("test-key").await?); + + tokio::time::sleep(Duration::from_secs(2)).await; + + assert_eq!(client.get("test-key").await?, Some("test-value".into())); + shutdown_tx.send(()).unwrap(); server_handle.await??;