diff --git a/comet_core/app.py b/comet_core/app.py index dfa69af..c549a25 100644 --- a/comet_core/app.py +++ b/comet_core/app.py @@ -144,6 +144,7 @@ def __init__(self, database_uri='sqlite://'): self.inputs = list() self.instantiated_inputs = list() self.hydrators = dict() + self.filters = dict() self.parsers = dict() self.routers = SourceTypeFunction() self.escalators = SourceTypeFunction() @@ -189,8 +190,14 @@ def message_callback(self, source_type, message): if hydrate: hydrate(event) + # Filter event + filter_event = self.filters.get(source_type) + if filter_event: + event = filter_event(event) + # Add to datastore - self.data_store.add_record(event.get_record()) + if event: + self.data_store.add_record(event.get_record()) return True def set_config(self, source_type, config): @@ -278,6 +285,28 @@ def decorator(func): else: self.hydrators[source_type] = func + def register_filter(self, source_type, func=None): + """Register a filter function to filter events before saving them to the db. + + This method can be used either as a decorator or with a filter function passed in. + + Args: + source_type (str): the source type to register the filter for + func (Optional[function]): a function that filter a message of type source_type, or None if used as a + decorator + Return: + function or None: if no func is given returns a decorator function, otherwise None + """ + if not func: + # pylint: disable=missing-docstring, missing-return-doc, missing-return-type-doc + def decorator(func): + self.filters[source_type] = func + return func + + return decorator + else: + self.filters[source_type] = func + def register_router(self, source_types=None, func=None): """Register a hydrator. diff --git a/setup.py b/setup.py index d69f3f9..63dc556 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setuptools.setup( name="comet-core", - version="1.0.10", + version="1.0.11", url="https://github.com/spotify/comet-core", author="Spotify Platform Security", diff --git a/tests/test_app.py b/tests/test_app.py index 4c8fbc8..8c90b6c 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -135,11 +135,34 @@ def loads(self, msg): hydrator_mock = mock.Mock() app.register_hydrator('test', hydrator_mock) + filter_return_value = EventContainer('test', {"a": "b"}) + filter_mock = mock.Mock(return_value=filter_return_value) + app.register_filter('test', filter_mock) + assert not app.message_callback('test1', '{}') assert not app.message_callback('test', '{ "c": "d" }') app.message_callback('test', '{ "a": "b" }') assert hydrator_mock.called + assert filter_mock.called + assert filter_mock.return_value, filter_return_value + + +def test_message_callback_filter(app): + @app.register_parser('test') + class TestParser: + def loads(self, msg): + ev = json.loads(msg) + if 'a' in ev: + return ev, None + return None, 'fail' + + filter_mock = mock.Mock(return_value=None) + app.register_filter('test', filter_mock) + + app.message_callback('test', '{ "a": "b" }') + assert filter_mock.called + assert filter_mock.return_value is None def test_register_input(app): @@ -184,6 +207,22 @@ def test_hydrator(*args): assert len(app.hydrators) == 2, app.hydrators +def test_register_filter(app): + assert not app.filters + + @app.register_filter('test1') + def test_filter(*args): + pass + + # Override existing + app.register_filter('test1', test_filter) + assert len(app.filters) == 1, app.filters + + # Add another + app.register_filter('test2', test_filter) + assert len(app.filters) == 2, app.filters + + def test_set_config(app): assert not app.specific_configs