From b1bcf41061acc8070f12d64191db714b5f4c6d83 Mon Sep 17 00:00:00 2001 From: neon_arch Date: Thu, 5 Sep 2024 22:07:10 +0530 Subject: [PATCH] :zap: perf: replace `Vec` with `Box<[T]>` & refactor `calculate_tfidf` function (#603) --- src/models/aggregation_models.rs | 74 +++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/src/models/aggregation_models.rs b/src/models/aggregation_models.rs index c046b1c..01c67a6 100644 --- a/src/models/aggregation_models.rs +++ b/src/models/aggregation_models.rs @@ -3,7 +3,6 @@ use super::engine_models::EngineError; use serde::{Deserialize, Serialize}; -use smallvec::SmallVec; #[cfg(any( feature = "use-synonyms-search", feature = "use-non-static-synonyms-search" @@ -23,7 +22,7 @@ pub struct SearchResult { /// The description of the search result. pub description: String, /// The names of the upstream engines from which this results were provided. - pub engine: SmallVec<[String; 0]>, + pub engine: Vec, /// The td-tdf score of the result in regards to the title, url and description and the user's query pub relevance_score: f32, } @@ -153,10 +152,10 @@ impl EngineErrorInfo { #[serde(rename_all = "camelCase")] pub struct SearchResults { /// Stores the individual serializable `SearchResult` struct into a vector of - pub results: Vec, + pub results: Box<[SearchResult]>, /// Stores the information on which engines failed with their engine name /// and the type of error that caused it. - pub engine_errors_info: Vec, + pub engine_errors_info: Box<[EngineErrorInfo]>, /// Stores the flag option which holds the check value that the following /// search query was disallowed when the safe search level set to 4 and it /// was present in the `Blocklist` file. @@ -183,10 +182,10 @@ impl SearchResults { /// the search url. /// * `engine_errors_info` - Takes an array of structs which contains information regarding /// which engines failed with their names, reason and their severity color name. - pub fn new(results: Vec, engine_errors_info: &[EngineErrorInfo]) -> Self { + pub fn new(results: Box<[SearchResult]>, engine_errors_info: Box<[EngineErrorInfo]>) -> Self { Self { results, - engine_errors_info: engine_errors_info.to_owned(), + engine_errors_info, disallowed: Default::default(), filtered: Default::default(), safe_search_level: Default::default(), @@ -205,11 +204,11 @@ impl SearchResults { } /// A getter function that gets the value of `engine_errors_info`. - pub fn engine_errors_info(&mut self) -> Vec { + pub fn engine_errors_info(&mut self) -> Box<[EngineErrorInfo]> { std::mem::take(&mut self.engine_errors_info) } /// A getter function that gets the value of `results`. - pub fn results(&mut self) -> Vec { + pub fn results(&mut self) -> Box<[SearchResult]> { self.results.clone() } @@ -254,27 +253,50 @@ fn calculate_tf_idf( let tf_idf = TfIdf::new(params); let tokener = Tokenizer::new(query, stop_words, Some(punctuation)); let query_tokens = tokener.split_into_words(); - let mut search_tokens = vec![]; - for token in query_tokens { - #[cfg(any( - feature = "use-synonyms-search", - feature = "use-non-static-synonyms-search" - ))] - { - // find some synonyms and add them to the search (from wordnet or moby if feature is enabled) - let synonyms = synonyms(&token); - search_tokens.extend(synonyms) - } - search_tokens.push(token); - } + #[cfg(any( + feature = "use-synonyms-search", + feature = "use-non-static-synonyms-search" + ))] + let mut extra_tokens = vec![]; - let mut total_score = 0.0f32; - for token in search_tokens.iter() { - total_score += tf_idf.get_score(token); - } + let total_score: f32 = query_tokens + .iter() + .map(|token| { + #[cfg(any( + feature = "use-synonyms-search", + feature = "use-non-static-synonyms-search" + ))] + { + // find some synonyms and add them to the search (from wordnet or moby if feature is enabled) + extra_tokens.extend(synonyms(token)) + } - let result = total_score / (search_tokens.len() as f32); + tf_idf.get_score(token) + }) + .sum(); + + #[cfg(not(any( + feature = "use-synonyms-search", + feature = "use-non-static-synonyms-search" + )))] + let result = total_score / (query_tokens.len() as f32); + + #[cfg(any( + feature = "use-synonyms-search", + feature = "use-non-static-synonyms-search" + ))] + let extra_total_score: f32 = extra_tokens + .iter() + .map(|token| tf_idf.get_score(token)) + .sum(); + + #[cfg(any( + feature = "use-synonyms-search", + feature = "use-non-static-synonyms-search" + ))] + let result = + (extra_total_score + total_score) / ((query_tokens.len() + extra_tokens.len()) as f32); f32::from(!result.is_nan()) * result }