persist command

This commit is contained in:
2025-06-17 21:11:09 +02:00
parent c51c90b597
commit 0a9c8f81aa
5 changed files with 106 additions and 3 deletions

View File

@@ -198,6 +198,32 @@ impl Client {
Ok(success) Ok(success)
} }
pub async fn persist(&mut self, key: &str) -> Result<bool> {
let mut bytes = BytesMut::new();
bytes.put_u16(7);
bytes.put_slice(b"persist");
let key_length: u16 = key
.len()
.try_into()
.map_err(|_| AppError::KeyLength(key.len()))?;
bytes.put_u16(key_length);
bytes.put_slice(key.as_bytes());
self.connection.write(bytes.into()).await?;
let mut r = self.get_response().await?;
let success = match r.try_get_u8()? {
1 => true,
0 => false,
_ => return Err(AppError::InvalidCommandResponse),
};
Ok(success)
}
async fn get_response(&mut self) -> Result<Bytes> { async fn get_response(&mut self) -> Result<Bytes> {
self.connection self.connection
.read_bytes() .read_bytes()

View File

@@ -2,6 +2,7 @@ mod delete;
mod expire; mod expire;
mod get; mod get;
mod has; mod has;
mod persist;
mod set; mod set;
mod ttl; mod ttl;
@@ -12,6 +13,7 @@ use delete::Delete;
use expire::Expire; use expire::Expire;
use get::Get; use get::Get;
use has::Has; use has::Has;
use persist::Persist;
use set::Set; use set::Set;
use ttl::Ttl; use ttl::Ttl;
@@ -25,6 +27,7 @@ pub enum Command {
Has(Has), Has(Has),
Ttl(Ttl), Ttl(Ttl),
Expire(Expire), Expire(Expire),
Persist(Persist),
} }
impl Command { impl Command {
@@ -36,6 +39,7 @@ impl Command {
Command::Has(has) => has.execute(db, connection).await, Command::Has(has) => has.execute(db, connection).await,
Command::Ttl(ttl) => ttl.execute(db, connection).await, Command::Ttl(ttl) => ttl.execute(db, connection).await,
Command::Expire(expire) => expire.execute(db, connection).await, Command::Expire(expire) => expire.execute(db, connection).await,
Command::Persist(persist) => persist.execute(db, connection).await,
} }
} }
@@ -61,6 +65,7 @@ impl Command {
"has" => Self::Has(Has::parse(bytes)?), "has" => Self::Has(Has::parse(bytes)?),
"ttl" => Self::Ttl(Ttl::parse(bytes)?), "ttl" => Self::Ttl(Ttl::parse(bytes)?),
"expire" => Self::Expire(Expire::parse(bytes)?), "expire" => Self::Expire(Expire::parse(bytes)?),
"persist" => Self::Persist(Persist::parse(bytes)?),
_ => return Err(AppError::UnknownCommand(command_name)), _ => return Err(AppError::UnknownCommand(command_name)),
}; };

34
src/commands/persist.rs Normal file
View File

@@ -0,0 +1,34 @@
use std::io::Cursor;
use bytes::{Buf as _, Bytes};
use crate::{Result, connection::Connection, database::Database, errors::AppError};
#[derive(Debug, Clone)]
pub struct Persist {
key: String,
}
impl Persist {
pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> {
let value = db.persist(&self.key).await?;
connection
.write(Bytes::from_static(if value { &[1] } else { &[0] }))
.await?;
Ok(())
}
pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result<Self> {
let key_length = bytes.try_get_u16()? as usize;
if bytes.remaining() < key_length {
return Err(AppError::IncompleteCommandBuffer);
}
let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?;
Ok(Self { key })
}
}

View File

@@ -142,6 +142,8 @@ impl Database {
pub async fn expire(&self, key: &str, seconds: u64) -> Result<bool> { pub async fn expire(&self, key: &str, seconds: u64) -> Result<bool> {
let mut state = self.state.lock().await; let mut state = self.state.lock().await;
let key = Yarn::copy(key);
let expiration = Instant::now() + Duration::from_secs(seconds); let expiration = Instant::now() + Duration::from_secs(seconds);
let notify = let notify =
@@ -153,8 +155,6 @@ impl Database {
next_expiration_key == &key || instant > expiration next_expiration_key == &key || instant > expiration
}); });
let key = Yarn::copy(key);
let Some(value) = state.entries.get_mut(&key) else { let Some(value) = state.entries.get_mut(&key) else {
return Ok(false); return Ok(false);
}; };
@@ -177,6 +177,35 @@ impl Database {
Ok(true) Ok(true)
} }
pub async fn persist(&self, key: &str) -> Result<bool> {
let mut state = self.state.lock().await;
let key = Yarn::copy(key);
let notify = state
.expirations
.iter()
.next()
.is_some_and(|&(_, ref next_expiration_key)| next_expiration_key == &key);
let Some(value) = state.entries.get_mut(&key) else {
return Ok(false);
};
match value.expiration.take() {
Some(expiration) => {
state.expirations.remove(&(expiration, key));
}
None => return Ok(false),
}
if notify {
self.notify.notify_one();
}
Ok(true)
}
pub async fn shutdown(&mut self) { pub async fn shutdown(&mut self) {
self.state.lock().await.shutdown = true; self.state.lock().await.shutdown = true;
self.notify.notify_one(); self.notify.notify_one();

View File

@@ -15,7 +15,7 @@ async fn expiration() -> Result<(), Box<dyn std::error::Error>> {
let mut client = client("127.0.0.1:6171").await?; let mut client = client("127.0.0.1:6171").await?;
client client
.set("test-key", "test-value".as_bytes(), Some(3)) .set("test-key", b"test-value", Some(3))
.await .await
.unwrap(); .unwrap();
@@ -34,6 +34,15 @@ async fn expiration() -> Result<(), Box<dyn std::error::Error>> {
assert!(!client.expire("test-key", 10).await?); assert!(!client.expire("test-key", 10).await?);
client.set("test-key", b"test-value", Some(2)).await?;
assert_eq!(client.ttl("test-key").await?, Some(1));
assert!(client.persist("test-key").await?);
tokio::time::sleep(Duration::from_secs(2)).await;
assert_eq!(client.get("test-key").await?, Some("test-value".into()));
shutdown_tx.send(()).unwrap(); shutdown_tx.send(()).unwrap();
server_handle.await??; server_handle.await??;