Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions regroup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def match(strings):
return DAWG.from_iter(strings).serialize()


class StringSet:
class StringSet(object):

'''
a set of strings
Expand All @@ -35,7 +35,7 @@ def __iter__(self):
return iter(self.strings.keys())


class TaggedString:
class TaggedString(object):

def __init__(self, string, tokenizer=None):
tokenizer = tokenizer or Tokenizer()
Expand All @@ -51,7 +51,7 @@ def escape(s):
return re.escape(s).replace(r'\ ', ' ')


class Trie:
class Trie(object):

'''
Trie
Expand Down Expand Up @@ -99,7 +99,7 @@ def _build(self, strings):
return root


class DAWG:
class DAWG(object):

'''
Directed Acyclic Word Graph
Expand All @@ -109,6 +109,7 @@ class DAWG:

def __init__(self, trie=None):
self.dawg = DAWG._build(trie)
self.trie = trie

@classmethod
def from_iter(cls, strings):
Expand Down Expand Up @@ -167,7 +168,8 @@ def flatten(d, clusters=None):
def _flatten(cls, d, path):
for k, v in sorted(d.items()):
if k:
yield from cls._flatten(v, path + k)
for item in cls._flatten(v, path + k):
yield item
else:
yield path

Expand All @@ -193,18 +195,18 @@ def _cluster_by_prefixlen(cls, length, clusters, d, path):
else:
cls._cluster_by_prefixlen(length, clusters, v, path2)

def dawg_weights(d, l):
def dawg_weights(self, d, l):
"""given a DAWG and the original list, calculate weights at each branchpoint"""
weights = defaultdict(int)
_dawg_weights(l, weights, d, [])
self._dawg_weights(l, weights, d, [])
return dict(weights)

def _dawg_weights(l, weights, d, path):
def _dawg_weights(self, l, weights, d, path):
for k, v in d.items():
path2 = path + [k]
path2str = ''.join(path)
weights[tuple(path2)] += sum(1 for x in l if x.startswith(path2str))
_dawg_weights(l, weights, v, path2)
self._dawg_weights(l, weights, v, path2)

def top_weights(weights, n):
wsorted = sorted(weights.items(),
Expand Down Expand Up @@ -329,7 +331,7 @@ def as_optional_group(strings):


def all_len01(l):
return set(map(len, l)) == {0, 1}
return set(map(len, l)) == set([0, 1])


def is_optional_char_class(d):
Expand Down Expand Up @@ -484,7 +486,8 @@ def _relaxable(cls, d):
yield (diffcnt, d)
for k, v in d.items():
if len(v) > 1:
yield from cls._relaxable(v)
for item in cls._relaxable(v):
yield item

def relax(self, threshold=1):
'''
Expand All @@ -503,7 +506,7 @@ def relax(self, threshold=1):

def do_relax(self, d):
merged = reduce(dict_merge, d.values(), {})
d2 = {k: merged for k in d}
d2 = dict([(k, merged) for k in d])
# print('merged', merged)
# print('d2', d2)
return DAWGRelaxer._replace(self.dawg.dawg, d, d2)
Expand All @@ -512,5 +515,4 @@ def do_relax(self, d):
def _replace(cls, dawg, find, replace):
if dawg == find:
return replace
return {k: cls._replace(v, find, replace)
for k, v in dawg.items()}
return dict([(k, cls._replace(v, find, replace)) for k, v in dawg.items()])
36 changes: 29 additions & 7 deletions regroup/cluster.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@

import re

class Cluster:
class Cluster(object):
def __init__(self):
self.left = None
self.dist = None
self.right = None


def __repr__(self):
return '({} {} {})'.format(self.left, self.dist, self.right)


def dump(self, indent=0):
if isinstance(self.left, str):
print(' ' * indent, self.left)
Expand All @@ -17,33 +22,49 @@ def dump(self, indent=0):
print(' ' * indent, self.right)
else:
self.right.dump(indent=indent+1)


def __iter__(self):
yield self
if self.left:
yield from self.left
for item in self.left:
yield item
if self.right:
yield from self.right
for item in self.right:
yield item


def leaves(self):
if isinstance(self.left, str):
yield self.left
else:
yield from self.left.leaves()
for item in self.left.leaves():
yield item
if isinstance(self.right, str):
yield self.right
else:
yield from self.right.leaves()
for item in self.right.leaves():
yield item


def distances(self):
yield self.dist
if isinstance(self.left, Cluster):
yield from self.left.distances()
for item in self.left.distances():
yield item
if isinstance(self.right, Cluster):
yield from self.right.distances()
for item in self.right.distances():
yield item


def clusters_by(self, dist):
if self.dist <= dist:
return list(self.leaves())
else:
return (self.left.clusters_by(dist),
self.right.clusters_by(dist))


def add(self, clusters, grid, lefti, righti):
self.left = clusters[lefti]
self.right = clusters[righti]
Expand Down Expand Up @@ -77,6 +98,7 @@ def agglomerate(labels, grid):


def strdist(x, y):
#TODO: function levenshtein never got imported or defined
return levenshtein(x, y)


Expand Down
7 changes: 4 additions & 3 deletions regroup/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def tokenize_regex_case_sensitive(string):
yield token


class Tokenizer:
class Tokenizer(object):

def __init__(self):
pass
Expand All @@ -26,6 +26,7 @@ def tokenize(self, string):
class DictionaryTokenizer(Tokenizer):

def __init__(self, wordset=None):
Tokenizer.__init__(self)
wordset = set(wordset) if wordset else set()
self.wordset = wordset
self.wordsetlen = defaultdict(set)
Expand Down Expand Up @@ -58,7 +59,7 @@ def fallback(self, string):
return string[0]


class Tagged:
class Tagged(object):

def __init__(self, string, tag):
self.string = string
Expand All @@ -68,7 +69,7 @@ def __str__(self):
return self.string


class TaggingTokenizer:
class TaggingTokenizer(object):

def __init__(self, tags):
self.tags = tags
Expand Down