diff --git a/kovot/bot.py b/kovot/bot.py index 0abcbcc..09d9eef 100644 --- a/kovot/bot.py +++ b/kovot/bot.py @@ -3,7 +3,9 @@ import logging +from typing import Optional from kovot.mod import ModManager +from kovot.message import Message # Kovot @@ -42,7 +44,7 @@ def __init__(self, # debug self.module_manager.show_mods() - def talk(self, message): + def talk(self, message) -> Optional[Message]: message = self.preprocessor.transform(message) # get responses from mods @@ -55,6 +57,10 @@ def talk(self, message): postprocessed_reponses = [self.postprocessor.transform(res) for res in selected_resposes] + if len(postprocessed_reponses) == 0: + logging.warn("### no answer candidate ###") + return None + # log logging.info("### answer candidates ###") for response in postprocessed_reponses: @@ -72,4 +78,5 @@ def run(self, stream): for message in stream: res = self.talk(message) - stream.post(res) + if res is not None: + stream.post(res) diff --git a/test/test_bot.py b/test/test_bot.py index e547b7b..083a9fd 100644 --- a/test/test_bot.py +++ b/test/test_bot.py @@ -3,6 +3,7 @@ import unittest +from typing import Callable, Iterator, Optional from kovot import Bot from kovot import Response from kovot import Message @@ -16,6 +17,26 @@ def generate_responses(self, bot, message): return [res] +class SilentMod: + def generate_responses(self, bot, message): + return [] + + +class StubStream(object): + def __init__(self, callback: Optional[Callable[[Response], bool]] = None): + self.callback = callback + self.callback_num = 0 + + def __iter__(self) -> Iterator[Message]: + return iter([Message(text="テスト", speaker=Speaker(name="話し✋"))]) + + def post(self, response: Response) -> bool: + if self.callback: + self.callback_num += 1 + return self.callback(response) + return True + + class BotTest(unittest.TestCase): def test_talk(self): msg = Message(text="テスト", @@ -24,3 +45,46 @@ def test_talk(self): res = bot.talk(msg) self.assertEqual(res, Response(score=1.0, text="テスト")) + + def test_talk_nocandidate(self): + msg = Message(text="テスト", + speaker=Speaker(name="話し✋")) + bot = Bot(mods=[SilentMod()]) + + res = bot.talk(msg) + self.assertIsNone(res) + + def test_talk_multimod(self): + msg = Message(text="テスト", + speaker=Speaker(name="話し✋")) + bot = Bot(mods=[EchoMod(), SilentMod()]) + + res = bot.talk(msg) + self.assertEqual(res, Response(score=1.0, text="テスト")) + + def test_run(self): + def cb(res): + self.assertEqual(res, Response(score=1.0, text="テスト")) + return True + bot = Bot(mods=[EchoMod()]) + stream = StubStream(callback=cb) + bot.run(stream=stream) + self.assertEqual(stream.callback_num, 1) + + def test_run_nocandidate(self): + def cb(res): + self.fail() + return True + bot = Bot(mods=[SilentMod()]) + stream = StubStream(callback=cb) + bot.run(stream=stream) + self.assertEqual(stream.callback_num, 0) + + def test_run_multimod(self): + def cb(res): + self.assertEqual(res, Response(score=1.0, text="テスト")) + return True + bot = Bot(mods=[EchoMod(), SilentMod()]) + stream = StubStream(callback=cb) + bot.run(stream=stream) + self.assertEqual(stream.callback_num, 1)