From f0ba265bbb0a0b6c4755500445478368f6dacb8f Mon Sep 17 00:00:00 2001 From: David Blewett Date: Wed, 12 Sep 2018 14:18:28 -0400 Subject: [PATCH 1/6] Add UnionSet class: * Supports range queries across multiple sets: a = Set.from_iter(["bar", "foo"]) b = Set.from_iter(["baz", "foo"]) list(UnionSet(a, b)['ba':'bb']) ['bar', 'baz'] * Add StreamBuilder, to correspond with the Rust library's abstraction * Update OpBuilder to support constructing operations against multiple underlying types (Set and StreamBuilder for now) --- rust/rust_fst.h | 4 ++ rust/src/set.rs | 31 +++++++++ rust_fst/__init__.py | 4 +- rust_fst/set.py | 149 ++++++++++++++++++++++++++++++++++++++----- tests/test_set.py | 34 +++++++++- 5 files changed, 204 insertions(+), 18 deletions(-) diff --git a/rust/rust_fst.h b/rust/rust_fst.h index 401a94e..15bfc2e 100644 --- a/rust/rust_fst.h +++ b/rust/rust_fst.h @@ -64,6 +64,7 @@ SetStream* fst_set_stream(Set*); SetLevStream* fst_set_levsearch(Set*, Levenshtein*); SetRegexStream* fst_set_regexsearch(Set*, Regex*); SetOpBuilder* fst_set_make_opbuilder(Set*); +SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*); void fst_set_free(Set*); char* fst_set_stream_next(SetStream*); @@ -76,6 +77,7 @@ char* fst_set_regexstream_next(SetRegexStream*); void fst_set_regexstream_free(SetRegexStream*); void fst_set_opbuilder_push(SetOpBuilder*, Set*); +void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*); void fst_set_opbuilder_free(SetOpBuilder*); SetUnion* fst_set_opbuilder_union(SetOpBuilder*); SetIntersection* fst_set_opbuilder_intersection(SetOpBuilder*); @@ -97,6 +99,8 @@ void fst_set_symmetricdifference_free(SetSymmetricDifference*); SetStreamBuilder* fst_set_streambuilder_new(Set*); SetStreamBuilder* fst_set_streambuilder_add_ge(SetStreamBuilder*, char*); +SetStreamBuilder* fst_set_streambuilder_add_gt(SetStreamBuilder*, char*); +SetStreamBuilder* fst_set_streambuilder_add_le(SetStreamBuilder*, char*); SetStreamBuilder* fst_set_streambuilder_add_lt(SetStreamBuilder*, char*); SetStream* fst_set_streambuilder_finish(SetStreamBuilder*); diff --git a/rust/src/set.rs b/rust/src/set.rs index bd28aa1..1c87ac1 100644 --- a/rust/src/set.rs +++ b/rust/src/set.rs @@ -146,6 +146,14 @@ pub extern "C" fn fst_set_make_opbuilder(ptr: *mut Set) -> *mut set::OpBuilder<' } make_free_fn!(fst_set_opbuilder_free, *mut set::OpBuilder); +#[no_mangle] +pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuilder<'static>) -> *mut set::OpBuilder<'static> { + let sb = val_from_ptr!(ptr); + let mut ob = set::OpBuilder::new(); + ob.push(sb.into_stream()); + to_raw_ptr(ob) +} + #[no_mangle] pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut Set) { let set = ref_from_ptr!(set_ptr); @@ -153,6 +161,13 @@ pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut ob.push(set); } +#[no_mangle] +pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder<'static>, sb_ptr: *mut set::StreamBuilder<'static>) { + let sb = val_from_ptr!(sb_ptr); + let ob = mutref_from_ptr!(ptr); + ob.push(sb.into_stream()); +} + #[no_mangle] pub extern "C" fn fst_set_opbuilder_union(ptr: *mut set::OpBuilder) -> *mut set::Union { @@ -205,6 +220,22 @@ pub extern "C" fn fst_set_streambuilder_add_ge(ptr: *mut set::StreamBuilder<'sta to_raw_ptr(sb.ge(cstr_to_str(c_bound))) } +#[no_mangle] +pub extern "C" fn fst_set_streambuilder_add_gt(ptr: *mut set::StreamBuilder<'static>, + c_bound: *mut libc::c_char) + -> *mut set::StreamBuilder<'static> { + let sb = val_from_ptr!(ptr); + to_raw_ptr(sb.gt(cstr_to_str(c_bound))) +} + +#[no_mangle] +pub extern "C" fn fst_set_streambuilder_add_le(ptr: *mut set::StreamBuilder<'static>, + c_bound: *mut libc::c_char) + -> *mut set::StreamBuilder<'static> { + let sb = val_from_ptr!(ptr); + to_raw_ptr(sb.le(cstr_to_str(c_bound))) +} + #[no_mangle] pub extern "C" fn fst_set_streambuilder_add_lt(ptr: *mut set::StreamBuilder<'static>, c_bound: *mut libc::c_char) diff --git a/rust_fst/__init__.py b/rust_fst/__init__.py index 5dd3cfb..04de867 100644 --- a/rust_fst/__init__.py +++ b/rust_fst/__init__.py @@ -1,4 +1,4 @@ -from .set import Set +from .set import Set, UnionSet from .map import Map -__all__ = ["Set", "Map"] +__all__ = ["Set", "UnionSet", "Map"] diff --git a/rust_fst/set.py b/rust_fst/set.py index cc08541..32b86e4 100644 --- a/rust_fst/set.py +++ b/rust_fst/set.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +from enum import Enum from .common import KeyStreamIterator from .lib import ffi, lib, checked_call @@ -55,14 +56,40 @@ def get_set(self): return Set(None, _pointer=self._set_ptr) +class OpBuilderInputType(Enum): + SET = 1 + STREAM_BUILDER = 2 + + class OpBuilder(object): - def __init__(self, set_ptr): + + _BUILDERS = { + OpBuilderInputType.SET: lib.fst_set_make_opbuilder, + OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder, + } + _PUSHERS = { + OpBuilderInputType.SET: lib.fst_set_opbuilder_push, + OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder, + } + + @classmethod + def from_slice(cls, set_ptr, s): + sb = StreamBuilder.from_slice(set_ptr, s) + opbuilder = OpBuilder(sb._ptr, + input_type=OpBuilderInputType.STREAM_BUILDER) + return opbuilder + + def __init__(self, ptr, input_type=OpBuilderInputType.SET): + if input_type not in self._BUILDERS: + raise ValueError( + "input_type must be a member of OpBuilderInputType.") + self._input_type = input_type # NOTE: No need for `ffi.gc`, since the struct will be free'd # once we call union/intersection/difference - self._ptr = lib.fst_set_make_opbuilder(set_ptr) + self._ptr = OpBuilder._BUILDERS[self._input_type](ptr) - def push(self, set_ptr): - lib.fst_set_opbuilder_push(self._ptr, set_ptr) + def push(self, ptr): + OpBuilder._PUSHERS[self._input_type](self._ptr, ptr) def union(self): stream_ptr = lib.fst_set_opbuilder_union(self._ptr) @@ -86,6 +113,44 @@ def symmetric_difference(self): lib.fst_set_symmetricdifference_free) +class StreamBuilder(object): + + @classmethod + def from_slice(cls, set_ptr, slice_bounds): + sb = StreamBuilder(set_ptr) + if slice_bounds.start: + sb.ge(slice_bounds.start) + if slice_bounds.stop: + sb.lt(slice_bounds.stop) + return sb + + def __init__(self, set_ptr): + # NOTE: No need for `ffi.gc`, since the struct will be free'd + # once we call union/intersection/difference + self._ptr = lib.fst_set_streambuilder_new(set_ptr) + + def finish(self): + stream_ptr = lib.fst_set_streambuilder_finish(self._ptr) + return KeyStreamIterator(stream_ptr, lib.fst_set_stream_next, + lib.fst_set_stream_free) + + def ge(self, bound): + c_start = ffi.new("char[]", bound.encode('utf8')) + self._ptr = lib.fst_set_streambuilder_add_ge(self._ptr, c_start) + + def gt(self, bound): + c_start = ffi.new("char[]", bound.encode('utf8')) + self._ptr = lib.fst_set_streambuilder_add_gt(self._ptr, c_start) + + def le(self, bound): + c_end = ffi.new("char[]", bound.encode('utf8')) + self._ptr = lib.fst_set_streambuilder_add_le(self._ptr, c_end) + + def lt(self, bound): + c_end = ffi.new("char[]", bound.encode('utf8')) + self._ptr = lib.fst_set_streambuilder_add_lt(self._ptr, c_end) + + class Set(object): """ An immutable ordered string set backed by a finite state transducer. @@ -203,19 +268,11 @@ def __getitem__(self, s): if s.start and s.stop and s.start > s.stop: raise ValueError( "Start key must be lexicographically smaller than stop.") - sb_ptr = lib.fst_set_streambuilder_new(self._ptr) - if s.start: - c_start = ffi.new("char[]", s.start.encode('utf8')) - sb_ptr = lib.fst_set_streambuilder_add_ge(sb_ptr, c_start) - if s.stop: - c_stop = ffi.new("char[]", s.stop.encode('utf8')) - sb_ptr = lib.fst_set_streambuilder_add_lt(sb_ptr, c_stop) - stream_ptr = lib.fst_set_streambuilder_finish(sb_ptr) - return KeyStreamIterator(stream_ptr, lib.fst_set_stream_next, - lib.fst_set_stream_free) + sb = StreamBuilder.from_slice(self._ptr, s) + return sb.finish() def _make_opbuilder(self, *others): - opbuilder = OpBuilder(self._ptr) + opbuilder = OpBuilder(self._ptr, input_type=OpBuilderInputType.SET) for oth in others: opbuilder.push(oth._ptr) return opbuilder @@ -333,3 +390,65 @@ def search(self, term, max_dist): return KeyStreamIterator(stream_ptr, lib.fst_set_levstream_next, lib.fst_set_levstream_free, lev_ptr, lib.fst_levenshtein_free) + + +class UnionSet(object): + """ A collection of Set objects that offer efficient operations across all + members. + """ + def __init__(self, *sets): + self.sets = list(sets) + + def __contains__(self, val): + """ Check if the set contains the value. """ + return any([ + lib.fst_set_contains(fst._ptr, + ffi.new("char[]", + val.encode('utf8'))) + for fst in self.sets + ]) + + def __getitem__(self, s): + """ Get an iterator over a range of set contents. + + Start and stop indices of the slice must be unicode strings. + + .. important:: + Slicing follows the semantics for numerical indices, i.e. the + `stop` value is **exclusive**. For example, given the set + `s = Set.from_iter(["bar", "baz", "foo", "moo"])`, `s['b': 'f']` + will only return `"bar"` and `"baz"`. + + :param s: A slice that specifies the range of the set to retrieve + :type s: :py:class:`slice` + """ + if not isinstance(s, slice): + raise ValueError( + "Value must be a string slice (e.g. `['foo':]`)") + if s.start and s.stop and s.start > s.stop: + raise ValueError( + "Start key must be lexicographically smaller than stop.") + if len(self.sets) <= 1: + raise ValueError( + "Must have more than one set to operate on.") + + opbuilder = OpBuilder.from_slice(self.sets[0]._ptr, s) + streams = [] + for fst in self.sets[1:]: + sb = StreamBuilder.from_slice(fst._ptr, s) + streams.append(sb) + for sb in streams: + opbuilder.push(sb._ptr) + return opbuilder.union() + + def __iter__(self): + """ Get an iterator over all keys in all sets in lexicographical order. + """ + if len(self.sets) <= 1: + raise ValueError( + "Must have more than one set to operate on.") + opbuilder = OpBuilder(self.sets[0]._ptr, + input_type=OpBuilderInputType.SET) + for fst in self.sets[1:]: + opbuilder.push(fst._ptr) + return opbuilder.union() diff --git a/tests/test_set.py b/tests/test_set.py index 509b412..47581cf 100644 --- a/tests/test_set.py +++ b/tests/test_set.py @@ -2,10 +2,11 @@ import pytest import rust_fst.lib as lib -from rust_fst import Set +from rust_fst import Set, UnionSet TEST_KEYS = [u"möö", "bar", "baz", "foo"] +TEST_KEYS2 = ["bing", "baz", "bap", "foo"] def do_build(path, keys=TEST_KEYS, sorted_=True): @@ -21,6 +22,17 @@ def fst_set(tmpdir): return Set(fst_path) +@pytest.fixture +def fst_unionset(tmpdir): + fst_path1 = str(tmpdir.join('test1.fst')) + fst_path2 = str(tmpdir.join('test2.fst')) + do_build(fst_path1, keys=TEST_KEYS) + do_build(fst_path2, keys=TEST_KEYS2) + a = Set(fst_path1) + b = Set(fst_path2) + return UnionSet(a, b) + + def test_build(tmpdir): fst_path = tmpdir.join('test.fst') do_build(str(fst_path)) @@ -147,3 +159,23 @@ def test_range(fst_set): fst_set['c':'a'] with pytest.raises(ValueError): fst_set['c'] + + +def test_unionset_contains(fst_unionset): + for key in TEST_KEYS+TEST_KEYS2: + assert key in fst_unionset + + +def test_unionset_iter(fst_unionset): + stored_keys = list(fst_unionset) + assert stored_keys == sorted(set(TEST_KEYS+TEST_KEYS2)) + + +def test_unionset_range(fst_unionset): + assert list(fst_unionset['f':]) == ['foo', u'möö'] + assert list(fst_unionset[:'m']) == ['bap', 'bar', 'baz', 'bing', 'foo'] + assert list(fst_unionset['baz':'m']) == ['baz', 'bing', 'foo'] + with pytest.raises(ValueError): + fst_unionset['c':'a'] + with pytest.raises(ValueError): + fst_unionset['c'] From 358d5a8ae54749f468e6de08fe127c3379fa3d61 Mon Sep 17 00:00:00 2001 From: David Blewett Date: Wed, 12 Sep 2018 14:36:57 -0400 Subject: [PATCH 2/6] Add support for set operations with UnionSet: * difference, intersection, symmetric_difference, union --- rust/rust_fst.h | 2 ++ rust/src/set.rs | 15 +++++++++++ rust_fst/set.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++ tests/test_set.py | 28 ++++++++++++++++++++ 4 files changed, 111 insertions(+) diff --git a/rust/rust_fst.h b/rust/rust_fst.h index 15bfc2e..cc53bf0 100644 --- a/rust/rust_fst.h +++ b/rust/rust_fst.h @@ -65,6 +65,7 @@ SetLevStream* fst_set_levsearch(Set*, Levenshtein*); SetRegexStream* fst_set_regexsearch(Set*, Regex*); SetOpBuilder* fst_set_make_opbuilder(Set*); SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*); +SetOpBuilder* fst_set_make_opbuilder_union(SetUnion*); void fst_set_free(Set*); char* fst_set_stream_next(SetStream*); @@ -78,6 +79,7 @@ void fst_set_regexstream_free(SetRegexStream*); void fst_set_opbuilder_push(SetOpBuilder*, Set*); void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*); +void fst_set_opbuilder_push_union(SetOpBuilder*, SetUnion*); void fst_set_opbuilder_free(SetOpBuilder*); SetUnion* fst_set_opbuilder_union(SetOpBuilder*); SetIntersection* fst_set_opbuilder_intersection(SetOpBuilder*); diff --git a/rust/src/set.rs b/rust/src/set.rs index 1c87ac1..b3394fc 100644 --- a/rust/src/set.rs +++ b/rust/src/set.rs @@ -154,6 +154,14 @@ pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuil to_raw_ptr(ob) } +#[no_mangle] +pub extern "C" fn fst_set_make_opbuilder_union(ptr: *mut set::Union<'static>) -> *mut set::OpBuilder<'static> { + let union = val_from_ptr!(ptr); + let mut ob = set::OpBuilder::new(); + ob.push(union.into_stream()); + to_raw_ptr(ob) +} + #[no_mangle] pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut Set) { let set = ref_from_ptr!(set_ptr); @@ -168,6 +176,13 @@ pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder< ob.push(sb.into_stream()); } +#[no_mangle] +pub extern "C" fn fst_set_opbuilder_push_union(ptr: *mut set::OpBuilder<'static>, union_ptr: *mut set::Union<'static>) { + let union = val_from_ptr!(union_ptr); + let ob = mutref_from_ptr!(ptr); + ob.push(union.into_stream()); +} + #[no_mangle] pub extern "C" fn fst_set_opbuilder_union(ptr: *mut set::OpBuilder) -> *mut set::Union { diff --git a/rust_fst/set.py b/rust_fst/set.py index 32b86e4..2bca368 100644 --- a/rust_fst/set.py +++ b/rust_fst/set.py @@ -59,6 +59,7 @@ def get_set(self): class OpBuilderInputType(Enum): SET = 1 STREAM_BUILDER = 2 + UNION = 3 class OpBuilder(object): @@ -66,10 +67,12 @@ class OpBuilder(object): _BUILDERS = { OpBuilderInputType.SET: lib.fst_set_make_opbuilder, OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder, + OpBuilderInputType.UNION: lib.fst_set_make_opbuilder_union, } _PUSHERS = { OpBuilderInputType.SET: lib.fst_set_opbuilder_push, OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder, + OpBuilderInputType.UNION: lib.fst_set_opbuilder_push_union, } @classmethod @@ -452,3 +455,66 @@ def __iter__(self): for fst in self.sets[1:]: opbuilder.push(fst._ptr) return opbuilder.union() + + def _make_opbuilder(self, *others): + others = list(others) + if len(self.sets) <= 1: + raise ValueError( + "Must have more than one set to operate on.") + if not others: + raise ValueError( + "Must have at least one set to compare against.") + our_opbuilder = OpBuilder(self.sets[0]._ptr, + input_type=OpBuilderInputType.SET) + for fst in self.sets[1:]: + our_opbuilder.push(fst._ptr) + our_stream = lib.fst_set_opbuilder_union(our_opbuilder._ptr) + + their_opbuilder = OpBuilder(others.pop()._ptr, + input_type=OpBuilderInputType.SET) + for fst in others: + their_opbuilder.push(fst._ptr) + their_stream = lib.fst_set_opbuilder_union(their_opbuilder._ptr) + + opbuilder = OpBuilder(our_stream, input_type=OpBuilderInputType.UNION) + opbuilder.push(their_stream) + return opbuilder + + def difference(self, *others): + """ Get an iterator over the keys in the difference of this set and + others. + + :param others: List of :py:class:`Set` objects + :returns: Iterator over all keys that exists in this set, but in + none of the other sets, in lexicographical order + """ + return self._make_opbuilder(*others).difference() + + def intersection(self, *others): + """ Get an iterator over the keys in the intersection of this set and + others. + + :param others: List of :py:class:`Set` objects + :returns: Iterator over all keys that exists in all of the passed + sets in lexicographical order + """ + return self._make_opbuilder(*others).intersection() + + def symmetric_difference(self, *others): + """ Get an iterator over the keys in the symmetric difference of this + set and others. + + :param others: List of :py:class:`Set` objects + :returns: Iterator over all keys that exists in only one of the + sets in lexicographical order + """ + return self._make_opbuilder(*others).symmetric_difference() + + def union(self, *others): + """ Get an iterator over the keys in the union of this set and others. + + :param others: List of :py:class:`Set` objects + :returns: Iterator over all keys in all sets in lexicographical + order + """ + return self._make_opbuilder(*others).union() diff --git a/tests/test_set.py b/tests/test_set.py index 47581cf..0854b1c 100644 --- a/tests/test_set.py +++ b/tests/test_set.py @@ -166,6 +166,20 @@ def test_unionset_contains(fst_unionset): assert key in fst_unionset +def test_unionset_difference(): + a = Set.from_iter(["bar", "foo"]) + b = Set.from_iter(["baz", "foo"]) + c = Set.from_iter(["bonk", "foo"]) + assert list(UnionSet(a, b).difference(c)) == ["bar", "baz"] + + +def test_unionset_intersection(): + a = Set.from_iter(["bar", "foo"]) + b = Set.from_iter(["baz", "foo"]) + c = Set.from_iter(["bonk", "foo"]) + assert list(UnionSet(a, b).intersection(c)) == ["foo"] + + def test_unionset_iter(fst_unionset): stored_keys = list(fst_unionset) assert stored_keys == sorted(set(TEST_KEYS+TEST_KEYS2)) @@ -179,3 +193,17 @@ def test_unionset_range(fst_unionset): fst_unionset['c':'a'] with pytest.raises(ValueError): fst_unionset['c'] + + +def test_unionset_symmetric_difference(): + a = Set.from_iter(["bar", "foo"]) + b = Set.from_iter(["baz", "foo"]) + c = Set.from_iter(["bonk", "foo"]) + assert list(UnionSet(a, b).symmetric_difference(c)) == ["bar", "baz", "bonk"] + + +def test_unionset_union(): + a = Set.from_iter(["bar", "foo"]) + b = Set.from_iter(["baz", "foo"]) + c = Set.from_iter(["bonk", "foo"]) + assert list(UnionSet(a, b).union(c)) == ["bar", "baz", "bonk", "foo"] From c82a480122e6984ff73d4620fc2a3506110cdfc4 Mon Sep 17 00:00:00 2001 From: David Blewett Date: Wed, 12 Sep 2018 14:41:41 -0400 Subject: [PATCH 3/6] Add support for Levenshtein fuzzy search to UnionSet. --- rust/rust_fst.h | 2 ++ rust/src/set.rs | 15 +++++++++++++++ rust_fst/set.py | 35 +++++++++++++++++++++++++++++++++++ tests/test_set.py | 5 +++++ 4 files changed, 57 insertions(+) diff --git a/rust/rust_fst.h b/rust/rust_fst.h index cc53bf0..233366a 100644 --- a/rust/rust_fst.h +++ b/rust/rust_fst.h @@ -65,6 +65,7 @@ SetLevStream* fst_set_levsearch(Set*, Levenshtein*); SetRegexStream* fst_set_regexsearch(Set*, Regex*); SetOpBuilder* fst_set_make_opbuilder(Set*); SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*); +SetOpBuilder* fst_set_make_opbuilder_levstream(SetLevStream*); SetOpBuilder* fst_set_make_opbuilder_union(SetUnion*); void fst_set_free(Set*); @@ -78,6 +79,7 @@ char* fst_set_regexstream_next(SetRegexStream*); void fst_set_regexstream_free(SetRegexStream*); void fst_set_opbuilder_push(SetOpBuilder*, Set*); +void fst_set_opbuilder_push_levstream(SetOpBuilder*, SetLevStream*); void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*); void fst_set_opbuilder_push_union(SetOpBuilder*, SetUnion*); void fst_set_opbuilder_free(SetOpBuilder*); diff --git a/rust/src/set.rs b/rust/src/set.rs index b3394fc..1767e16 100644 --- a/rust/src/set.rs +++ b/rust/src/set.rs @@ -146,6 +146,14 @@ pub extern "C" fn fst_set_make_opbuilder(ptr: *mut Set) -> *mut set::OpBuilder<' } make_free_fn!(fst_set_opbuilder_free, *mut set::OpBuilder); +#[no_mangle] +pub extern "C" fn fst_set_make_opbuilder_levstream(ptr: *mut SetLevStream) -> *mut set::OpBuilder<'static> { + let sls = val_from_ptr!(ptr); + let mut ob = set::OpBuilder::new(); + ob.push(sls.into_stream()); + to_raw_ptr(ob) +} + #[no_mangle] pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuilder<'static>) -> *mut set::OpBuilder<'static> { let sb = val_from_ptr!(ptr); @@ -169,6 +177,13 @@ pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut ob.push(set); } +#[no_mangle] +pub extern "C" fn fst_set_opbuilder_push_levstream(ptr: *mut set::OpBuilder<'static>, sls_ptr: *mut SetLevStream) { + let sls = val_from_ptr!(sls_ptr); + let ob = mutref_from_ptr!(ptr); + ob.push(sls.into_stream()); +} + #[no_mangle] pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder<'static>, sb_ptr: *mut set::StreamBuilder<'static>) { let sb = val_from_ptr!(sb_ptr); diff --git a/rust_fst/set.py b/rust_fst/set.py index 2bca368..5a42bbc 100644 --- a/rust_fst/set.py +++ b/rust_fst/set.py @@ -60,6 +60,16 @@ class OpBuilderInputType(Enum): SET = 1 STREAM_BUILDER = 2 UNION = 3 + SEARCH = 4 + + +def _build_levsearch(fst, term, max_dist): + lev_ptr = checked_call( + lib.fst_levenshtein_new, + fst._ctx, + ffi.new("char[]", term.encode('utf8')), + max_dist) + return lib.fst_set_levsearch(fst._ptr, lev_ptr) class OpBuilder(object): @@ -68,13 +78,22 @@ class OpBuilder(object): OpBuilderInputType.SET: lib.fst_set_make_opbuilder, OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder, OpBuilderInputType.UNION: lib.fst_set_make_opbuilder_union, + OpBuilderInputType.SEARCH: lib.fst_set_make_opbuilder_levstream, } _PUSHERS = { OpBuilderInputType.SET: lib.fst_set_opbuilder_push, OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder, OpBuilderInputType.UNION: lib.fst_set_opbuilder_push_union, + OpBuilderInputType.SEARCH: lib.fst_set_opbuilder_push_levstream, } + @classmethod + def from_search(cls, fst, term, max_dist): + stream_ptr = _build_levsearch(fst, term, max_dist) + opbuilder = OpBuilder(stream_ptr, + input_type=OpBuilderInputType.SEARCH) + return opbuilder + @classmethod def from_slice(cls, set_ptr, s): sb = StreamBuilder.from_slice(set_ptr, s) @@ -500,6 +519,22 @@ def intersection(self, *others): """ return self._make_opbuilder(*others).intersection() + def search(self, term, max_dist): + """ Search the set with a Levenshtein automaton. + + :param term: The search term + :param max_dist: The maximum edit distance for search results + :returns: Iterator over matching values in the set + :rtype: :py:class:`KeyStreamIterator` + """ + if len(self.sets) <= 1: + raise ValueError( + "Must have more than one set to operate on.") + opbuilder = OpBuilder.from_search(self.sets[0], term, max_dist) + for fst in self.sets[1:]: + opbuilder.push(_build_levsearch(fst, term, max_dist)) + return opbuilder.union() + def symmetric_difference(self, *others): """ Get an iterator over the keys in the symmetric difference of this set and others. diff --git a/tests/test_set.py b/tests/test_set.py index 0854b1c..4402543 100644 --- a/tests/test_set.py +++ b/tests/test_set.py @@ -195,6 +195,11 @@ def test_unionset_range(fst_unionset): fst_unionset['c'] +def test_unionset_search(fst_unionset): + matches = list(fst_unionset.search("bam", 1)) + assert matches == ["bap", "bar", "baz"] + + def test_unionset_symmetric_difference(): a = Set.from_iter(["bar", "foo"]) b = Set.from_iter(["baz", "foo"]) From e745f2585effe6fed176411aff1c75f2a92bb411 Mon Sep 17 00:00:00 2001 From: David Blewett Date: Wed, 12 Sep 2018 14:42:56 -0400 Subject: [PATCH 4/6] Add support for regex search to UnionSet. --- rust/rust_fst.h | 2 ++ rust/src/set.rs | 15 +++++++++++++++ rust_fst/set.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++ tests/test_set.py | 5 +++++ 4 files changed, 71 insertions(+) diff --git a/rust/rust_fst.h b/rust/rust_fst.h index 233366a..116bfbc 100644 --- a/rust/rust_fst.h +++ b/rust/rust_fst.h @@ -66,6 +66,7 @@ SetRegexStream* fst_set_regexsearch(Set*, Regex*); SetOpBuilder* fst_set_make_opbuilder(Set*); SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*); SetOpBuilder* fst_set_make_opbuilder_levstream(SetLevStream*); +SetOpBuilder* fst_set_make_opbuilder_regexstream(SetRegexStream*); SetOpBuilder* fst_set_make_opbuilder_union(SetUnion*); void fst_set_free(Set*); @@ -80,6 +81,7 @@ void fst_set_regexstream_free(SetRegexStream*); void fst_set_opbuilder_push(SetOpBuilder*, Set*); void fst_set_opbuilder_push_levstream(SetOpBuilder*, SetLevStream*); +void fst_set_opbuilder_push_regexstream(SetOpBuilder*, SetRegexStream*); void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*); void fst_set_opbuilder_push_union(SetOpBuilder*, SetUnion*); void fst_set_opbuilder_free(SetOpBuilder*); diff --git a/rust/src/set.rs b/rust/src/set.rs index 1767e16..d3d2c3d 100644 --- a/rust/src/set.rs +++ b/rust/src/set.rs @@ -154,6 +154,14 @@ pub extern "C" fn fst_set_make_opbuilder_levstream(ptr: *mut SetLevStream) -> *m to_raw_ptr(ob) } +#[no_mangle] +pub extern "C" fn fst_set_make_opbuilder_regexstream(ptr: *mut SetRegexStream) -> *mut set::OpBuilder<'static> { + let srs = val_from_ptr!(ptr); + let mut ob = set::OpBuilder::new(); + ob.push(srs.into_stream()); + to_raw_ptr(ob) +} + #[no_mangle] pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuilder<'static>) -> *mut set::OpBuilder<'static> { let sb = val_from_ptr!(ptr); @@ -184,6 +192,13 @@ pub extern "C" fn fst_set_opbuilder_push_levstream(ptr: *mut set::OpBuilder<'sta ob.push(sls.into_stream()); } +#[no_mangle] +pub extern "C" fn fst_set_opbuilder_push_regexstream(ptr: *mut set::OpBuilder<'static>, srs_ptr: *mut SetRegexStream) { + let srs = val_from_ptr!(srs_ptr); + let ob = mutref_from_ptr!(ptr); + ob.push(srs.into_stream()); +} + #[no_mangle] pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder<'static>, sb_ptr: *mut set::StreamBuilder<'static>) { let sb = val_from_ptr!(sb_ptr); diff --git a/rust_fst/set.py b/rust_fst/set.py index 5a42bbc..9f703c6 100644 --- a/rust_fst/set.py +++ b/rust_fst/set.py @@ -61,6 +61,7 @@ class OpBuilderInputType(Enum): STREAM_BUILDER = 2 UNION = 3 SEARCH = 4 + SEARCH_RE = 5 def _build_levsearch(fst, term, max_dist): @@ -72,6 +73,13 @@ def _build_levsearch(fst, term, max_dist): return lib.fst_set_levsearch(fst._ptr, lev_ptr) +def _build_research(fst, pattern): + re_ptr = checked_call( + lib.fst_regex_new, fst._ctx, + ffi.new("char[]", pattern.encode('utf8'))) + return lib.fst_set_regexsearch(fst._ptr, re_ptr) + + class OpBuilder(object): _BUILDERS = { @@ -79,12 +87,14 @@ class OpBuilder(object): OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder, OpBuilderInputType.UNION: lib.fst_set_make_opbuilder_union, OpBuilderInputType.SEARCH: lib.fst_set_make_opbuilder_levstream, + OpBuilderInputType.SEARCH_RE: lib.fst_set_make_opbuilder_regexstream, } _PUSHERS = { OpBuilderInputType.SET: lib.fst_set_opbuilder_push, OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder, OpBuilderInputType.UNION: lib.fst_set_opbuilder_push_union, OpBuilderInputType.SEARCH: lib.fst_set_opbuilder_push_levstream, + OpBuilderInputType.SEARCH_RE: lib.fst_set_opbuilder_push_regexstream, } @classmethod @@ -94,6 +104,13 @@ def from_search(cls, fst, term, max_dist): input_type=OpBuilderInputType.SEARCH) return opbuilder + @classmethod + def from_search_re(cls, fst, pattern): + stream_ptr = _build_research(fst, pattern) + opbuilder = OpBuilder(stream_ptr, + input_type=OpBuilderInputType.SEARCH_RE) + return opbuilder + @classmethod def from_slice(cls, set_ptr, s): sb = StreamBuilder.from_slice(set_ptr, s) @@ -535,6 +552,38 @@ def search(self, term, max_dist): opbuilder.push(_build_levsearch(fst, term, max_dist)) return opbuilder.union() + def search_re(self, pattern): + """ Search the set with a regular expression. + + Note that the regular expression syntax is not Python's, but the one + supported by the `regex` Rust crate, which is almost identical + to the engine of the RE2 engine. + + For a documentation of the syntax, see: + http://doc.rust-lang.org/regex/regex/index.html#syntax + + Due to limitations of the underlying FST, only a subset of this syntax + is supported. Most notably absent are: + + * Lazy quantifiers (``r'*?'``, ``r'+?'``) + * Word boundaries (``r'\\b'``) + * Other zero-width assertions (``r'^'``, ``r'$'``) + + For background on these limitations, consult the documentation of + the Rust crate: http://burntsushi.net/rustdoc/fst/struct.Regex.html + + :param pattern: A regular expression + :returns: An iterator over all matching keys in the set + :rtype: :py:class:`KeyStreamIterator` + """ + if len(self.sets) <= 1: + raise ValueError( + "Must have more than one set to operate on.") + opbuilder = OpBuilder.from_search_re(self.sets[0], pattern) + for fst in self.sets[1:]: + opbuilder.push(_build_research(fst, pattern)) + return opbuilder.union() + def symmetric_difference(self, *others): """ Get an iterator over the keys in the symmetric difference of this set and others. diff --git a/tests/test_set.py b/tests/test_set.py index 4402543..af9fc17 100644 --- a/tests/test_set.py +++ b/tests/test_set.py @@ -200,6 +200,11 @@ def test_unionset_search(fst_unionset): assert matches == ["bap", "bar", "baz"] +def test_unionset_search_re(fst_unionset): + matches = list(fst_unionset.search_re(r'ba.*')) + assert matches == ["bap", "bar", "baz"] + + def test_unionset_symmetric_difference(): a = Set.from_iter(["bar", "foo"]) b = Set.from_iter(["baz", "foo"]) From 5b1a20e66e634264889236824d56600b38d2c209 Mon Sep 17 00:00:00 2001 From: David Blewett Date: Wed, 12 Sep 2018 14:43:11 -0400 Subject: [PATCH 5/6] Update README with examples. --- README.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ec473b3..f6c2920 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,6 @@ package. ## Status The package exposes almost all functionality of the `fst` crate, except for: -- Combining the results of slicing, `search` and `search_re` with set operations - Using raw transducers @@ -83,6 +82,24 @@ m = Map.from_iter( file_iterator('/your/input/file/'), '/your/mmapped/output.fst # re-open a file you built previously with from_iter() m = Map(path='/path/to/existing.fst') + +# slicing multiple sets efficiently +a = Set.from_iter(["bar", "foo"]) +b = Set.from_iter(["baz", "foo"]) +list(UnionSet(a, b)['ba':'bb']) +['bar', 'baz'] + +# searching multiple sets efficiently +a = Set.from_iter(["bar", "foo"]) +b = Set.from_iter(["baz", "foo"]) +list(UnionSet(a, b).search('ba', 1) +['bar', 'baz'] + +# searching multiple sets with a regex efficiently +a = Set.from_iter(["bar", "foo"]) +b = Set.from_iter(["baz", "foo"]) +list(UnionSet(a, b).search_re(r'b\w{2}') +['bar', 'baz'] ``` From 3488ff76e414e6eebe67858bc2b611d0e9c1ef99 Mon Sep 17 00:00:00 2001 From: David Blewett Date: Wed, 3 Oct 2018 15:02:53 -0400 Subject: [PATCH 6/6] Remove extraneous checks. --- rust_fst/set.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/rust_fst/set.py b/rust_fst/set.py index 9f703c6..059ed66 100644 --- a/rust_fst/set.py +++ b/rust_fst/set.py @@ -467,10 +467,9 @@ def __getitem__(self, s): if s.start and s.stop and s.start > s.stop: raise ValueError( "Start key must be lexicographically smaller than stop.") - if len(self.sets) <= 1: - raise ValueError( - "Must have more than one set to operate on.") + if not self.sets: + return opbuilder = OpBuilder.from_slice(self.sets[0]._ptr, s) streams = [] for fst in self.sets[1:]: @@ -483,9 +482,8 @@ def __getitem__(self, s): def __iter__(self): """ Get an iterator over all keys in all sets in lexicographical order. """ - if len(self.sets) <= 1: - raise ValueError( - "Must have more than one set to operate on.") + if not self.sets: + return opbuilder = OpBuilder(self.sets[0]._ptr, input_type=OpBuilderInputType.SET) for fst in self.sets[1:]: @@ -494,12 +492,11 @@ def __iter__(self): def _make_opbuilder(self, *others): others = list(others) - if len(self.sets) <= 1: - raise ValueError( - "Must have more than one set to operate on.") if not others: raise ValueError( "Must have at least one set to compare against.") + if not self.sets: + return our_opbuilder = OpBuilder(self.sets[0]._ptr, input_type=OpBuilderInputType.SET) for fst in self.sets[1:]: @@ -544,9 +541,8 @@ def search(self, term, max_dist): :returns: Iterator over matching values in the set :rtype: :py:class:`KeyStreamIterator` """ - if len(self.sets) <= 1: - raise ValueError( - "Must have more than one set to operate on.") + if not self.sets: + return opbuilder = OpBuilder.from_search(self.sets[0], term, max_dist) for fst in self.sets[1:]: opbuilder.push(_build_levsearch(fst, term, max_dist)) @@ -576,9 +572,8 @@ def search_re(self, pattern): :returns: An iterator over all matching keys in the set :rtype: :py:class:`KeyStreamIterator` """ - if len(self.sets) <= 1: - raise ValueError( - "Must have more than one set to operate on.") + if not self.sets: + return opbuilder = OpBuilder.from_search_re(self.sets[0], pattern) for fst in self.sets[1:]: opbuilder.push(_build_research(fst, pattern))