refactor commands

This commit is contained in:
2025-06-12 19:04:52 +02:00
parent 06a503f67d
commit 34818ce050
8 changed files with 236 additions and 145 deletions

View File

@@ -1,139 +0,0 @@
use std::io::{Cursor, Read};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::{
Result,
connection::Connection,
database::{Database, Value},
errors::AppError,
};
#[derive(Debug)]
pub enum Command {
Get { key: String },
Set { key: String, value: Value },
Delete { key: String },
Has { key: String },
}
impl Command {
pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> {
match self {
Command::Get { ref key } => {
let value = db.get(&key).await;
let mut buf = BytesMut::new();
match value {
Some(v) => {
buf.put_u8(1);
buf.put_u32(v.len() as u32);
buf.put_slice(&v);
}
None => {
buf.put_u8(0);
}
}
connection.write(buf.into()).await?;
}
Command::Set { key, value } => {
db.set(key.clone(), value.clone()).await?;
connection.write(Bytes::from_static(&[1])).await?;
}
Command::Delete { ref key } => {
let value = db.delete(key).await;
let mut buf = BytesMut::new();
match value {
Some(v) => {
buf.put_u8(1);
buf.put_u32(v.len() as u32);
buf.put_slice(&v);
}
None => buf.put_u8(0),
}
connection.write(buf.into()).await?;
}
Command::Has { ref key } => {
let value = db.has(key).await;
let buf = Bytes::copy_from_slice(&[if value { 1 } else { 0 }]);
connection.write(buf.into()).await?;
}
}
Ok(())
}
pub fn parse(bytes: &BytesMut) -> Result<(Self, u64)> {
let mut buffer = Cursor::new(&bytes[..]);
let name = read_string(&mut buffer)?;
Self::parse_inner(name, &mut buffer)
}
fn parse_inner(command_name: String, bytes: &mut Cursor<&[u8]>) -> Result<(Self, u64)> {
let command = match command_name.as_str() {
"get" => {
let key = read_string(bytes)?;
Self::Get { key }
}
"set" => {
let key = read_string(bytes)?;
let data = read_bytes(bytes)?;
Self::Set {
key,
value: Value::new(data),
}
}
"delete" => {
let key = read_string(bytes)?;
Self::Delete { key }
}
"has" => {
let key = read_string(bytes)?;
Self::Has { key }
}
_ => return Err(AppError::UnknownCommand(command_name)),
};
Ok((command, bytes.position()))
}
}
fn read_string(buffer: &mut Cursor<&[u8]>) -> Result<String> {
let length = buffer.try_get_u16()? as usize;
if buffer.remaining() < length {
return Err(AppError::IncompleteCommandBuffer);
}
let mut contents = Vec::with_capacity(length);
for _ in 0..length {
contents.push(buffer.try_get_u8()?);
}
let string = String::from_utf8(contents)?;
Ok(string)
}
fn read_bytes(buffer: &mut Cursor<&[u8]>) -> Result<Bytes> {
let len = buffer.try_get_u32()? as usize;
if buffer.remaining() < len {
return Err(AppError::IncompleteCommandBuffer);
}
Ok(buffer.copy_to_bytes(len))
}

42
src/commands/delete.rs Normal file
View File

@@ -0,0 +1,42 @@
use std::io::Cursor;
use bytes::{Buf as _, BufMut as _, BytesMut};
use crate::{Result, connection::Connection, database::Database, errors::AppError};
#[derive(Debug, Clone)]
pub struct Delete {
key: String,
}
impl Delete {
pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> {
let value = db.delete(&self.key).await;
let mut buf = BytesMut::new();
match value {
Some(v) => {
buf.put_u8(1);
buf.put_u32(v.len() as u32);
buf.put_slice(&v);
}
None => buf.put_u8(0),
}
connection.write(buf.into()).await?;
Ok(())
}
pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result<Self> {
let key_length = bytes.try_get_u16()? as usize;
if bytes.remaining() < key_length {
return Err(AppError::IncompleteCommandBuffer);
}
let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?;
Ok(Self { key })
}
}

44
src/commands/get.rs Normal file
View File

@@ -0,0 +1,44 @@
use std::io::Cursor;
use bytes::{Buf as _, BufMut as _, BytesMut};
use crate::{Result, connection::Connection, database::Database, errors::AppError};
#[derive(Debug, Clone)]
pub struct Get {
key: String,
}
impl Get {
pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> {
let value = db.get(&self.key).await;
let mut buf = BytesMut::new();
match value {
Some(v) => {
buf.put_u8(1);
buf.put_u32(v.len() as u32);
buf.put_slice(&v);
}
None => {
buf.put_u8(0);
}
}
connection.write(buf.into()).await?;
Ok(())
}
pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result<Self> {
let key_length = bytes.try_get_u16()? as usize;
if bytes.remaining() < key_length {
return Err(AppError::IncompleteCommandBuffer);
}
let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?;
Ok(Self { key })
}
}

34
src/commands/has.rs Normal file
View File

@@ -0,0 +1,34 @@
use std::io::Cursor;
use bytes::{Buf as _, Bytes};
use crate::{Result, connection::Connection, database::Database, errors::AppError};
#[derive(Debug, Clone)]
pub struct Has {
key: String,
}
impl Has {
pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> {
let value = db.has(&self.key).await;
let buf = Bytes::copy_from_slice(&[if value { 1 } else { 0 }]);
connection.write(buf.into()).await?;
Ok(())
}
pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result<Self> {
let key_length = bytes.try_get_u16()? as usize;
if bytes.remaining() < key_length {
return Err(AppError::IncompleteCommandBuffer);
}
let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?;
Ok(Self { key })
}
}

64
src/commands/mod.rs Normal file
View File

@@ -0,0 +1,64 @@
pub mod delete;
mod get;
pub mod has;
pub mod set;
use std::io::Cursor;
use bytes::{Buf, BytesMut};
use delete::Delete;
use get::Get;
use has::Has;
use set::Set;
use crate::{
Result,
connection::Connection,
database::{Database, Value},
errors::AppError,
};
#[derive(Debug)]
pub enum Command {
Get(Get),
Set(Set),
Delete(Delete),
Has(Has),
}
impl Command {
pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> {
match self {
Command::Get(get) => get.execute(db, connection).await,
Command::Set(set) => set.execute(db, connection).await,
Command::Delete(delete) => delete.execute(db, connection).await,
Command::Has(has) => has.execute(db, connection).await,
}
}
pub fn parse(bytes: &BytesMut) -> Result<(Self, u64)> {
let mut buffer = Cursor::new(&bytes[..]);
let name_length = buffer.try_get_u16()? as usize;
if buffer.remaining() < name_length {
return Err(AppError::IncompleteCommandBuffer);
}
let name = String::from_utf8(buffer.copy_to_bytes(name_length).to_vec())?;
Self::parse_inner(name, &mut buffer)
}
fn parse_inner(command_name: String, bytes: &mut Cursor<&[u8]>) -> Result<(Self, u64)> {
let command = match command_name.to_lowercase().as_str() {
"get" => Self::Get(Get::parse(bytes)?),
"set" => Self::Set(Set::parse(bytes)?),
"delete" => Self::Delete(Delete::parse(bytes)?),
"has" => Self::Has(Has::parse(bytes)?),
_ => return Err(AppError::UnknownCommand(command_name)),
};
Ok((command, bytes.position()))
}
}

49
src/commands/set.rs Normal file
View File

@@ -0,0 +1,49 @@
use std::io::Cursor;
use bytes::{Buf as _, Bytes};
use crate::{
Result,
connection::Connection,
database::{Database, Value},
errors::AppError,
};
#[derive(Debug, Clone)]
pub struct Set {
key: String,
value: Value,
}
impl Set {
pub async fn execute(self, db: &Database, connection: &mut Connection) -> Result<()> {
db.set(self.key, self.value).await?;
connection.write(Bytes::from_static(&[1])).await?;
Ok(())
}
pub fn parse(bytes: &mut Cursor<&[u8]>) -> Result<Self> {
let key_length = bytes.try_get_u16()? as usize;
if bytes.remaining() < key_length {
return Err(AppError::IncompleteCommandBuffer);
}
let key = String::from_utf8(bytes.copy_to_bytes(key_length).to_vec())?;
let value_length = bytes.try_get_u32()? as usize;
if bytes.remaining() < value_length {
return Err(AppError::IncompleteCommandBuffer);
}
let data = bytes.copy_to_bytes(value_length);
Ok(Self {
key,
value: Value::new(data),
})
}
}

View File

@@ -4,7 +4,7 @@ use tokio::{
net::TcpStream, net::TcpStream,
}; };
use crate::{Result, command::Command}; use crate::{Result, commands::Command};
#[derive(Debug)] #[derive(Debug)]
pub struct Connection { pub struct Connection {

View File

@@ -5,7 +5,7 @@ use errors::AppError;
use server::Server; use server::Server;
pub mod client; pub mod client;
pub mod command; pub mod commands;
pub mod connection; pub mod connection;
pub mod database; pub mod database;
pub mod errors; pub mod errors;
@@ -18,10 +18,7 @@ pub type Result<T> = std::result::Result<T, AppError>;
async fn main() -> Result<()> { async fn main() -> Result<()> {
let mut server = Server::new("127.0.0.1:6171").await?; let mut server = Server::new("127.0.0.1:6171").await?;
// Testing tokio::spawn(client("client-1".into()));
for i in 0..256 {
tokio::spawn(client(format!("client-{}", i + 1)));
}
server.run().await?; server.run().await?;