diff --git a/src/client.rs b/src/client.rs index 4f74ca8..8e58307 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,7 @@ use bytes::{Buf, BufMut as _, Bytes, BytesMut}; use tokio::net::{TcpStream, ToSocketAddrs}; -use crate::{Result, connection::Connection, database::Value, errors::AppError}; +use crate::{Result, connection::Connection, errors::AppError}; pub struct Client { connection: Connection, @@ -50,7 +50,12 @@ impl Client { Ok(response) } - pub async fn set(&mut self, key: &str, value: Value) -> Result<()> { + pub async fn set( + &mut self, + key: &str, + data: &[u8], + expiration_secs: Option, + ) -> Result<()> { let mut bytes = BytesMut::new(); bytes.put_u16(3); @@ -64,7 +69,18 @@ impl Client { bytes.put_u16(key_length); bytes.put_slice(key.as_bytes()); - value.write_to_bytes(&mut bytes); + bytes.put_u32(data.len() as u32); + bytes.put_slice(data); + + match expiration_secs { + Some(seconds) => { + bytes.put_u8(1); + bytes.put_u64(seconds); + } + None => { + bytes.put_u8(0); + } + } self.connection.write(bytes.into()).await?; diff --git a/src/commands/set.rs b/src/commands/set.rs index d714d92..33db36b 100644 --- a/src/commands/set.rs +++ b/src/commands/set.rs @@ -1,6 +1,7 @@ -use std::io::Cursor; +use std::{io::Cursor, time::Duration}; use bytes::{Buf as _, Bytes}; +use tokio::time::Instant; use crate::{ Result, @@ -41,9 +42,15 @@ impl Set { let data = bytes.copy_to_bytes(value_length); + let expiration: Option = match bytes.try_get_u8()? { + 1 => Some(Instant::now() + Duration::from_secs(bytes.try_get_u64()?)), + 0 => None, + _ => return Err(AppError::UnexpectedCommandData), + }; + Ok(Self { key, - value: Value::new(data, None), + value: Value::new(data, expiration), }) } } diff --git a/src/config.rs b/src/config.rs index 41f00b7..2af0ccb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,6 @@ use bon::Builder; -#[derive(Debug, Builder)] +#[derive(Debug, Builder, Clone)] pub struct ServerConfig { #[builder(default = String::from("0.0.0.0"))] pub host: String, diff --git a/src/database.rs b/src/database.rs index 3999801..e9623d7 100644 --- a/src/database.rs +++ b/src/database.rs @@ -22,6 +22,7 @@ pub struct Database { pub struct DatabaseState { entries: BTreeMap, expirations: BTreeSet<(Instant, Yarn)>, + shutdown: bool, } #[derive(Debug, Clone)] @@ -126,14 +127,22 @@ impl Database { state.entries.contains_key(key) } + + pub async fn shutdown(&mut self) { + self.state.lock().await.shutdown = true; + self.notify.notify_one(); + } } -// TODO: Add shutdown stuff pub async fn key_expiration_manager(db: Database) { 'outer: loop { let mut state_lock = db.state.lock().await; let state = &mut *state_lock; + if state.shutdown { + break; + } + let now = Instant::now(); while let Some((expiration, key)) = state.expirations.iter().next().as_ref() { @@ -157,4 +166,6 @@ pub async fn key_expiration_manager(db: Database) { drop(state_lock); db.notify.notified().await; } + + log::debug!("key_expiration_manager has finished"); } diff --git a/src/errors.rs b/src/errors.rs index 01b6771..b720822 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -18,4 +18,6 @@ pub enum AppError { NoResponse, #[error("Expected a different response for the executed command")] InvalidCommandResponse, + #[error("The binary command data is not structured correctly")] + UnexpectedCommandData, } diff --git a/src/main.rs b/src/main.rs index 503a307..7ba3083 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,10 @@ use config::ServerConfig; use errors::AppError; use server::Server; +use tokio::{signal::ctrl_c, sync::oneshot}; + +#[cfg(test)] +pub mod tests; pub mod client; pub mod commands; @@ -18,6 +22,7 @@ async fn main() -> Result<()> { env_logger::builder() .format_target(false) .filter_level(log::LevelFilter::Info) + .parse_default_env() .init(); let config = ServerConfig::builder() @@ -38,43 +43,23 @@ async fn main() -> Result<()> { let mut server = Server::new(&config).await?; - // tokio::spawn(test()); - log::info!("The server is listening on {}:{}", config.host, config.port); log::info!( "The maximum amount of concurrent connections is {}", config.max_connections ); - server.run().await?; + let (shutdown_sender, shutdown_receiver) = oneshot::channel(); + + tokio::spawn(async move { + if ctrl_c().await.is_ok() { + let _ = shutdown_sender.send(()); + } + }); + + server.run(shutdown_receiver).await?; + + log::info!("Goodbye"); Ok(()) } - -/* async fn test() -> Result<()> { - let mut client = Client::new("127.0.0.1:6171").await?; - - let key = String::from("my-key"); - - client - .set( - &key, - Value::from_string( - "my-value".into(), - Some(Instant::now() + Duration::from_secs(5)), - ), - ) - .await?; - - assert!(client.has(&key).await?); - - let value = client.get(&key).await?; - - tokio::time::sleep(Duration::from_secs(6)).await; - - assert!(!client.has(&key).await?); - - let value = client.get(&key).await?; - - Ok(()) -} */ diff --git a/src/server.rs b/src/server.rs index 292523d..994a9c2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,8 @@ use std::sync::Arc; use tokio::net::ToSocketAddrs; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; use tokio::{net::TcpListener, sync::Semaphore}; use crate::Result; @@ -14,6 +16,7 @@ pub struct Server { db: Database, listener: TcpListener, connection_limit: Arc, + expiration_manager_handle: JoinHandle<()>, } impl Server { @@ -27,23 +30,39 @@ impl Server { let db = Database::new(); - tokio::spawn(key_expiration_manager(db.clone())); + let expiration_manager_handle = tokio::spawn(key_expiration_manager(db.clone())); Ok(Self { db, connection_limit: Arc::new(Semaphore::const_new(max_connections)), listener, + expiration_manager_handle, }) } - pub async fn run(&mut self) -> Result<()> { + pub async fn run(&mut self, mut shutdown: oneshot::Receiver<()>) -> Result<()> { + let shutdown = &mut shutdown; + loop { let permit = Arc::clone(&self.connection_limit) .acquire_owned() .await .unwrap(); - let socket = self.listener.accept().await?.0; + let Some(socket) = ({ + tokio::select! { + socket = self.listener.accept() => Some(socket?.0), + _ = &mut *shutdown => None, + } + }) else { + log::info!("Shutting down"); + + self.db.shutdown().await; + let _ = (&mut self.expiration_manager_handle).await; + + return Ok(()); + }; + let addr = socket.peer_addr()?; let connection = Connection::new(socket); diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..1602be9 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,38 @@ +use std::time::Duration; + +use tokio::sync::oneshot; + +use crate::{client::Client, config::ServerConfig, server::Server}; + +#[tokio::test] +async fn expiration() -> Result<(), Box> { + let config = ServerConfig::builder().host("127.0.0.1".into()).build(); + let mut server = Server::new(&config).await?; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let server_handle = tokio::spawn(async move { server.run(shutdown_rx).await }); + + let mut client = client("127.0.0.1:6171").await?; + + client + .set("test-key", "test-value".as_bytes(), Some(2)) + .await + .unwrap(); + assert!(client.has("test-key").await.unwrap()); + tokio::time::sleep(Duration::from_secs(1)).await; + assert!(client.has("test-key").await.unwrap()); + tokio::time::sleep(Duration::from_secs(2)).await; + assert!(!client.has("test-key").await.unwrap()); + + shutdown_tx.send(()).unwrap(); + + server_handle.await??; + + Ok(()) +} + +async fn client( + addr: A, +) -> Result> { + Ok(Client::new(addr).await?) +}