|
| 1 | +use java_asm::StrRef; |
| 2 | +use nucleo_matcher::pattern::{CaseMatching, Normalization, Pattern}; |
| 3 | +use nucleo_matcher::{Config, Matcher, Utf32Str}; |
| 4 | +use std::cmp::{min, Reverse}; |
| 5 | +use std::sync::Arc; |
| 6 | + |
| 7 | +pub struct FuzzyMatchModel { |
| 8 | + // the query string |
| 9 | + input: StrRef, |
| 10 | + items: Vec<StrRef>, |
| 11 | + // how many results to return |
| 12 | + top_n: usize, |
| 13 | + matcher: Matcher, |
| 14 | + pattern: Pattern, |
| 15 | + // length is the same as input |
| 16 | + inc_infos: IncrementalInfos, |
| 17 | +} |
| 18 | + |
| 19 | +type IncrementalInfos = Vec<Option<IncrementalInfo>>; |
| 20 | + |
| 21 | +// if change `a` to `ab`, just search things from (previous matched items + remaining items). |
| 22 | +// if change `ab` to `a`, reset incremental info to the specific index. |
| 23 | +// in other cases, which means totally different, clear all incremental info. |
| 24 | +#[derive(Debug, Clone)] |
| 25 | +struct IncrementalInfo { |
| 26 | + // previous last index when search stops |
| 27 | + stop_index: usize, |
| 28 | + // previous matched items |
| 29 | + items: Vec<StrRef>, |
| 30 | +} |
| 31 | + |
| 32 | +impl FuzzyMatchModel { |
| 33 | + pub fn new( |
| 34 | + input: StrRef, items: &[StrRef], top_n: usize, |
| 35 | + ) -> Self { |
| 36 | + let config = Config::DEFAULT.match_paths(); |
| 37 | + let matcher = Matcher::new(config); |
| 38 | + let pattern = Pattern::parse(&input, CaseMatching::Ignore, Normalization::Never); |
| 39 | + let items = items.to_vec(); |
| 40 | + let inc_info = vec![None; items.len() + 1]; |
| 41 | + FuzzyMatchModel { |
| 42 | + input, |
| 43 | + items, |
| 44 | + top_n, |
| 45 | + matcher, |
| 46 | + pattern, |
| 47 | + inc_infos: inc_info, |
| 48 | + } |
| 49 | + } |
| 50 | + |
| 51 | + pub fn search_with_new_input(&mut self, new_input: StrRef) -> Vec<StrRef> { |
| 52 | + let old_input = self.input.clone(); |
| 53 | + let old_len = old_input.len(); |
| 54 | + let new_len = new_input.len(); |
| 55 | + // 1. check if same input |
| 56 | + let mut previous_info: Option<IncrementalInfo> = None; |
| 57 | + if let Some(Some(previous)) = self.inc_infos.get(old_len) { |
| 58 | + previous_info = Some(previous.clone()); |
| 59 | + } |
| 60 | + if old_len == new_len && new_input == old_input { |
| 61 | + if let Some(IncrementalInfo { items, .. }) = previous_info { |
| 62 | + // same input, skip search |
| 63 | + return items.clone(); |
| 64 | + } |
| 65 | + } else { |
| 66 | + // change pattern |
| 67 | + self.input = new_input.clone(); |
| 68 | + self.pattern.reparse(&new_input, CaseMatching::Ignore, Normalization::Never); |
| 69 | + } |
| 70 | + // 2. do search by incremental info |
| 71 | + let result: Vec<StrRef> = previous_info |
| 72 | + .map(|inc_info| { |
| 73 | + // if has inc info |
| 74 | + self.inc_case_1(&old_input, &new_input, &inc_info) |
| 75 | + .or_else(|| self.inc_case_2(&old_input, &new_input)) |
| 76 | + .unwrap_or_else(|| self.full_search()) |
| 77 | + }) |
| 78 | + .unwrap_or_else(|| self.full_search()); |
| 79 | + result |
| 80 | + } |
| 81 | + |
| 82 | + // case 1, change `a` to `ab`, just search things from (previous matched items + remaining items). |
| 83 | + // returns `None` means not applicable for this case. |
| 84 | + fn inc_case_1( |
| 85 | + &mut self, old_input: &str, new_input: &str, |
| 86 | + old_inc_info: &IncrementalInfo, |
| 87 | + ) -> Option<Vec<StrRef>> { |
| 88 | + let old_len = old_input.len(); |
| 89 | + let new_len = new_input.len(); |
| 90 | + if new_len < old_len || !new_input.starts_with(old_input) { return None; }; |
| 91 | + |
| 92 | + let IncrementalInfo { stop_index, items: inc_items } = old_inc_info; |
| 93 | + let search_result = self.search_in_ranges(&inc_items, *stop_index); |
| 94 | + let new_inc_info = IncrementalInfo { |
| 95 | + stop_index: search_result.stop_idx, |
| 96 | + items: search_result.items.clone(), |
| 97 | + }; |
| 98 | + self.inc_infos.resize(new_len + 1, None); |
| 99 | + self.inc_infos[new_len] = Some(new_inc_info); |
| 100 | + Some(search_result.items) |
| 101 | + } |
| 102 | + |
| 103 | + // case 2, if change `ab` to `a`, reset incremental info to the specific index. |
| 104 | + fn inc_case_2( |
| 105 | + &mut self, old_input: &str, new_input: &str, |
| 106 | + ) -> Option<Vec<StrRef>> { |
| 107 | + // 1. find common prefix len |
| 108 | + let mut common_prefix_len = 0usize; |
| 109 | + let old_bytes = old_input.as_bytes(); |
| 110 | + let new_bytes = new_input.as_bytes(); |
| 111 | + let min_len = min(old_bytes.len(), new_bytes.len()); |
| 112 | + for i in 0..min_len { |
| 113 | + if old_bytes[i] == new_bytes[i] { |
| 114 | + common_prefix_len = i + 1; |
| 115 | + } else { |
| 116 | + break; |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + // 2. if the inc info in common prefix idx not exists, means not applicable for this case. |
| 121 | + let Some(target_inc_info) = self.inc_infos.get(common_prefix_len)? else { |
| 122 | + return None; |
| 123 | + }; |
| 124 | + |
| 125 | + let result = target_inc_info.items.clone(); |
| 126 | + // 3. resize incremental infos to new length to fit new query. |
| 127 | + self.inc_infos.resize(new_input.len() + 1, None); |
| 128 | + Some(result) |
| 129 | + } |
| 130 | + |
| 131 | + pub fn full_search(&mut self) -> Vec<StrRef> { |
| 132 | + let SearchResult { stop_idx, items } = self.search_in_ranges(&[], 0); |
| 133 | + let new_inc_info = IncrementalInfo { |
| 134 | + stop_index: stop_idx, |
| 135 | + items, |
| 136 | + }; |
| 137 | + self.inc_infos.clear(); |
| 138 | + self.inc_infos.resize(self.input.len() + 1, None); |
| 139 | + self.inc_infos[self.input.len()] = Some(new_inc_info.clone()); |
| 140 | + new_inc_info.items |
| 141 | + } |
| 142 | + |
| 143 | + fn search_in_ranges( |
| 144 | + &mut self, items_1: &[StrRef], items_2_start_idx: usize, |
| 145 | + ) -> SearchResult { |
| 146 | + let Self { top_n, items: all_items, matcher, pattern, .. } = self; |
| 147 | + let top_n = *top_n; |
| 148 | + let mut buf = Vec::new(); |
| 149 | + let mut result: Vec<(&StrRef, u32)> = Vec::with_capacity(top_n); |
| 150 | + |
| 151 | + let mut stop_idx = items_2_start_idx; |
| 152 | + for item in items_1 { |
| 153 | + if result.len() >= top_n { |
| 154 | + break; |
| 155 | + } |
| 156 | + let haystack = Utf32Str::new(item.as_ref(), &mut buf); |
| 157 | + let single_result = pattern.score(haystack, matcher) |
| 158 | + .map(|score| (item, score)); |
| 159 | + let Some(single_result) = single_result else { |
| 160 | + continue; |
| 161 | + }; |
| 162 | + result.push(single_result); |
| 163 | + } |
| 164 | + |
| 165 | + for (idx, item) in all_items.iter().enumerate().skip(stop_idx) { |
| 166 | + if result.len() >= top_n { |
| 167 | + break; |
| 168 | + } |
| 169 | + stop_idx = idx; |
| 170 | + let haystack = Utf32Str::new(item.as_ref(), &mut buf); |
| 171 | + let single_result = pattern.score(haystack, matcher) |
| 172 | + .map(|score| (item, score)); |
| 173 | + let Some(single_result) = single_result else { |
| 174 | + continue; |
| 175 | + }; |
| 176 | + result.push(single_result); |
| 177 | + } |
| 178 | + |
| 179 | + result.sort_by_key(|(_, score)| Reverse(*score)); |
| 180 | + let result_items = result.iter().map(|(item, _)| Arc::clone(item)).collect(); |
| 181 | + SearchResult { |
| 182 | + stop_idx, |
| 183 | + items: result_items, |
| 184 | + } |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +struct SearchResult { |
| 189 | + stop_idx: usize, |
| 190 | + items: Vec<StrRef>, |
| 191 | +} |
| 192 | + |
| 193 | +#[cfg(test)] |
| 194 | +mod tests { |
| 195 | + use crate::impls::fuzzy::FuzzyMatchModel; |
| 196 | + use java_asm::{vec_str_ref, StrRef}; |
| 197 | + use rand::prelude::SliceRandom; |
| 198 | + use rand::rng; |
| 199 | + |
| 200 | + #[test] |
| 201 | + fn test_fuzzy_match() { |
| 202 | + // simple input |
| 203 | + let input: StrRef = "abc".into(); |
| 204 | + let items: Vec<StrRef> = vec_str_ref![ |
| 205 | + "apple/banana", |
| 206 | + "apple/banana/cake", |
| 207 | + ]; |
| 208 | + let mut model = FuzzyMatchModel::new(input, &items, 10); |
| 209 | + let expected_result_1 = vec_str_ref![ |
| 210 | + "apple/banana/cake" |
| 211 | + ]; |
| 212 | + assert_eq!(model.full_search(), expected_result_1); |
| 213 | + |
| 214 | + // update input |
| 215 | + let real_result = model.search_with_new_input("abn".into()); |
| 216 | + let expected_result_2 = vec_str_ref![ |
| 217 | + "apple/banana", |
| 218 | + "apple/banana/cake" |
| 219 | + ]; |
| 220 | + assert_eq!(real_result, expected_result_2); |
| 221 | + |
| 222 | + // not exist |
| 223 | + let real_result = model.search_with_new_input("abcd".into()); |
| 224 | + let expected_result_3 = vec_str_ref![]; |
| 225 | + assert_eq!(real_result, expected_result_3); |
| 226 | + } |
| 227 | + |
| 228 | + #[test] |
| 229 | + fn test_huge_input() { |
| 230 | + let sample_size = 100_000; |
| 231 | + let input: StrRef = "im2z".into(); |
| 232 | + let items: Vec<StrRef> = (0..sample_size).map(|i| |
| 233 | + format!("item/{}/i12k/pc1i1/z", i).into() |
| 234 | + ).collect(); |
| 235 | + let start = std::time::Instant::now(); |
| 236 | + let mut model = FuzzyMatchModel::new(input, &items, 100_000); |
| 237 | + let result = model.full_search(); |
| 238 | + println!("Cost time: {:?}ms for 100K items", start.elapsed().as_millis()); |
| 239 | + assert_eq!(result.len(), sample_size); |
| 240 | + } |
| 241 | + |
| 242 | + #[test] |
| 243 | + fn test_huge_input_take_1000() { |
| 244 | + let sample_size = 50_000; |
| 245 | + let input: StrRef = "im21".into(); |
| 246 | + // items 1: always matched |
| 247 | + let items_1 = (0..sample_size).map(|i| |
| 248 | + format!("item/{}/i12k/pc1i1/z", i).into() |
| 249 | + ); |
| 250 | + // items 2: not matched |
| 251 | + let items_2 = (0..sample_size).map(|i| |
| 252 | + format!("item/{}", i).into() |
| 253 | + ); |
| 254 | + let mut items: Vec<StrRef> = items_1.chain(items_2).collect(); |
| 255 | + items.shuffle(&mut rng()); |
| 256 | + |
| 257 | + let start = std::time::Instant::now(); |
| 258 | + let mut model = FuzzyMatchModel::new(input, &items, 10000); |
| 259 | + let result = model.full_search(); |
| 260 | + println!("Cost time: {:?}ms for take 1000 items", start.elapsed().as_millis()); |
| 261 | + assert_eq!(result.len(), 10000); |
| 262 | + |
| 263 | + let start = std::time::Instant::now(); |
| 264 | + let result = model.search_with_new_input("im21z".into()); |
| 265 | + println!("Cost time: {:?}ms for take 1000 items next time", start.elapsed().as_millis()); |
| 266 | + assert_eq!(result.len(), 10000); |
| 267 | + } |
| 268 | +} |
| 269 | + |
0 commit comments