diff --git a/src/client.rs b/src/client.rs index c6520b1..366804c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -170,6 +170,34 @@ impl Client { Ok(ttl) } + pub async fn expire(&mut self, key: &str, seconds: u64) -> Result { + let mut bytes = BytesMut::new(); + bytes.put_u16(6); + bytes.put_slice(b"expire"); + + let key_length: u16 = key + .len() + .try_into() + .map_err(|_| AppError::KeyLength(key.len()))?; + + bytes.put_u16(key_length); + bytes.put_slice(key.as_bytes()); + + bytes.put_u64(seconds); + + self.connection.write(bytes.into()).await?; + + let mut r = self.get_response().await?; + + let success = match r.try_get_u8()? { + 1 => true, + 0 => false, + _ => return Err(AppError::InvalidCommandResponse), + }; + + Ok(success) + } + async fn get_response(&mut self) -> Result { self.connection .read_bytes() diff --git a/src/commands/expire.rs b/src/commands/expire.rs new file mode 100644 index 0000000..ea7363f --- /dev/null +++ b/src/commands/expire.rs @@ -0,0 +1,37 @@ +use std::io::Cursor; + +use bytes::{Buf as _, Bytes}; + +use crate::{Result, connection::Connection, database::Database, errors::AppError}; + +#[derive(Debug, Clone)] +pub struct Expire { + key: String, + seconds: u64, +} + +impl Expire { + pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> { + let value = db.expire(&self.key, self.seconds).await?; + + connection + .write(Bytes::from_static(if value { &[1] } else { &[0] })) + .await?; + + Ok(()) + } + + pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result { + let key_length = bytes.try_get_u16()? as usize; + + if bytes.remaining() < key_length { + return Err(AppError::IncompleteCommandBuffer); + } + + let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?; + + let seconds = bytes.try_get_u64()?; + + Ok(Self { key, seconds }) + } +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index a33d14f..43cf882 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,4 +1,5 @@ mod delete; +mod expire; mod get; mod has; mod set; @@ -8,6 +9,7 @@ use std::io::Cursor; use bytes::{Buf, BytesMut}; use delete::Delete; +use expire::Expire; use get::Get; use has::Has; use set::Set; @@ -22,6 +24,7 @@ pub enum Command { Delete(Delete), Has(Has), Ttl(Ttl), + Expire(Expire), } impl Command { @@ -32,6 +35,7 @@ impl Command { Command::Delete(delete) => delete.execute(db, connection).await, Command::Has(has) => has.execute(db, connection).await, Command::Ttl(ttl) => ttl.execute(db, connection).await, + Command::Expire(expire) => expire.execute(db, connection).await, } } @@ -56,6 +60,7 @@ impl Command { "delete" => Self::Delete(Delete::parse(bytes)?), "has" => Self::Has(Has::parse(bytes)?), "ttl" => Self::Ttl(Ttl::parse(bytes)?), + "expire" => Self::Expire(Expire::parse(bytes)?), _ => return Err(AppError::UnknownCommand(command_name)), }; diff --git a/src/database.rs b/src/database.rs index 84922aa..284fdbf 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,6 +1,7 @@ use std::{ collections::{BTreeMap, BTreeSet}, sync::Arc, + time::Duration, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -138,6 +139,44 @@ impl Database { .flatten() } + pub async fn expire(&self, key: &str, seconds: u64) -> Result { + let mut state = self.state.lock().await; + + let expiration = Instant::now() + Duration::from_secs(seconds); + + let notify = + state + .expirations + .iter() + .next() + .is_none_or(|&(instant, ref next_expiration_key)| { + next_expiration_key == &key || instant > expiration + }); + + let key = Yarn::copy(key); + + let Some(value) = state.entries.get_mut(&key) else { + return Ok(false); + }; + + let previous_expiration = value.expiration.take(); + value.expiration = Some(expiration); + + if let Some(previous_expiration) = previous_expiration { + state + .expirations + .remove(&(previous_expiration, key.clone())); + }; + + state.expirations.insert((expiration, key)); + + if notify { + self.notify.notify_one(); + } + + Ok(true) + } + pub async fn shutdown(&mut self) { self.state.lock().await.shutdown = true; self.notify.notify_one(); diff --git a/src/tests.rs b/src/tests.rs index 9e1f8cd..c3e9b3d 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -27,9 +27,13 @@ async fn expiration() -> Result<(), Box> { assert!(client.has("test-key").await.unwrap()); assert_eq!(client.ttl("test-key").await.unwrap(), Some(1)); + assert!(client.expire("test-key", 2).await?); + tokio::time::sleep(Duration::from_secs(2)).await; assert!(!client.has("test-key").await.unwrap()); + assert!(!client.expire("test-key", 10).await?); + shutdown_tx.send(()).unwrap(); server_handle.await??;