diff --git a/interegular/fsm.py b/interegular/fsm.py index 67e5ca4..43ec94f 100644 --- a/interegular/fsm.py +++ b/interegular/fsm.py @@ -2,9 +2,11 @@ Finite state machine library, extracted from `greenery.fsm` and adapted by MegaIng """ from _collections import deque +from dataclasses import dataclass from collections import defaultdict from functools import total_ordering -from typing import Any, Set, Dict, Union, NewType, Mapping, Tuple, Iterable +from typing import Any, Set, Dict, Union, NewType, Mapping, Tuple, Iterable, Callable, List, Optional +from itertools import chain from interegular.utils import soft_repr @@ -23,28 +25,29 @@ class _AnythingElseCls: fsm.anything_else, then follow the appropriate transition. """ - def __str__(self): + def __str__(self) -> str: return "anything_else" - def __repr__(self): + def __repr__(self) -> str: return "anything_else" - def __lt__(self, other): + def __lt__(self, other) -> bool: return False - def __eq__(self, other): - return self is other + def __eq__(self, other) -> bool: + return isinstance(other, _AnythingElseCls) - def __hash__(self): - return hash(id(self)) + def __hash__(self) -> int: + return hash(str(self)) # We use a class instance because that gives us control over how the special # value gets serialised. Otherwise this would just be `object()`. anything_else = _AnythingElseCls() +Symbol = Union[str, _AnythingElseCls] -def nice_char_group(chars: Iterable[Union[str, _AnythingElseCls]]): +def nice_char_group(chars: Iterable[Symbol]) -> str: out = [] current_range = [] for c in sorted(chars): @@ -69,10 +72,10 @@ def nice_char_group(chars: Iterable[Union[str, _AnythingElseCls]]): class Alphabet(Mapping[Any, TransitionKey]): @property - def by_transition(self): + def by_transition(self) -> Dict[TransitionKey, List[Symbol]]: return self._by_transition - def __str__(self): + def __str__(self) -> str: out = [] width = 0 for tk, symbols in sorted(self._by_transition.items()): @@ -81,7 +84,7 @@ def __str__(self): width = len(out[-1][0]) return '\n'.join(f"{a:{width}} | {b}" for a, b in out) - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({self._symbol_mapping!r})" def __len__(self) -> int: @@ -90,12 +93,12 @@ def __len__(self) -> int: def __iter__(self): return iter(self._symbol_mapping) - def __init__(self, symbol_mapping: Dict[Union[str, _AnythingElseCls], TransitionKey]): + def __init__(self, symbol_mapping: Dict[Symbol, TransitionKey]): self._symbol_mapping = symbol_mapping by_transition = defaultdict(list) for s, t in self._symbol_mapping.items(): by_transition[t].append(s) - self._by_transition = dict(by_transition) + self._by_transition: Dict[TransitionKey, List[Symbol]] = dict(by_transition) def __getitem__(self, item): if item not in self._symbol_mapping: @@ -106,7 +109,7 @@ def __getitem__(self, item): else: return self._symbol_mapping[item] - def __contains__(self, item): + def __contains__(self, item) -> bool: return item in self._symbol_mapping def union(*alphabets: 'Alphabet') -> 'Tuple[Alphabet, Tuple[Dict[TransitionKey, TransitionKey], ...]]': @@ -119,11 +122,11 @@ def union(*alphabets: 'Alphabet') -> 'Tuple[Alphabet, Tuple[Dict[TransitionKey, result = Alphabet({symbol: keys_to_key[keys] for keys, symbols in keys_to_symbols.items() for symbol in symbols}) - new_to_old_mappings = [{} for _ in alphabets] + new_to_old_mappings = tuple({} for _ in alphabets) for keys, new_key in keys_to_key.items(): for old_key, new_to_old in zip(keys, new_to_old_mappings): new_to_old[new_key] = old_key - return result, tuple(new_to_old_mappings) + return result, new_to_old_mappings @classmethod def from_groups(cls, *groups): @@ -139,13 +142,13 @@ def intersect(self, other: 'Alphabet') -> 'Tuple[Alphabet, Tuple[Dict[Transition result = Alphabet({symbol: keys_to_key[keys] for keys, symbols in keys_to_symbols.items() for symbol in symbols}) - old_to_new_mappings = [defaultdict(list) for _ in (self, other)] - new_to_old_mappings = [{} for _ in (self, other)] + old_to_new_mappings = defaultdict(list), defaultdict(list) + new_to_old_mappings = {}, {} for keys, new_key in keys_to_key.items(): for old_key, old_to_new, new_to_old in zip(keys, old_to_new_mappings, new_to_old_mappings): old_to_new[old_key].append(new_key) new_to_old[new_key] = old_key - return result, tuple(new_to_old_mappings) + return result, new_to_old_mappings def copy(self): return Alphabet(self._symbol_mapping.copy()) @@ -161,6 +164,7 @@ class OblivionError(Exception): pass +@dataclass(frozen=True) class FSM: """ A Finite State Machine or FSM has an alphabet and a set of states. At any @@ -176,16 +180,17 @@ class FSM: The majority of these methods are available using operator overloads. """ alphabet: Alphabet + states: frozenset[State] initial: State - states: Set[State] - finals: Set[State] - map: Dict[State, Dict[TransitionKey, State]] - - def __setattr__(self, name, value): - """Immutability prevents some potential problems.""" - raise Exception("This object is immutable.") + finals: frozenset[State] + transition_map: Dict[State, Dict[TransitionKey, State]] + __no_validation__: Optional[bool] = True + + @property + def map(self) -> Dict[State, Dict[TransitionKey, State]]: + return self.transition_map - def __init__(self, alphabet: Alphabet, states, initial, finals, map, *, __no_validation__=False): + def __init__(self, alphabet: Alphabet, states: frozenset[State], initial: State, finals: frozenset[State], transition_map: Optional[Dict[State, Dict[TransitionKey, State]]]=None, __no_validation__: Optional[bool] = True, map: Optional[Dict[State, Dict[TransitionKey, State]]]=None): """ `alphabet` is an iterable of symbols the FSM can be fed. `states` is the set of states for the FSM @@ -194,7 +199,7 @@ def __init__(self, alphabet: Alphabet, states, initial, finals, map, *, __no_val `map` may be sparse (i.e. it may omit transitions). In the case of omitted transitions, a non-final "oblivion" state is simulated. """ - + assert map is not None or transition_map is not None if not __no_validation__: # Validation. Thanks to immutability, this only needs to be carried out once. if not isinstance(alphabet, Alphabet): @@ -203,21 +208,23 @@ def __init__(self, alphabet: Alphabet, states, initial, finals, map, *, __no_val raise Exception("Initial state " + repr(initial) + " must be one of " + repr(states)) if not finals.issubset(states): raise Exception("Final states " + repr(finals) + " must be a subset of " + repr(states)) - for state in map.keys(): - for symbol in map[state]: - if not map[state][symbol] in states: + for state, transitions in transition_map.items(): + for symbol, next_state in transitions.items(): + if not next_state in states: raise Exception( "Transition for state " + repr(state) + " and symbol " + repr(symbol) + " leads to " + repr( - map[state][symbol]) + ", which is not a state") - - # Initialise the hard way due to immutability. - self.__dict__["alphabet"] = alphabet - self.__dict__["states"] = frozenset(states) - self.__dict__["initial"] = initial - self.__dict__["finals"] = frozenset(finals) - self.__dict__["map"] = map + next_state) + ", which is not a state") + + object.__setattr__(self, "alphabet", alphabet) + object.__setattr__(self, "states", states) + object.__setattr__(self, "initial", initial) + object.__setattr__(self, "finals", finals) + if transition_map is not None: + object.__setattr__(self, "transition_map", transition_map) + else: + object.__setattr__(self, "transition_map", map) - def accepts(self, input: str): + def accepts(self, input_str: str) -> bool: """ Test whether the present FSM accepts the supplied string (iterable of symbols). Equivalently, consider `self` as a possibly-infinite set of @@ -227,19 +234,37 @@ def accepts(self, input: str): alphabet will be converted to `fsm.anything_else`. """ state = self.initial - for symbol in input: - if anything_else in self.alphabet and not symbol in self.alphabet: - symbol = anything_else - transition = self.alphabet[symbol] + if anything_else in self.alphabet: + + for symbol in input_str: + if not symbol in self.alphabet: + symbol = anything_else + + if state not in self.transition_map: + return False + + transition = self.alphabet[symbol] + + # Missing transition = transition to dead state + if transition not in self.transition_map[state]: + return False - # Missing transition = transition to dead state - if not (state in self.map and transition in self.map[state]): - return False + state = self.transition_map[state][transition] + else: + for symbol in input_str: + if state not in self.transition_map: + return False + + transition = self.alphabet[symbol] - state = self.map[state][transition] + # Missing transition = transition to dead state + if transition not in self.transition_map[state]: + return False + + state = self.transition_map[state][transition] return state in self.finals - def __contains__(self, string): + def __contains__(self, string) -> bool: """ This lets you use the syntax `"a" in fsm1` to see whether the string "a" is in the set of strings accepted by `fsm1`. @@ -254,17 +279,17 @@ def reduce(self): """ return self.reversed().reversed() - def __repr__(self): + def __repr__(self) -> str: string = "fsm(" string += "alphabet = " + repr(self.alphabet) string += ", states = " + repr(self.states) string += ", initial = " + repr(self.initial) string += ", finals = " + repr(self.finals) - string += ", map = " + repr(self.map) + string += ", map = " + repr(self.transition_map) string += ")" return string - def __str__(self): + def __str__(self) -> str: rows = [] # top row @@ -286,8 +311,8 @@ def __str__(self): else: row.append("False") for symbol, transition in sorted(self.alphabet.items()): - if state in self.map and transition in self.map[state]: - row.append(str(self.map[state][transition])) + if state in self.transition_map and transition in self.transition_map[state]: + row.append(str(self.transition_map[state][transition])) else: row.append("") rows.append(row) @@ -322,20 +347,19 @@ def connect_all(i, substate): (if it's final) the first state from the next FSM, plus (if that's final) the first state from the next but one FSM, plus... """ - result = {(i, substate)} + result = [(i, substate)] while i < last_index and substate in fsms[i].finals: i += 1 substate = fsms[i].initial - result.add((i, substate)) - return result + result.append((i, substate)) + return frozenset(result) # Use a superset containing states from all FSMs at once. # We start at the start of the first FSM. If this state is final in the # first FSM, then we are also at the start of the second FSM. And so on. - initial = set() + initial = frozenset() if len(fsms) > 0: - initial.update(connect_all(0, fsms[0].initial)) - initial = frozenset(initial) + initial = connect_all(0, fsms[0].initial) def final(state): """If you're in a final state of the final FSM, it's final""" @@ -350,14 +374,15 @@ def follow(current, new_transition): next FSM if we reach the end of the current one TODO: improve all follow() implementations to allow for dead metastates? """ - next = set() + next_states = [] for (i, substate) in current: fsm = fsms[i] - if substate in fsm.map and new_to_old[i][new_transition] in fsm.map[substate]: - next.update(connect_all(i, fsm.map[substate][new_to_old[i][new_transition]])) - if not next: + current_vertex: TransitionKey = new_to_old[i][new_transition] + if substate in fsm.transition_map and current_vertex in fsm.transition_map[substate]: + next_states.append(connect_all(i, fsm.transition_map[substate][current_vertex])) + if not next_states: raise OblivionError - return frozenset(next) + return frozenset(chain.from_iterable(next_states)) return crawl(alphabet, initial, final, follow) @@ -382,22 +407,22 @@ def star(self): initial = {self.initial} def follow(state, transition): - next = set() + next_states = [] for substate in state: - if substate in self.map and transition in self.map[substate]: - next.add(self.map[substate][transition]) + if substate in self.transition_map and transition in self.transition_map[substate]: + next_states.append(self.transition_map[substate][transition]) # If one of our substates is final, then we can also consider # transitions from the initial state of the original FSM. if substate in self.finals \ - and self.initial in self.map \ - and transition in self.map[self.initial]: - next.add(self.map[self.initial][transition]) + and self.initial in self.transition_map \ + and transition in self.transition_map[self.initial]: + next_states.append(self.transition_map[self.initial][transition]) - if not next: + if not next_states: raise OblivionError - return frozenset(next) + return frozenset(next_states) def final(state): return any(substate in self.finals for substate in state) @@ -406,7 +431,7 @@ def final(state): base.__dict__['finals'] = base.finals | {base.initial} return base - def times(self, multiplier): + def times(self, multiplier: int): """ Given an FSM and a multiplier, return the multiplied FSM. """ @@ -427,18 +452,19 @@ def final(state): return False def follow(current, transition): - next = [] + next_state = [] for (substate, iteration) in current: if iteration < multiplier \ - and substate in self.map \ - and transition in self.map[substate]: - next.append((self.map[substate][transition], iteration)) + and substate in self.transition_map \ + and transition in self.transition_map[substate]: + current_state = self.transition_map[substate][transition] + next_state.append((current_state, iteration)) # final of self? merge with initial on next iteration - if self.map[substate][transition] in self.finals: - next.append((self.initial, iteration + 1)) - if len(next) == 0: + if current_state in self.finals: + next_state.append((self.initial, iteration + 1)) + if len(next_state) == 0: raise OblivionError - return frozenset(next) + return frozenset(next_state) return crawl(alphabet, initial, final, follow) @@ -513,8 +539,8 @@ def everythingbut(self): def follow(current, transition): next = {} - if 0 in current and current[0] in self.map and transition in self.map[current[0]]: - next[0] = self.map[current[0]][transition] + if 0 in current and current[0] in self.transition_map and transition in self.transition_map[current[0]]: + next[0] = self.transition_map[current[0]][transition] return next # state is final unless the original was @@ -531,12 +557,12 @@ def isdisjoint(self, other: 'FSM') -> bool: # obtained by following this transition in the new FSM def follow(current, transition): ss, os = current - if ss in self.map and new_to_old[0][transition] in self.map[ss]: - sn = self.map[ss][new_to_old[0][transition]] + if ss in self.transition_map and new_to_old[0][transition] in self.transition_map[ss]: + sn = self.transition_map[ss][new_to_old[0][transition]] else: sn = None - if os in other.map and new_to_old[1][transition] in other.map[os]: - on = other.map[os][new_to_old[1][transition]] + if os in other.transition_map and new_to_old[1][transition] in other.transition_map[os]: + on = other.transition_map[os][new_to_old[1][transition]] else: on = None if not sn or not on: @@ -568,22 +594,25 @@ def reversed(self): initial = frozenset(self.finals) # Speed up follow by pre-computing reverse-transition map - reverse_map = {} - for state, transition_map in self.map.items(): + reverse_map = defaultdict(set) + for state, transition_map in self.transition_map.items(): for transition, next_state in transition_map.items(): - if (next_state, transition) not in reverse_map: - reverse_map[(next_state, transition)] = set() reverse_map[(next_state, transition)].add(state) # Find every possible way to reach the current state-set # using this symbol. def follow(current, transition): - next_states = set() - for state in current: - next_states.update(reverse_map.get((state, transition), set())) + _empty_set = set() # reuse to avoid unnecessary allocations + + next_states_iter = ( + reverse_map.get((state, transition), _empty_set) + for state in current + ) + next_states = frozenset(chain.from_iterable(next_states_iter)) + if not next_states: raise OblivionError - return frozenset(next_states) + return next_states # A state-set is final if the initial state is in it. def final(state): @@ -610,12 +639,12 @@ def islive(self, state): current = reachable[i] if current in self.finals: return True - if current in self.map: - for transition in self.map[current]: - next = self.map[current][transition] - if next not in seen: - reachable.append(next) - seen.add(next) + if current in self.transition_map: + transitions = self.transition_map[current] + for next_state in transitions.values(): + if next_state not in seen: + reachable.append(next_state) + seen.add(next_state) i += 1 return False @@ -648,7 +677,7 @@ def strings(self, max_iterations=None): # Many FSMs have "dead states". Once you reach a dead state, you can no # longer reach a final state. Since many strings may end up here, it's # advantageous to constrain our search to live states only. - livestates = set(state for state in self.states if self.islive(state)) + livestates = frozenset(state for state in self.states if self.islive(state)) # We store a list of tuples. Each tuple consists of an input string and the # state that this input string leads to. This means we don't have to run the @@ -671,9 +700,9 @@ def strings(self, max_iterations=None): while strings: (cstring, cstate) = strings.popleft() i += 1 - if cstate in self.map: - for transition in sorted(self.map[cstate]): - nstate = self.map[cstate][transition] + if cstate in self.transition_map: + for transition in sorted(self.transition_map[cstate]): + nstate = self.transition_map[cstate][transition] if nstate in livestates: for symbol in sorted(self.alphabet.by_transition[transition]): nstring = cstring + [symbol] @@ -737,25 +766,25 @@ def cardinality(self): def get_num_strings(state): # Many FSMs have at least one oblivion state - if self.islive(state): - if state in num_strings: - if num_strings[state] is None: # "computing..." - # Recursion! There are infinitely many strings recognised - raise OverflowError(state) - return num_strings[state] - num_strings[state] = None # i.e. "computing..." - - n = 0 - if state in self.finals: - n += 1 - if state in self.map: - for transition in self.map[state]: - n += get_num_strings(self.map[state][transition]) * len(self.alphabet.by_transition[transition]) - num_strings[state] = n - - else: - # Dead state + if not self.islive(state): num_strings[state] = 0 + return 0 + + if state in num_strings: + if num_strings[state] is None: # "computing..." + # Recursion! There are infinitely many strings recognised + raise OverflowError(state) + return num_strings[state] + num_strings[state] = None # i.e. "computing..." + + n = 0 + if state in self.finals: + n += 1 + if state in self.transition_map: + transitions = self.transition_map[state] + for transition, next_state in transitions.items(): + n += get_num_strings(next_state) * len(self.alphabet.by_transition[transition]) + num_strings[state] = n return num_strings[state] @@ -834,7 +863,7 @@ def copy(self): states=self.states.copy(), initial=self.initial, finals=self.finals.copy(), - map=self.map.copy(), + transition_map=self.transition_map.copy(), __no_validation__=True, ) @@ -856,10 +885,10 @@ def derive(self, input): symbol = anything_else # Missing transition = transition to dead state - if not (state in self.map and self.alphabet[symbol] in self.map[state]): + if not (state in self.transition_map and self.alphabet[symbol] in self.transition_map[state]): raise OblivionError - state = self.map[state][self.alphabet[symbol]] + state = self.transition_map[state][self.alphabet[symbol]] # OK so now we have consumed that string, use the new location as the # starting point. @@ -868,7 +897,7 @@ def derive(self, input): states=self.states, initial=state, finals=self.finals, - map=self.map, + transition_map=self.transition_map, __no_validation__=True, ) @@ -885,10 +914,10 @@ def null(alphabet): """ return FSM( alphabet=alphabet, - states={0}, + states=frozenset({0}), initial=0, - finals=set(), - map={ + finals=frozenset(), + transition_map={ 0: dict([(transition, 0) for transition in alphabet.by_transition]), }, __no_validation__=True, @@ -902,10 +931,10 @@ def epsilon(alphabet): """ return FSM( alphabet=alphabet, - states={0}, + states=frozenset({0}), initial=0, - finals={0}, - map={}, + finals=frozenset({0}), + transition_map={}, __no_validation__=True, ) @@ -922,21 +951,28 @@ def parallel(fsms, test): # dedicated function accepts a "superset" and returns the next "superset" # obtained by following this transition in the new FSM - def follow(current, new_transition, fsm_range=tuple(enumerate(fsms))): - next = {} + def follow(current, new_transition, fsm_range=None): + fsm_range = fsm_range or enumerate(fsms) + next_state = {} + for i, f in fsm_range: + if i not in current: + continue + old_transition = new_to_old[i][new_transition] - if i in current \ - and current[i] in f.map \ - and old_transition in f.map[current[i]]: - next[i] = f.map[current[i]][old_transition] - if not next: + + current_i = current[i] + if current_i in f.transition_map and old_transition in f.transition_map[current_i]: + next_state[i] = f.transition_map[current_i][old_transition] + + if not next_state: raise OblivionError - return next + return next_state # Determine the "is final?" condition of each substate, then pass it to the # test to determine finality of the overall FSM. - def final(state, fsm_range=tuple(enumerate(fsms))): + def final(state, fsm_range=None): + fsm_range = fsm_range or enumerate(fsms) accepts = [i in state and state[i] in fsm.finals for (i, fsm) in fsm_range] return test(accepts) @@ -944,7 +980,7 @@ def final(state, fsm_range=tuple(enumerate(fsms))): def crawl_hash_no_result(alphabet, initial, final, follow): - unvisited = {initial} + unvisited = [initial] visited = set() while unvisited: @@ -963,10 +999,10 @@ def crawl_hash_no_result(alphabet, initial, final, follow): continue else: if new not in visited: - unvisited.add(new) + unvisited.append(new) -def crawl(alphabet, initial, final, follow): +def crawl(alphabet: Alphabet, initial: Any, final: Callable[[Any], bool], follow: Callable[[Any, TransitionKey], Any]): """ Given the above conditions and instructions, crawl a new unknown FSM, mapping its states, final states and transitions. Return the new FSM. @@ -974,9 +1010,19 @@ def crawl(alphabet, initial, final, follow): forever if you supply an evil version of follow(). """ + def get_hash(obj): + if isinstance(obj, set): + return hash(frozenset(obj)) + elif isinstance(obj, dict): + return hash(tuple(sorted(obj.items()))) + return hash(obj) + + transitions_in_alphabet = alphabet.by_transition.keys() + states = [initial] - finals = set() - map = {} + state_idx: Dict[int, int] = {get_hash(initial): 0} + finals = [] + transition_map = {} # iterate over a growing list i = 0 @@ -985,31 +1031,36 @@ def crawl(alphabet, initial, final, follow): # add to finals if final(state): - finals.add(i) + finals.append(i) # compute map for this state - map[i] = {} - for transition in alphabet.by_transition: + transition_map[i] = {} + for transition in transitions_in_alphabet: try: - next = follow(state, transition) + next_state = follow(state, transition) + except OblivionError: # Reached an oblivion state. Don't list it. continue + else: - try: - j = states.index(next) - except ValueError: + next_hash = get_hash(next_state) + if next_hash in state_idx: + j = state_idx[next_hash] + else: j = len(states) - states.append(next) - map[i][transition] = j + states.append(next_state) + state_idx[next_hash] = j + + transition_map[i][transition] = j i += 1 return FSM( alphabet=alphabet, - states=range(len(states)), + states=frozenset(range(len(states))), initial=0, - finals=finals, - map=map, + finals=frozenset(finals), + transition_map=transition_map, __no_validation__=True, - ) + ) \ No newline at end of file diff --git a/interegular/patterns.py b/interegular/patterns.py index bb09dbb..d9154d1 100644 --- a/interegular/patterns.py +++ b/interegular/patterns.py @@ -159,10 +159,10 @@ def to_fsm(self, alphabet=None, prefix_postfix=None, flags=REFlags(0)) -> FSM: return FSM( alphabet=alphabet, - states={0, 1}, + states=frozenset({0, 1}), initial=0, - finals={1}, - map=mapping, + finals=frozenset({1}), + transition_map=mapping, ) def simplify(self) -> '_CharGroup': @@ -190,10 +190,10 @@ def to_fsm(self, alphabet=None, prefix_postfix=None, flags=REFlags(0)) -> FSM: symbols = alphabet return FSM( alphabet=alphabet, - states={0, 1}, + states=frozenset({0, 1}), initial=0, - finals={1}, - map={0: {alphabet[sym]: 1 for sym in symbols}}, + finals=frozenset({1}), + transition_map={0: {alphabet[sym]: 1 for sym in symbols}}, ) def _get_alphabet(self, flags: REFlags) -> Alphabet: