0
0
mirror of https://github.com/neon-mmd/websurfx.git synced 2024-12-21 20:08:21 -05:00

perf: replace Vec<T> with Box<[T]> & refactor calculate_tfidf function (#603)

This commit is contained in:
neon_arch 2024-09-05 22:07:10 +05:30
parent 39af9096ef
commit b1bcf41061

View File

@ -3,7 +3,6 @@
use super::engine_models::EngineError;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
#[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
@ -23,7 +22,7 @@ pub struct SearchResult {
/// The description of the search result.
pub description: String,
/// The names of the upstream engines from which this results were provided.
pub engine: SmallVec<[String; 0]>,
pub engine: Vec<String>,
/// The td-tdf score of the result in regards to the title, url and description and the user's query
pub relevance_score: f32,
}
@ -153,10 +152,10 @@ impl EngineErrorInfo {
#[serde(rename_all = "camelCase")]
pub struct SearchResults {
/// Stores the individual serializable `SearchResult` struct into a vector of
pub results: Vec<SearchResult>,
pub results: Box<[SearchResult]>,
/// Stores the information on which engines failed with their engine name
/// and the type of error that caused it.
pub engine_errors_info: Vec<EngineErrorInfo>,
pub engine_errors_info: Box<[EngineErrorInfo]>,
/// Stores the flag option which holds the check value that the following
/// search query was disallowed when the safe search level set to 4 and it
/// was present in the `Blocklist` file.
@ -183,10 +182,10 @@ impl SearchResults {
/// the search url.
/// * `engine_errors_info` - Takes an array of structs which contains information regarding
/// which engines failed with their names, reason and their severity color name.
pub fn new(results: Vec<SearchResult>, engine_errors_info: &[EngineErrorInfo]) -> Self {
pub fn new(results: Box<[SearchResult]>, engine_errors_info: Box<[EngineErrorInfo]>) -> Self {
Self {
results,
engine_errors_info: engine_errors_info.to_owned(),
engine_errors_info,
disallowed: Default::default(),
filtered: Default::default(),
safe_search_level: Default::default(),
@ -205,11 +204,11 @@ impl SearchResults {
}
/// A getter function that gets the value of `engine_errors_info`.
pub fn engine_errors_info(&mut self) -> Vec<EngineErrorInfo> {
pub fn engine_errors_info(&mut self) -> Box<[EngineErrorInfo]> {
std::mem::take(&mut self.engine_errors_info)
}
/// A getter function that gets the value of `results`.
pub fn results(&mut self) -> Vec<SearchResult> {
pub fn results(&mut self) -> Box<[SearchResult]> {
self.results.clone()
}
@ -254,27 +253,50 @@ fn calculate_tf_idf(
let tf_idf = TfIdf::new(params);
let tokener = Tokenizer::new(query, stop_words, Some(punctuation));
let query_tokens = tokener.split_into_words();
let mut search_tokens = vec![];
for token in query_tokens {
#[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
))]
{
// find some synonyms and add them to the search (from wordnet or moby if feature is enabled)
let synonyms = synonyms(&token);
search_tokens.extend(synonyms)
}
search_tokens.push(token);
}
#[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
))]
let mut extra_tokens = vec![];
let mut total_score = 0.0f32;
for token in search_tokens.iter() {
total_score += tf_idf.get_score(token);
}
let total_score: f32 = query_tokens
.iter()
.map(|token| {
#[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
))]
{
// find some synonyms and add them to the search (from wordnet or moby if feature is enabled)
extra_tokens.extend(synonyms(token))
}
let result = total_score / (search_tokens.len() as f32);
tf_idf.get_score(token)
})
.sum();
#[cfg(not(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
)))]
let result = total_score / (query_tokens.len() as f32);
#[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
))]
let extra_total_score: f32 = extra_tokens
.iter()
.map(|token| tf_idf.get_score(token))
.sum();
#[cfg(any(
feature = "use-synonyms-search",
feature = "use-non-static-synonyms-search"
))]
let result =
(extra_total_score + total_score) / ((query_tokens.len() + extra_tokens.len()) as f32);
f32::from(!result.is_nan()) * result
}