diff --git a/Cargo.toml b/Cargo.toml index 76d0059..eee7582 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,8 +32,8 @@ error-stack = {version="0.4.0", default-features=false, features=["std"]} async-trait = {version="0.1.76", default-features=false} regex = {version="1.9.4", features=["perf"], default-features = false} smallvec = {version="1.13.1", features=["union", "serde"], default-features=false} -futures = {version="0.3.28", default-features=false} -dhat = {version="0.3.3", optional = true, default-features=false} +futures = {version="0.3.30", default-features=false, features=["alloc"]} +dhat = {version="0.3.2", optional = true, default-features=false} mimalloc = { version = "0.1.38", default-features = false } async-once-cell = {version="0.5.3", default-features=false} actix-governor = {version="0.5.0", default-features=false} diff --git a/src/bin/websurfx.rs b/src/bin/websurfx.rs index 1852695..c3d8c38 100644 --- a/src/bin/websurfx.rs +++ b/src/bin/websurfx.rs @@ -5,7 +5,7 @@ #[cfg(not(feature = "dhat-heap"))] use mimalloc::MiMalloc; -use std::net::TcpListener; +use std::{net::TcpListener, sync::OnceLock}; use websurfx::{cache::cacher::create_cache, config::parser::Config, run}; /// A dhat heap memory profiler @@ -17,6 +17,9 @@ static ALLOC: dhat::Alloc = dhat::Alloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; +/// A static constant for holding the parsed config. +static CONFIG: OnceLock = OnceLock::new(); + /// The function that launches the main server and registers all the routes of the website. /// /// # Error @@ -29,10 +32,10 @@ async fn main() -> std::io::Result<()> { #[cfg(feature = "dhat-heap")] let _profiler = dhat::Profiler::new_heap(); - // Initialize the parsed config file. - let config = Config::parse(false).unwrap(); + // Initialize the parsed config globally. + let config = CONFIG.get_or_init(|| Config::parse(false).unwrap()); - let cache = create_cache(&config).await; + let cache = create_cache(config).await; log::info!( "started server on port {} and IP {}", diff --git a/src/cache/redis_cacher.rs b/src/cache/redis_cacher.rs index c775963..a334996 100644 --- a/src/cache/redis_cacher.rs +++ b/src/cache/redis_cacher.rs @@ -1,11 +1,13 @@ //! This module provides the functionality to cache the aggregated results fetched and aggregated //! from the upstream search engines in a json format. +use super::error::CacheError; use error_stack::Report; -use futures::future::try_join_all; +use futures::stream::FuturesUnordered; use redis::{aio::ConnectionManager, AsyncCommands, Client, RedisError}; -use super::error::CacheError; +/// A constant holding the redis pipeline size. +const REDIS_PIPELINE_SIZE: usize = 3; /// A named struct which stores the redis Connection url address to which the client will /// connect to. @@ -20,6 +22,8 @@ pub struct RedisCache { current_connection: u8, /// It stores the max TTL for keys. cache_ttl: u16, + /// It stores the redis pipeline struct of size 3. + pipeline: redis::Pipeline, } impl RedisCache { @@ -30,6 +34,8 @@ impl RedisCache { /// * `redis_connection_url` - It takes the redis Connection url address. /// * `pool_size` - It takes the size of the connection pool (in other words the number of /// connections that should be stored in the pool). + /// * `cache_ttl` - It takes the the time to live for cached results to live in the redis + /// server. /// /// # Error /// @@ -41,18 +47,28 @@ impl RedisCache { cache_ttl: u16, ) -> Result> { let client = Client::open(redis_connection_url)?; - let mut tasks: Vec<_> = Vec::new(); + let tasks: FuturesUnordered<_> = FuturesUnordered::new(); for _ in 0..pool_size { - tasks.push(client.get_connection_manager()); + let client_partially_cloned = client.clone(); + tasks.push(tokio::spawn(async move { + client_partially_cloned.get_connection_manager().await + })); + } + + let mut outputs = Vec::new(); + for task in tasks { + outputs.push(task.await??); } let redis_cache = RedisCache { - connection_pool: try_join_all(tasks).await?, + connection_pool: outputs, pool_size, current_connection: Default::default(), cache_ttl, + pipeline: redis::Pipeline::with_capacity(REDIS_PIPELINE_SIZE), }; + Ok(redis_cache) } @@ -122,13 +138,14 @@ impl RedisCache { keys: impl Iterator, ) -> Result<(), Report> { self.current_connection = Default::default(); - let mut pipeline = redis::Pipeline::with_capacity(3); for (key, json_result) in keys.zip(json_results) { - pipeline.set_ex(key, json_result, self.cache_ttl.into()); + self.pipeline + .set_ex(key, json_result, self.cache_ttl.into()); } - let mut result: Result<(), RedisError> = pipeline + let mut result: Result<(), RedisError> = self + .pipeline .query_async(&mut self.connection_pool[self.current_connection as usize]) .await; @@ -149,7 +166,8 @@ impl RedisCache { CacheError::PoolExhaustionWithConnectionDropError, )); } - result = pipeline + result = self + .pipeline .query_async( &mut self.connection_pool[self.current_connection as usize], ) diff --git a/src/engines/bing.rs b/src/engines/bing.rs index 84dbf93..ec582e4 100644 --- a/src/engines/bing.rs +++ b/src/engines/bing.rs @@ -48,7 +48,7 @@ impl SearchEngine for Bing { user_agent: &str, client: &Client, _safe_search: u8, - ) -> Result, EngineError> { + ) -> Result, EngineError> { // Bing uses `start results from this number` convention // So, for 10 results per page, page 0 starts at 1, page 1 // starts at 11, and so on. diff --git a/src/engines/brave.rs b/src/engines/brave.rs index 49626e3..65067fc 100644 --- a/src/engines/brave.rs +++ b/src/engines/brave.rs @@ -44,7 +44,7 @@ impl SearchEngine for Brave { user_agent: &str, client: &Client, safe_search: u8, - ) -> Result, EngineError> { + ) -> Result, EngineError> { let url = format!("https://search.brave.com/search?q={query}&offset={page}"); let safe_search_level = match safe_search { diff --git a/src/engines/duckduckgo.rs b/src/engines/duckduckgo.rs index c48522f..02ee481 100644 --- a/src/engines/duckduckgo.rs +++ b/src/engines/duckduckgo.rs @@ -47,7 +47,7 @@ impl SearchEngine for DuckDuckGo { user_agent: &str, client: &Client, _safe_search: u8, - ) -> Result, EngineError> { + ) -> Result, EngineError> { // Page number can be missing or empty string and so appropriate handling is required // so that upstream server recieves valid page number. let url: String = match page { diff --git a/src/engines/librex.rs b/src/engines/librex.rs index b34393f..69e4611 100644 --- a/src/engines/librex.rs +++ b/src/engines/librex.rs @@ -62,7 +62,7 @@ impl SearchEngine for LibreX { user_agent: &str, client: &Client, _safe_search: u8, - ) -> Result, EngineError> { + ) -> Result, EngineError> { // Page number can be missing or empty string and so appropriate handling is required // so that upstream server recieves valid page number. let url: String = format!( diff --git a/src/engines/mojeek.rs b/src/engines/mojeek.rs index 3f7fbb1..e376828 100644 --- a/src/engines/mojeek.rs +++ b/src/engines/mojeek.rs @@ -47,7 +47,7 @@ impl SearchEngine for Mojeek { user_agent: &str, client: &Client, safe_search: u8, - ) -> Result, EngineError> { + ) -> Result, EngineError> { // Mojeek uses `start results from this number` convention // So, for 10 results per page, page 0 starts at 1, page 1 // starts at 11, and so on. @@ -72,8 +72,23 @@ impl SearchEngine for Mojeek { "Yep", "You", ]; + let qss = search_engines.join("%2C"); - let safe = if safe_search == 0 { "0" } else { "1" }; + + // A branchless condition to check whether the `safe_search` parameter has the + // value 0 or not. If it is zero then it sets the value 0 otherwise it sets + // the value to 1 for all other values of `safe_search` + // + // Moreover, the below branchless code is equivalent to the following code below: + // + // ```rust + // let safe = if safe_search == 0 { 0 } else { 1 }.to_string(); + // ``` + // + // For more information on branchless programming. See: + // + // * https://piped.video/watch?v=bVJ-mWWL7cE + let safe = u8::from(safe_search != 0).to_string(); // Mojeek detects automated requests, these are preferences that are // able to circumvent the countermeasure. Some of these are @@ -89,7 +104,7 @@ impl SearchEngine for Mojeek { ("hp", "minimal"), ("lb", "en"), ("qss", &qss), - ("safe", safe), + ("safe", &safe), ]; let mut query_params_string = String::new(); diff --git a/src/engines/search_result_parser.rs b/src/engines/search_result_parser.rs index 0512bdd..8c16b65 100644 --- a/src/engines/search_result_parser.rs +++ b/src/engines/search_result_parser.rs @@ -1,5 +1,4 @@ //! This modules provides helper functionalities for parsing a html document into internal SearchResult. -use std::collections::HashMap; use crate::models::{aggregation_models::SearchResult, engine_models::EngineError}; use error_stack::{Report, Result}; @@ -47,7 +46,7 @@ impl SearchResultParser { &self, document: &Html, builder: impl Fn(&ElementRef<'_>, &ElementRef<'_>, &ElementRef<'_>) -> Option, - ) -> Result, EngineError> { + ) -> Result, EngineError> { let res = document .select(&self.results) .filter_map(|result| { diff --git a/src/engines/searx.rs b/src/engines/searx.rs index 9bb297c..df96857 100644 --- a/src/engines/searx.rs +++ b/src/engines/searx.rs @@ -43,12 +43,21 @@ impl SearchEngine for Searx { user_agent: &str, client: &Client, mut safe_search: u8, - ) -> Result, EngineError> { - // Page number can be missing or empty string and so appropriate handling is required - // so that upstream server recieves valid page number. - if safe_search == 3 { - safe_search = 2; - }; + ) -> Result, EngineError> { + // A branchless condition to check whether the `safe_search` parameter has the + // value greater than equal to three or not. If it is, then it modifies the + // `safesearch` parameters value to 2. + // + // Moreover, the below branchless code is equivalent to the following code below: + // + // ```rust + // safe_search = u8::from(safe_search == 3) * 2; + // ``` + // + // For more information on branchless programming. See: + // + // * https://piped.video/watch?v=bVJ-mWWL7cE + safe_search = u8::from(safe_search >= 3) * 2; let url: String = format!( "https://searx.be/search?q={query}&pageno={}&safesearch={safe_search}", diff --git a/src/engines/startpage.rs b/src/engines/startpage.rs index 540a0ce..97b7a40 100644 --- a/src/engines/startpage.rs +++ b/src/engines/startpage.rs @@ -47,7 +47,7 @@ impl SearchEngine for Startpage { user_agent: &str, client: &Client, _safe_search: u8, - ) -> Result, EngineError> { + ) -> Result, EngineError> { // Page number can be missing or empty string and so appropriate handling is required // so that upstream server recieves valid page number. let url: String = format!( diff --git a/src/lib.rs b/src/lib.rs index ec35273..19702db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ pub mod results; pub mod server; pub mod templates; -use std::net::TcpListener; +use std::{net::TcpListener, sync::OnceLock}; use crate::server::router; @@ -31,6 +31,9 @@ use cache::cacher::{Cacher, SharedCache}; use config::parser::Config; use handler::{file_path, FileType}; +/// A static constant for holding the cache struct. +static SHARED_CACHE: OnceLock = OnceLock::new(); + /// Runs the web server on the provided TCP listener and returns a `Server` instance. /// /// # Arguments @@ -57,14 +60,14 @@ use handler::{file_path, FileType}; /// ``` pub fn run( listener: TcpListener, - config: Config, + config: &'static Config, cache: impl Cacher + 'static, ) -> std::io::Result { let public_folder_path: &str = file_path(FileType::Theme)?; let cloned_config_threads_opt: u8 = config.threads; - let cache = web::Data::new(SharedCache::new(cache)); + let cache = SHARED_CACHE.get_or_init(|| SharedCache::new(cache)); let server = HttpServer::new(move || { let cors: Cors = Cors::default() @@ -81,8 +84,8 @@ pub fn run( // Compress the responses provided by the server for the client requests. .wrap(Compress::default()) .wrap(Logger::default()) // added logging middleware for logging. - .app_data(web::Data::new(config.clone())) - .app_data(cache.clone()) + .app_data(web::Data::new(config)) + .app_data(web::Data::new(cache)) .wrap(cors) .wrap(Governor::new( &GovernorConfigBuilder::default() diff --git a/src/models/aggregation_models.rs b/src/models/aggregation_models.rs index 680d222..6be3958 100644 --- a/src/models/aggregation_models.rs +++ b/src/models/aggregation_models.rs @@ -154,8 +154,8 @@ impl SearchResults { } /// A setter function that sets the filtered to true. - pub fn set_filtered(&mut self) { - self.filtered = true; + pub fn set_filtered(&mut self, filtered: bool) { + self.filtered = filtered; } /// A getter function that gets the value of `engine_errors_info`. diff --git a/src/models/engine_models.rs b/src/models/engine_models.rs index 4d56836..932afce 100644 --- a/src/models/engine_models.rs +++ b/src/models/engine_models.rs @@ -4,7 +4,7 @@ use super::aggregation_models::SearchResult; use error_stack::{Report, Result, ResultExt}; use reqwest::Client; -use std::{collections::HashMap, fmt}; +use std::fmt; /// A custom error type used for handle engine associated errors. #[derive(Debug)] @@ -147,7 +147,7 @@ pub trait SearchEngine: Sync + Send { user_agent: &str, client: &Client, safe_search: u8, - ) -> Result, EngineError>; + ) -> Result, EngineError>; } /// A named struct which stores the engine struct with the name of the associated engine. diff --git a/src/results/aggregator.rs b/src/results/aggregator.rs index a95fb33..2759d9b 100644 --- a/src/results/aggregator.rs +++ b/src/results/aggregator.rs @@ -9,22 +9,24 @@ use crate::models::{ engine_models::{EngineError, EngineHandler}, }; use error_stack::Report; +use futures::stream::FuturesUnordered; use regex::Regex; use reqwest::{Client, ClientBuilder}; use std::time::{SystemTime, UNIX_EPOCH}; +use std::{fs::File, io::BufRead}; use std::{ - collections::HashMap, io::{BufReader, Read}, time::Duration, }; -use std::{fs::File, io::BufRead}; use tokio::task::JoinHandle; /// A constant for holding the prebuilt Client globally in the app. static CLIENT: std::sync::OnceLock = std::sync::OnceLock::new(); /// Aliases for long type annotations -type FutureVec = Vec, Report>>>; + +type FutureVec = + FuturesUnordered, Report>>>; /// The function aggregates the scraped results from the user-selected upstream search engines. /// These engines can be chosen either from the user interface (UI) or from the configuration file. @@ -37,7 +39,7 @@ type FutureVec = Vec, Report = Vec::with_capacity(0); // create tasks for upstream result fetching - let mut tasks: FutureVec = FutureVec::new(); + let tasks: FutureVec = FutureVec::new(); for engine_handler in upstream_search_engines { let (name, search_engine) = engine_handler.to_owned().into_name_engine(); @@ -117,7 +119,7 @@ pub async fn aggregate( } // aggregate search results, removing duplicates and handling errors the upstream engines returned - let mut result_map: HashMap = HashMap::new(); + let mut result_map: Vec<(String, SearchResult)> = Vec::new(); let mut engine_errors_info: Vec = Vec::new(); let mut handle_error = |error: &Report, engine_name: &'static str| { @@ -134,35 +136,27 @@ pub async fn aggregate( if result_map.is_empty() { match response { - Ok(results) => { - result_map = results.clone(); - } - Err(error) => { - handle_error(&error, engine); - } - } + Ok(results) => result_map = results, + Err(error) => handle_error(&error, engine), + }; continue; } match response { Ok(result) => { result.into_iter().for_each(|(key, value)| { - result_map - .entry(key) - .and_modify(|result| { - result.add_engines(engine); - }) - .or_insert_with(|| -> SearchResult { value }); + match result_map.iter().find(|(key_s, _)| key_s == &key) { + Some(value) => value.1.to_owned().add_engines(engine), + None => result_map.push((key, value)), + }; }); } - Err(error) => { - handle_error(&error, engine); - } - } + Err(error) => handle_error(&error, engine), + }; } if safe_search >= 3 { - let mut blacklist_map: HashMap = HashMap::new(); + let mut blacklist_map: Vec<(String, SearchResult)> = Vec::new(); filter_with_lists( &mut result_map, &mut blacklist_map, @@ -178,7 +172,7 @@ pub async fn aggregate( drop(blacklist_map); } - let results: Vec = result_map.into_values().collect(); + let results: Vec = result_map.iter().map(|(_, value)| value.clone()).collect(); Ok(SearchResults::new(results, &engine_errors_info)) } @@ -187,16 +181,16 @@ pub async fn aggregate( /// /// # Arguments /// -/// * `map_to_be_filtered` - A mutable reference to a `HashMap` of search results to filter, where the filtered results will be removed from. -/// * `resultant_map` - A mutable reference to a `HashMap` to hold the filtered results. +/// * `map_to_be_filtered` - A mutable reference to a `Vec` of search results to filter, where the filtered results will be removed from. +/// * `resultant_map` - A mutable reference to a `Vec` to hold the filtered results. /// * `file_path` - A `&str` representing the path to a file containing regex patterns to use for filtering. /// /// # Errors /// /// Returns an error if the file at `file_path` cannot be opened or read, or if a regex pattern is invalid. pub fn filter_with_lists( - map_to_be_filtered: &mut HashMap, - resultant_map: &mut HashMap, + map_to_be_filtered: &mut Vec<(String, SearchResult)>, + resultant_map: &mut Vec<(String, SearchResult)>, file_path: &str, ) -> Result<(), Box> { let mut reader = BufReader::new(File::open(file_path)?); @@ -204,18 +198,23 @@ pub fn filter_with_lists( for line in reader.by_ref().lines() { let re = Regex::new(line?.trim())?; + let mut length = map_to_be_filtered.len(); + let mut idx: usize = Default::default(); // Iterate over each search result in the map and check if it matches the regex pattern - for (url, search_result) in map_to_be_filtered.clone().into_iter() { - if re.is_match(&url.to_lowercase()) - || re.is_match(&search_result.title.to_lowercase()) - || re.is_match(&search_result.description.to_lowercase()) + while idx < length { + let ele = &map_to_be_filtered[idx]; + let ele_inner = &ele.1; + match re.is_match(&ele.0.to_lowercase()) + || re.is_match(&ele_inner.title.to_lowercase()) + || re.is_match(&ele_inner.description.to_lowercase()) { - // If the search result matches the regex pattern, move it from the original map to the resultant map - resultant_map.insert( - url.to_owned(), - map_to_be_filtered.remove(&url.to_owned()).unwrap(), - ); - } + true => { + // If the search result matches the regex pattern, move it from the original map to the resultant map + resultant_map.push(map_to_be_filtered.swap_remove(idx)); + length -= 1; + } + false => idx += 1, + }; } } @@ -226,15 +225,14 @@ pub fn filter_with_lists( mod tests { use super::*; use smallvec::smallvec; - use std::collections::HashMap; use std::io::Write; use tempfile::NamedTempFile; #[test] fn test_filter_with_lists() -> Result<(), Box> { // Create a map of search results to filter - let mut map_to_be_filtered = HashMap::new(); - map_to_be_filtered.insert( + let mut map_to_be_filtered = Vec::new(); + map_to_be_filtered.push(( "https://www.example.com".to_owned(), SearchResult { title: "Example Domain".to_owned(), @@ -243,15 +241,15 @@ mod tests { .to_owned(), engine: smallvec!["Google".to_owned(), "Bing".to_owned()], }, - ); - map_to_be_filtered.insert( + )); + map_to_be_filtered.push(( "https://www.rust-lang.org/".to_owned(), SearchResult { title: "Rust Programming Language".to_owned(), url: "https://www.rust-lang.org/".to_owned(), description: "A systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety.".to_owned(), engine: smallvec!["Google".to_owned(), "DuckDuckGo".to_owned()], - }, + },) ); // Create a temporary file with regex patterns @@ -260,7 +258,7 @@ mod tests { writeln!(file, "rust")?; file.flush()?; - let mut resultant_map = HashMap::new(); + let mut resultant_map = Vec::new(); filter_with_lists( &mut map_to_be_filtered, &mut resultant_map, @@ -268,8 +266,12 @@ mod tests { )?; assert_eq!(resultant_map.len(), 2); - assert!(resultant_map.contains_key("https://www.example.com")); - assert!(resultant_map.contains_key("https://www.rust-lang.org/")); + assert!(resultant_map + .iter() + .any(|(key, _)| key == "https://www.example.com")); + assert!(resultant_map + .iter() + .any(|(key, _)| key == "https://www.rust-lang.org/")); assert_eq!(map_to_be_filtered.len(), 0); Ok(()) @@ -277,8 +279,8 @@ mod tests { #[test] fn test_filter_with_lists_wildcard() -> Result<(), Box> { - let mut map_to_be_filtered = HashMap::new(); - map_to_be_filtered.insert( + let mut map_to_be_filtered = Vec::new(); + map_to_be_filtered.push(( "https://www.example.com".to_owned(), SearchResult { title: "Example Domain".to_owned(), @@ -287,8 +289,8 @@ mod tests { .to_owned(), engine: smallvec!["Google".to_owned(), "Bing".to_owned()], }, - ); - map_to_be_filtered.insert( + )); + map_to_be_filtered.push(( "https://www.rust-lang.org/".to_owned(), SearchResult { title: "Rust Programming Language".to_owned(), @@ -296,14 +298,14 @@ mod tests { description: "A systems programming language that runs blazingly fast, prevents segfaults, and guarantees thread safety.".to_owned(), engine: smallvec!["Google".to_owned(), "DuckDuckGo".to_owned()], }, - ); + )); // Create a temporary file with a regex pattern containing a wildcard let mut file = NamedTempFile::new()?; writeln!(file, "ex.*le")?; file.flush()?; - let mut resultant_map = HashMap::new(); + let mut resultant_map = Vec::new(); filter_with_lists( &mut map_to_be_filtered, @@ -312,18 +314,22 @@ mod tests { )?; assert_eq!(resultant_map.len(), 1); - assert!(resultant_map.contains_key("https://www.example.com")); + assert!(resultant_map + .iter() + .any(|(key, _)| key == "https://www.example.com")); assert_eq!(map_to_be_filtered.len(), 1); - assert!(map_to_be_filtered.contains_key("https://www.rust-lang.org/")); + assert!(map_to_be_filtered + .iter() + .any(|(key, _)| key == "https://www.rust-lang.org/")); Ok(()) } #[test] fn test_filter_with_lists_file_not_found() { - let mut map_to_be_filtered = HashMap::new(); + let mut map_to_be_filtered = Vec::new(); - let mut resultant_map = HashMap::new(); + let mut resultant_map = Vec::new(); // Call the `filter_with_lists` function with a non-existent file path let result = filter_with_lists( @@ -337,8 +343,8 @@ mod tests { #[test] fn test_filter_with_lists_invalid_regex() { - let mut map_to_be_filtered = HashMap::new(); - map_to_be_filtered.insert( + let mut map_to_be_filtered = Vec::new(); + map_to_be_filtered.push(( "https://www.example.com".to_owned(), SearchResult { title: "Example Domain".to_owned(), @@ -347,9 +353,9 @@ mod tests { .to_owned(), engine: smallvec!["Google".to_owned(), "Bing".to_owned()], }, - ); + )); - let mut resultant_map = HashMap::new(); + let mut resultant_map = Vec::new(); // Create a temporary file with an invalid regex pattern let mut file = NamedTempFile::new().unwrap(); diff --git a/src/server/router.rs b/src/server/router.rs index c46e79d..aa2a9cc 100644 --- a/src/server/router.rs +++ b/src/server/router.rs @@ -11,7 +11,9 @@ use std::fs::read_to_string; /// Handles the route of index page or main page of the `websurfx` meta search engine website. #[get("/")] -pub async fn index(config: web::Data) -> Result> { +pub async fn index( + config: web::Data<&'static Config>, +) -> Result> { Ok(HttpResponse::Ok().content_type(ContentType::html()).body( crate::templates::views::index::index( &config.style.colorscheme, @@ -25,7 +27,7 @@ pub async fn index(config: web::Data) -> Result, + config: web::Data<&'static Config>, ) -> Result> { Ok(HttpResponse::Ok().content_type(ContentType::html()).body( crate::templates::views::not_found::not_found( @@ -49,7 +51,9 @@ pub async fn robots_data(_req: HttpRequest) -> Result) -> Result> { +pub async fn about( + config: web::Data<&'static Config>, +) -> Result> { Ok(HttpResponse::Ok().content_type(ContentType::html()).body( crate::templates::views::about::about( &config.style.colorscheme, @@ -63,7 +67,7 @@ pub async fn about(config: web::Data) -> Result, + config: web::Data<&'static Config>, ) -> Result> { Ok(HttpResponse::Ok().content_type(ContentType::html()).body( crate::templates::views::settings::settings( diff --git a/src/server/routes/search.rs b/src/server/routes/search.rs index a5b2558..4f44f03 100644 --- a/src/server/routes/search.rs +++ b/src/server/routes/search.rs @@ -37,8 +37,8 @@ use tokio::join; #[get("/search")] pub async fn search( req: HttpRequest, - config: web::Data, - cache: web::Data, + config: web::Data<&'static Config>, + cache: web::Data<&'static SharedCache>, ) -> Result> { use std::sync::Arc; let params = web::Query::::from_query(req.query_string())?; @@ -70,8 +70,8 @@ pub async fn search( }); search_settings.safe_search_level = get_safesearch_level( - &Some(search_settings.safe_search_level), - ¶ms.safesearch, + params.safesearch, + search_settings.safe_search_level, config.safe_search, ); @@ -158,8 +158,8 @@ pub async fn search( /// It returns the `SearchResults` struct if the search results could be successfully fetched from /// the cache or from the upstream search engines otherwise it returns an appropriate error. async fn results( - config: &Config, - cache: &web::Data, + config: &'static Config, + cache: &'static SharedCache, query: &str, page: u32, search_settings: &server_models::Cookie<'_>, @@ -225,12 +225,12 @@ async fn results( search_results } }; - if results.engine_errors_info().is_empty() - && results.results().is_empty() - && !results.no_engines_selected() - { - results.set_filtered(); - } + let (engine_errors_info, results_empty_check, no_engines_selected) = ( + results.engine_errors_info().is_empty(), + results.results().is_empty(), + results.no_engines_selected(), + ); + results.set_filtered(engine_errors_info & results_empty_check & !no_engines_selected); cache .cache_results(&[results.clone()], &[cache_key.clone()]) .await?; @@ -267,24 +267,95 @@ fn is_match_from_filter_list( Ok(false) } -/// A helper function to modify the safe search level based on the url params. -/// The `safe_search` is the one in the user's cookie or -/// the default set by the server config if the cookie was missing. +/// A helper function to choose the safe search level value based on the URL parameters, +/// cookie value and config value. /// /// # Argurments /// -/// * `url_level` - Safe search level from the url. -/// * `safe_search` - User's cookie, or the safe search level set by the server -/// * `config_level` - Safe search level to fall back to -fn get_safesearch_level(cookie_level: &Option, url_level: &Option, config_level: u8) -> u8 { - match url_level { - Some(url_level) => { - if *url_level >= 3 { - config_level - } else { - *url_level - } +/// * `safe_search_level_from_url` - Safe search level from the URL parameters. +/// * `cookie_safe_search_level` - Safe search level value from the cookie. +/// * `config_safe_search_level` - Safe search level value from the config file. +/// +/// # Returns +/// +/// Returns an appropriate safe search level value based on the safe search level values +/// from the URL parameters, cookie and the config file. +fn get_safesearch_level( + safe_search_level_from_url: Option, + cookie_safe_search_level: u8, + config_safe_search_level: u8, +) -> u8 { + (u8::from(safe_search_level_from_url.is_some()) + * ((u8::from(config_safe_search_level >= 3) * config_safe_search_level) + + (u8::from(config_safe_search_level < 3) * safe_search_level_from_url.unwrap_or(0)))) + + (u8::from(safe_search_level_from_url.is_none()) + * ((u8::from(config_safe_search_level >= 3) * config_safe_search_level) + + (u8::from(config_safe_search_level < 3) * cookie_safe_search_level))) +} + +#[cfg(test)] +mod tests { + use std::time::{SystemTime, UNIX_EPOCH}; + + /// A helper function which creates a random mock safe search level value. + /// + /// # Returns + /// + /// Returns an optional u8 value. + fn mock_safe_search_level_value() -> Option { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .subsec_nanos() as f32; + let delay = ((nanos / 1_0000_0000 as f32).floor() as i8) - 1; + + match delay { + -1 => None, + some_num => Some(if some_num > 4 { some_num - 4 } else { some_num } as u8), } - None => cookie_level.unwrap_or(config_level), + } + + #[test] + /// A test function to test whether the output of the branchless and branched code + /// for the code to choose the appropriate safe search level is same or not. + fn get_safesearch_level_branched_branchless_code_test() { + // Get mock values for the safe search level values for URL parameters, cookie + // and config. + let safe_search_level_from_url = mock_safe_search_level_value(); + let cookie_safe_search_level = mock_safe_search_level_value().unwrap_or(0); + let config_safe_search_level = mock_safe_search_level_value().unwrap_or(0); + + // Branched code + let safe_search_level_value_from_branched_code = match safe_search_level_from_url { + Some(safe_search_level_from_url_parsed) => { + if config_safe_search_level >= 3 { + config_safe_search_level + } else { + safe_search_level_from_url_parsed + } + } + None => { + if config_safe_search_level >= 3 { + config_safe_search_level + } else { + cookie_safe_search_level + } + } + }; + + // branchless code + let safe_search_level_value_from_branchless_code = + (u8::from(safe_search_level_from_url.is_some()) + * ((u8::from(config_safe_search_level >= 3) * config_safe_search_level) + + (u8::from(config_safe_search_level < 3) + * safe_search_level_from_url.unwrap_or(0)))) + + (u8::from(safe_search_level_from_url.is_none()) + * ((u8::from(config_safe_search_level >= 3) * config_safe_search_level) + + (u8::from(config_safe_search_level < 3) * cookie_safe_search_level))); + + assert_eq!( + safe_search_level_value_from_branched_code, + safe_search_level_value_from_branchless_code + ); } } diff --git a/tests/index.rs b/tests/index.rs index 563c2d9..010795d 100644 --- a/tests/index.rs +++ b/tests/index.rs @@ -1,14 +1,17 @@ -use std::net::TcpListener; +use std::{net::TcpListener, sync::OnceLock}; use websurfx::{config::parser::Config, run, templates::views}; +/// A static constant for holding the parsed config. +static CONFIG: OnceLock = OnceLock::new(); + // Starts a new instance of the HTTP server, bound to a random available port async fn spawn_app() -> String { // Binding to port 0 will trigger the OS to assign a port for us. let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind random port"); let port = listener.local_addr().unwrap().port(); - let config = Config::parse(false).unwrap(); - let cache = websurfx::cache::cacher::create_cache(&config).await; + let config = CONFIG.get_or_init(|| Config::parse(false).unwrap()); + let cache = websurfx::cache::cacher::create_cache(config).await; let server = run(listener, config, cache).expect("Failed to bind address"); tokio::spawn(server);