0
0
mirror of https://github.com/neon-mmd/websurfx.git synced 2024-11-23 22:48:21 -05:00

perf: replace Vec<T> with Box<[T]> & refactor calculate_tfidf function (#603)

This commit is contained in:
neon_arch 2024-09-05 22:07:10 +05:30
parent 39af9096ef
commit b1bcf41061

View File

@ -3,7 +3,6 @@
use super::engine_models::EngineError; use super::engine_models::EngineError;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
#[cfg(any( #[cfg(any(
feature = "use-synonyms-search", feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search" feature = "use-non-static-synonyms-search"
@ -23,7 +22,7 @@ pub struct SearchResult {
/// The description of the search result. /// The description of the search result.
pub description: String, pub description: String,
/// The names of the upstream engines from which this results were provided. /// The names of the upstream engines from which this results were provided.
pub engine: SmallVec<[String; 0]>, pub engine: Vec<String>,
/// The td-tdf score of the result in regards to the title, url and description and the user's query /// The td-tdf score of the result in regards to the title, url and description and the user's query
pub relevance_score: f32, pub relevance_score: f32,
} }
@ -153,10 +152,10 @@ impl EngineErrorInfo {
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct SearchResults { pub struct SearchResults {
/// Stores the individual serializable `SearchResult` struct into a vector of /// Stores the individual serializable `SearchResult` struct into a vector of
pub results: Vec<SearchResult>, pub results: Box<[SearchResult]>,
/// Stores the information on which engines failed with their engine name /// Stores the information on which engines failed with their engine name
/// and the type of error that caused it. /// and the type of error that caused it.
pub engine_errors_info: Vec<EngineErrorInfo>, pub engine_errors_info: Box<[EngineErrorInfo]>,
/// Stores the flag option which holds the check value that the following /// Stores the flag option which holds the check value that the following
/// search query was disallowed when the safe search level set to 4 and it /// search query was disallowed when the safe search level set to 4 and it
/// was present in the `Blocklist` file. /// was present in the `Blocklist` file.
@ -183,10 +182,10 @@ impl SearchResults {
/// the search url. /// the search url.
/// * `engine_errors_info` - Takes an array of structs which contains information regarding /// * `engine_errors_info` - Takes an array of structs which contains information regarding
/// which engines failed with their names, reason and their severity color name. /// which engines failed with their names, reason and their severity color name.
pub fn new(results: Vec<SearchResult>, engine_errors_info: &[EngineErrorInfo]) -> Self { pub fn new(results: Box<[SearchResult]>, engine_errors_info: Box<[EngineErrorInfo]>) -> Self {
Self { Self {
results, results,
engine_errors_info: engine_errors_info.to_owned(), engine_errors_info,
disallowed: Default::default(), disallowed: Default::default(),
filtered: Default::default(), filtered: Default::default(),
safe_search_level: Default::default(), safe_search_level: Default::default(),
@ -205,11 +204,11 @@ impl SearchResults {
} }
/// A getter function that gets the value of `engine_errors_info`. /// A getter function that gets the value of `engine_errors_info`.
pub fn engine_errors_info(&mut self) -> Vec<EngineErrorInfo> { pub fn engine_errors_info(&mut self) -> Box<[EngineErrorInfo]> {
std::mem::take(&mut self.engine_errors_info) std::mem::take(&mut self.engine_errors_info)
} }
/// A getter function that gets the value of `results`. /// A getter function that gets the value of `results`.
pub fn results(&mut self) -> Vec<SearchResult> { pub fn results(&mut self) -> Box<[SearchResult]> {
self.results.clone() self.results.clone()
} }
@ -254,27 +253,50 @@ fn calculate_tf_idf(
let tf_idf = TfIdf::new(params); let tf_idf = TfIdf::new(params);
let tokener = Tokenizer::new(query, stop_words, Some(punctuation)); let tokener = Tokenizer::new(query, stop_words, Some(punctuation));
let query_tokens = tokener.split_into_words(); let query_tokens = tokener.split_into_words();
let mut search_tokens = vec![];
for token in query_tokens { #[cfg(any(
#[cfg(any( feature = "use-synonyms-search",
feature = "use-synonyms-search", feature = "use-non-static-synonyms-search"
feature = "use-non-static-synonyms-search" ))]
))] let mut extra_tokens = vec![];
{
// find some synonyms and add them to the search (from wordnet or moby if feature is enabled)
let synonyms = synonyms(&token);
search_tokens.extend(synonyms)
}
search_tokens.push(token);
}
let mut total_score = 0.0f32; let total_score: f32 = query_tokens
for token in search_tokens.iter() { .iter()
total_score += tf_idf.get_score(token); .map(|token| {
} #[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
))]
{
// find some synonyms and add them to the search (from wordnet or moby if feature is enabled)
extra_tokens.extend(synonyms(token))
}
let result = total_score / (search_tokens.len() as f32); tf_idf.get_score(token)
})
.sum();
#[cfg(not(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
)))]
let result = total_score / (query_tokens.len() as f32);
#[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
))]
let extra_total_score: f32 = extra_tokens
.iter()
.map(|token| tf_idf.get_score(token))
.sum();
#[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
))]
let result =
(extra_total_score + total_score) / ((query_tokens.len() + extra_tokens.len()) as f32);
f32::from(!result.is_nan()) * result f32::from(!result.is_nan()) * result
} }