diff --git a/main.py b/main.py index 1cda71a..186676d 100755 --- a/main.py +++ b/main.py @@ -84,6 +84,7 @@ slack_app_token = args.slack_app_token user_map = args.user_map bot_name = args.name + brain = Markov(args.brain, args.output, args.user_map, [bot_name]) intents = discord.Intents(guild_messages=True, message_content=True) @@ -123,6 +124,12 @@ async def on_ready(): async def on_message(message): if message.author == discord_client.user: return + + num_words_in_brain = len(brain.words) + if num_words_in_brain >= (Markov.SIGNED_INT_MAX_VALUE * 0.95): + await message.channel.send('Brain hurty. Brain almost full :sob:') + return + # print(f"Discord message from {message.author}: {message.content}") response = create_raw_response(message.content, False) if response and response.strip() != "": diff --git a/markov.py b/markov.py index f0756fe..359f88a 100755 --- a/markov.py +++ b/markov.py @@ -1,23 +1,94 @@ import random +import struct +from typing import List, Dict, Tuple, Optional +import array import yaml from itertools import chain, groupby -START_TOK = "" -STOP_TOK = "" -STOP = object() -START = object() +START_TOKEN = "" +STOP_TOKEN = "" + +START_INDEX = -1 +STOP_INDEX = -2 # instantiate a Markov object with the source file class Markov: + + # Signed short + COMPRESSED_NUMBER_FORMAT: str = 'i' + # Two little-endian signed shorts + COMPRESSED_COMBO_NUMBER_FORMAT: str = f'{COMPRESSED_NUMBER_FORMAT}{COMPRESSED_NUMBER_FORMAT}' + SIGNED_INT_MAX_VALUE: int = 2147483648 + def __init__(self, input_file: str, output_file: str, user_map, ignore_words): if input_file == output_file: raise ValueError("input and output files must be different") self.user_map = self._init_user_map(user_map) self.ignore_words = set(w.upper() for w in ignore_words) self.output_file = output_file + + # Map of n-gram transitions + self.graph: Dict[bytes, array.array] = dict() + + # Word -> Word Index map + self.word_index_map: Dict[str, bytes] = dict() + + # List of all unique words. word_index_map maps to indices in this list + self.words: List[str] = list() + self.update_graph_and_corpus(self.corpus_iter(input_file), init=True) + print(f'Found {len(self.words)} unique words') + + def to_graph_key(self, word_index: int | Tuple[int, int]) -> bytes: + """Convert 1 or 2 integers into a primitive bytes key, compressed to 4 bytes. + If a single int is passed, it will be in the 0 spot + If a tuple is passed, it will fill both spots""" + if isinstance(word_index, int): + return struct.pack(Markov.COMPRESSED_COMBO_NUMBER_FORMAT, word_index, -1) + elif isinstance(word_index, tuple): + return struct.pack(Markov.COMPRESSED_COMBO_NUMBER_FORMAT, word_index[0], word_index[1]) + else: + raise Exception(f'word index must be int or tuple but was \"{word_index}\"') + + def unpack_graph_key(self, key: bytes) -> Tuple[int, int]: + """Convert a key bytes object back into a tuple of two ints. If the original key had one int value, it will be + at index 0""" + return struct.unpack(Markov.COMPRESSED_COMBO_NUMBER_FORMAT, key) + + def get_word_index(self, word: str) -> int: + """Get the index for the provided word. If we don't know it, it's a new word and it'll be inserted""" + if word in self.word_index_map: + word_index, _ = self.unpack_graph_key(self.word_index_map[word]) + return word_index + + if word == START_TOKEN: + return START_INDEX + + if word == STOP_TOKEN: + return STOP_INDEX + + self.words.append(word) + index = len(self.words) - 1 + self.word_index_map[word] = self.to_graph_key(index) + return index + + def get_candidate_indices_for_graph_key(self, graph_key: int | Tuple[int, int]) -> array.array: + """Get the indices for the next transition at the given word_index""" + array_at_index:array.array = self.graph[self.to_graph_key(graph_key)] + return array_at_index + + def try_append_at_graph_key(self, graph_key: int | Tuple[int, int], value_to_append: int) -> Tuple[bool, array.array]: + """Try to insert the given integer value to the array at graph_key in the graph. If it's already there, no-op. + + Returns true if the array was modified, false otherwise""" + key = self.to_graph_key(graph_key) + array_at_word_index = self.graph.setdefault(key, array.array(Markov.COMPRESSED_NUMBER_FORMAT)) + if value_to_append not in array_at_word_index: + array_at_word_index.extend((value_to_append,)) + return True, array_at_word_index + return False, array_at_word_index def corpus_iter(self, source_file: str): """ @@ -27,7 +98,7 @@ def corpus_iter(self, source_file: str): # this is dumb if source_file.endswith(".yml") or source_file.endswith(".yaml"): words = yaml.load(infile.read(), Loader=yaml.Loader) - for is_delim, phrase in groupby(words, lambda w: w in (START_TOK, STOP_TOK)): + for is_delim, phrase in groupby(words, lambda w: w in (START_TOKEN, STOP_TOKEN)): if not is_delim: yield list(phrase) else: @@ -40,7 +111,7 @@ def triples_and_stop(cls, words): Emit 3-grams from the sequence of words, the last one ending with the special STOP token """ - words = chain(words, [STOP]) + words = chain(words, [STOP_TOKEN]) try: w1 = next(words) w2 = next(words) @@ -75,9 +146,9 @@ def tokenize(self, sentence: str): if cur: yield cur - def _update_graph_and_emit_changes(self, token_seqs, init=False): + def _update_graph_and_emit_changes(self, token_seqs:List[List[str]], init=False): """ - self.graph stores the graph of n-gram trasitions. + self.graph stores the graph of n-gram transitions. The keys are single tokens or pairs and the values possible next words in the n-gram. Initial tokens are also specially added to the list at the key START. @@ -90,25 +161,29 @@ def _update_graph_and_emit_changes(self, token_seqs, init=False): if init is True reinitialize from an empty graph """ if init: - self.graph = {START: []} + self.graph.clear() + self.graph[self.to_graph_key(START_INDEX)] = array.array(Markov.COMPRESSED_NUMBER_FORMAT) for seq in token_seqs: first = True learned = False for w1, w2, w3 in self.triples_and_stop(seq): + + w1_index = self.get_word_index(w1) + w2_index = self.get_word_index(w2) + w3_index = self.get_word_index(w3) + if first: - if w1 not in self.graph[START]: - self.graph[START].append(w1) - learned = True - next_words = self.graph.setdefault(w1, []) - if w2 not in next_words: - next_words.append(w2) - learned = True + added_to_start, new_start = self.try_append_at_graph_key(START_INDEX, w1_index) + learned |= added_to_start + + added_to_w1, new_w1 = self.try_append_at_graph_key(w1_index, w2_index) + learned |= added_to_w1 first = False - next_words = self.graph.setdefault((w1, w2), []) - if w3 not in next_words: - next_words.append(w3) - learned = True + + combined_key = (w1_index, w2_index) + added_to_combined, new_combined = self.try_append_at_graph_key(combined_key, w3_index) + learned |= added_to_combined if learned: yield seq @@ -129,21 +204,31 @@ def update_corpus(self, token_seqs, init=False): f.write(" ".join(seq)) f.write("\n") - def generate_markov_text(self, seed=None): - if seed and seed in self.graph: - w1 = seed - else: - w1 = random.choice(self.graph[START]) - w2 = random.choice(self.graph[w1]) + def generate_markov_text(self, seed: Optional[str]=None): + + w1_index = None + if seed is not None: + seed_index = self.get_word_index(seed) + if seed_index in self.graph: + w1_index = seed_index + + if w1_index is None: + choices = self.get_candidate_indices_for_graph_key(START_INDEX) + w1_index = random.choice(choices) + + choices = self.get_candidate_indices_for_graph_key(w1_index) + w2_index = random.choice(choices) - gen_words = [w1] + generated_index_list = [w1_index] while True: - if w2 == STOP: + if w2_index == STOP_INDEX: break - w1, w2 = w2, random.choice(self.graph[(w1, w2)]) - gen_words.append(w1) + next_key = (w1_index, w2_index) + choices = self.get_candidate_indices_for_graph_key(next_key) + w1_index, w2_index = w2_index, random.choice(choices) + generated_index_list.append(w1_index) - message = ' '.join(gen_words) + message = ' '.join(map(lambda idx: self.words[idx], generated_index_list)) return message def _map_users(self, response, slack): @@ -161,8 +246,10 @@ def _map_users(self, response, slack): def create_response(self, prompt="", learn=False, slack=False): # set seedword from somewhere in words if there's no prompt prompt_tokens = prompt.split() - valid_seeds = [tok for tok in prompt_tokens[:-2] if tok in self.graph] - seed_word = random.choice(valid_seeds) if valid_seeds else None + prompt_indices = list(map(lambda t: self.get_word_index(t), prompt_tokens)) + valid_seed_indices = [tok for tok in prompt_indices[:-2] if tok in self.graph] + seed_word_index = random.choice(valid_seed_indices) if valid_seed_indices else None + seed_word = self.words[seed_word_index] if seed_word_index is not None else None response = self.generate_markov_text(seed_word) if learn: self.update_graph_and_corpus(self.tokenize(prompt))