read expiration from Set command + tests

This commit is contained in:
2025-06-17 01:52:11 +02:00
parent 28b42c786c
commit 20e3fbd5d3
8 changed files with 119 additions and 41 deletions

View File

@@ -1,7 +1,7 @@
use bytes::{Buf, BufMut as _, Bytes, BytesMut}; use bytes::{Buf, BufMut as _, Bytes, BytesMut};
use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::net::{TcpStream, ToSocketAddrs};
use crate::{Result, connection::Connection, database::Value, errors::AppError}; use crate::{Result, connection::Connection, errors::AppError};
pub struct Client { pub struct Client {
connection: Connection, connection: Connection,
@@ -50,7 +50,12 @@ impl Client {
Ok(response) Ok(response)
} }
pub async fn set(&mut self, key: &str, value: Value) -> Result<()> { pub async fn set(
&mut self,
key: &str,
data: &[u8],
expiration_secs: Option<u64>,
) -> Result<()> {
let mut bytes = BytesMut::new(); let mut bytes = BytesMut::new();
bytes.put_u16(3); bytes.put_u16(3);
@@ -64,7 +69,18 @@ impl Client {
bytes.put_u16(key_length); bytes.put_u16(key_length);
bytes.put_slice(key.as_bytes()); bytes.put_slice(key.as_bytes());
value.write_to_bytes(&mut bytes); bytes.put_u32(data.len() as u32);
bytes.put_slice(data);
match expiration_secs {
Some(seconds) => {
bytes.put_u8(1);
bytes.put_u64(seconds);
}
None => {
bytes.put_u8(0);
}
}
self.connection.write(bytes.into()).await?; self.connection.write(bytes.into()).await?;

View File

@@ -1,6 +1,7 @@
use std::io::Cursor; use std::{io::Cursor, time::Duration};
use bytes::{Buf as _, Bytes}; use bytes::{Buf as _, Bytes};
use tokio::time::Instant;
use crate::{ use crate::{
Result, Result,
@@ -41,9 +42,15 @@ impl Set {
let data = bytes.copy_to_bytes(value_length); let data = bytes.copy_to_bytes(value_length);
let expiration: Option<Instant> = match bytes.try_get_u8()? {
1 => Some(Instant::now() + Duration::from_secs(bytes.try_get_u64()?)),
0 => None,
_ => return Err(AppError::UnexpectedCommandData),
};
Ok(Self { Ok(Self {
key, key,
value: Value::new(data, None), value: Value::new(data, expiration),
}) })
} }
} }

View File

@@ -1,6 +1,6 @@
use bon::Builder; use bon::Builder;
#[derive(Debug, Builder)] #[derive(Debug, Builder, Clone)]
pub struct ServerConfig { pub struct ServerConfig {
#[builder(default = String::from("0.0.0.0"))] #[builder(default = String::from("0.0.0.0"))]
pub host: String, pub host: String,

View File

@@ -22,6 +22,7 @@ pub struct Database {
pub struct DatabaseState { pub struct DatabaseState {
entries: BTreeMap<Yarn, Value>, entries: BTreeMap<Yarn, Value>,
expirations: BTreeSet<(Instant, Yarn)>, expirations: BTreeSet<(Instant, Yarn)>,
shutdown: bool,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -126,14 +127,22 @@ impl Database {
state.entries.contains_key(key) state.entries.contains_key(key)
} }
pub async fn shutdown(&mut self) {
self.state.lock().await.shutdown = true;
self.notify.notify_one();
}
} }
// TODO: Add shutdown stuff
pub async fn key_expiration_manager(db: Database) { pub async fn key_expiration_manager(db: Database) {
'outer: loop { 'outer: loop {
let mut state_lock = db.state.lock().await; let mut state_lock = db.state.lock().await;
let state = &mut *state_lock; let state = &mut *state_lock;
if state.shutdown {
break;
}
let now = Instant::now(); let now = Instant::now();
while let Some((expiration, key)) = state.expirations.iter().next().as_ref() { while let Some((expiration, key)) = state.expirations.iter().next().as_ref() {
@@ -157,4 +166,6 @@ pub async fn key_expiration_manager(db: Database) {
drop(state_lock); drop(state_lock);
db.notify.notified().await; db.notify.notified().await;
} }
log::debug!("key_expiration_manager has finished");
} }

View File

@@ -18,4 +18,6 @@ pub enum AppError {
NoResponse, NoResponse,
#[error("Expected a different response for the executed command")] #[error("Expected a different response for the executed command")]
InvalidCommandResponse, InvalidCommandResponse,
#[error("The binary command data is not structured correctly")]
UnexpectedCommandData,
} }

View File

@@ -1,6 +1,10 @@
use config::ServerConfig; use config::ServerConfig;
use errors::AppError; use errors::AppError;
use server::Server; use server::Server;
use tokio::{signal::ctrl_c, sync::oneshot};
#[cfg(test)]
pub mod tests;
pub mod client; pub mod client;
pub mod commands; pub mod commands;
@@ -18,6 +22,7 @@ async fn main() -> Result<()> {
env_logger::builder() env_logger::builder()
.format_target(false) .format_target(false)
.filter_level(log::LevelFilter::Info) .filter_level(log::LevelFilter::Info)
.parse_default_env()
.init(); .init();
let config = ServerConfig::builder() let config = ServerConfig::builder()
@@ -38,43 +43,23 @@ async fn main() -> Result<()> {
let mut server = Server::new(&config).await?; let mut server = Server::new(&config).await?;
// tokio::spawn(test());
log::info!("The server is listening on {}:{}", config.host, config.port); log::info!("The server is listening on {}:{}", config.host, config.port);
log::info!( log::info!(
"The maximum amount of concurrent connections is {}", "The maximum amount of concurrent connections is {}",
config.max_connections config.max_connections
); );
server.run().await?; let (shutdown_sender, shutdown_receiver) = oneshot::channel();
tokio::spawn(async move {
if ctrl_c().await.is_ok() {
let _ = shutdown_sender.send(());
}
});
server.run(shutdown_receiver).await?;
log::info!("Goodbye");
Ok(()) Ok(())
} }
/* async fn test() -> Result<()> {
let mut client = Client::new("127.0.0.1:6171").await?;
let key = String::from("my-key");
client
.set(
&key,
Value::from_string(
"my-value".into(),
Some(Instant::now() + Duration::from_secs(5)),
),
)
.await?;
assert!(client.has(&key).await?);
let value = client.get(&key).await?;
tokio::time::sleep(Duration::from_secs(6)).await;
assert!(!client.has(&key).await?);
let value = client.get(&key).await?;
Ok(())
} */

View File

@@ -1,6 +1,8 @@
use std::sync::Arc; use std::sync::Arc;
use tokio::net::ToSocketAddrs; use tokio::net::ToSocketAddrs;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::{net::TcpListener, sync::Semaphore}; use tokio::{net::TcpListener, sync::Semaphore};
use crate::Result; use crate::Result;
@@ -14,6 +16,7 @@ pub struct Server {
db: Database, db: Database,
listener: TcpListener, listener: TcpListener,
connection_limit: Arc<Semaphore>, connection_limit: Arc<Semaphore>,
expiration_manager_handle: JoinHandle<()>,
} }
impl Server { impl Server {
@@ -27,23 +30,39 @@ impl Server {
let db = Database::new(); let db = Database::new();
tokio::spawn(key_expiration_manager(db.clone())); let expiration_manager_handle = tokio::spawn(key_expiration_manager(db.clone()));
Ok(Self { Ok(Self {
db, db,
connection_limit: Arc::new(Semaphore::const_new(max_connections)), connection_limit: Arc::new(Semaphore::const_new(max_connections)),
listener, listener,
expiration_manager_handle,
}) })
} }
pub async fn run(&mut self) -> Result<()> { pub async fn run(&mut self, mut shutdown: oneshot::Receiver<()>) -> Result<()> {
let shutdown = &mut shutdown;
loop { loop {
let permit = Arc::clone(&self.connection_limit) let permit = Arc::clone(&self.connection_limit)
.acquire_owned() .acquire_owned()
.await .await
.unwrap(); .unwrap();
let socket = self.listener.accept().await?.0; let Some(socket) = ({
tokio::select! {
socket = self.listener.accept() => Some(socket?.0),
_ = &mut *shutdown => None,
}
}) else {
log::info!("Shutting down");
self.db.shutdown().await;
let _ = (&mut self.expiration_manager_handle).await;
return Ok(());
};
let addr = socket.peer_addr()?; let addr = socket.peer_addr()?;
let connection = Connection::new(socket); let connection = Connection::new(socket);

38
src/tests.rs Normal file
View File

@@ -0,0 +1,38 @@
use std::time::Duration;
use tokio::sync::oneshot;
use crate::{client::Client, config::ServerConfig, server::Server};
#[tokio::test]
async fn expiration() -> Result<(), Box<dyn std::error::Error>> {
let config = ServerConfig::builder().host("127.0.0.1".into()).build();
let mut server = Server::new(&config).await?;
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let server_handle = tokio::spawn(async move { server.run(shutdown_rx).await });
let mut client = client("127.0.0.1:6171").await?;
client
.set("test-key", "test-value".as_bytes(), Some(2))
.await
.unwrap();
assert!(client.has("test-key").await.unwrap());
tokio::time::sleep(Duration::from_secs(1)).await;
assert!(client.has("test-key").await.unwrap());
tokio::time::sleep(Duration::from_secs(2)).await;
assert!(!client.has("test-key").await.unwrap());
shutdown_tx.send(()).unwrap();
server_handle.await??;
Ok(())
}
async fn client<A: tokio::net::ToSocketAddrs>(
addr: A,
) -> Result<Client, Box<dyn std::error::Error>> {
Ok(Client::new(addr).await?)
}