diff --git a/pybars/_compiler.py b/pybars/_compiler.py index 611709e..c284472 100644 --- a/pybars/_compiler.py +++ b/pybars/_compiler.py @@ -42,6 +42,7 @@ except NameError: # Python 3 support str_class = str + basestring = str # Flag for testing @@ -163,31 +164,43 @@ class PybarsError(Exception): pass -class strlist(list): +class strlist(object): + __slots__ = ['value'] + def __init__(self, default=None): + self.value = u'' + if default: + self.grow(default) - """A quasi-list to let the template code avoid special casing.""" + def __str__(self): + return self.value + + def __unicode__(self): + return self.value - def __str__(self): # Python 3 - return ''.join(self) + def append(self, other): + self.value += other - def __unicode__(self): # Python 2 - return u''.join(self) + def extend(self, other): + if type(other) is strlist: + self.value += other.value + else: + self.value += ''.join(other) - def grow(self, thing): - """Make the list longer, appending for unicode, extending otherwise.""" - if type(thing) == str_class: - self.append(thing) + def __iter__(self): + return iter([self]) - # This will only ever match in Python 2 since str_class is str in - # Python 3. - elif type(thing) == str: - self.append(unicode(thing)) + def __add__(self, other): + self.value += other.value + return self + def grow(self, other): + if isinstance(other, basestring): + self.value += other + elif type(other) is strlist: + self.value += other.value else: - # Recursively expand to a flat list; may deserve a C accelerator at - # some point. - for element in thing: - self.grow(element) + for item in other: + self.grow(item) _map = { @@ -212,14 +225,15 @@ def escape(something, _escape_re=_escape_re, substitute=substitute): def pick(context, name, default=None): - if isinstance(name, str) and hasattr(context, name): + if type(name) is str and hasattr(context, name): return getattr(context, name) if hasattr(context, 'get'): return context.get(name) try: return context[name] except (KeyError, TypeError): - return default + pass + return default sentinel = object() @@ -270,34 +284,40 @@ def __unicode__(self): return unicode(self.context) +ITERABLE_TYPES = (list, tuple) + def resolve(context, *segments): - carryover_data = False + context_type = type(context) # This makes sure that bare "this" paths don't return a Scope object - if segments == ('',) and isinstance(context, Scope): + if segments == ('',) and context_type is Scope: return context.get('this') + carryover_data = False for segment in segments: + if context is None: + return None + # Handle @../index syntax by popping the extra @ along the segment path if carryover_data: + segment = u'@' + segment carryover_data = False - segment = u'@%s' % segment - if len(segment) > 1 and segment[0:2] == '@@': + + if segment[:2] == '@@': segment = segment[1:] carryover_data = True - if context is None: - return None - if segment in (None, ""): + if not segment: continue - if type(context) in (list, tuple): - offset = int(segment) - context = context[offset] - elif isinstance(context, Scope): + + if context_type is Scope: context = context.get(segment) + elif context_type in ITERABLE_TYPES: + context = context[int(segment)] else: context = pick(context, segment) + context_type = type(context) return context @@ -336,7 +356,7 @@ def prepare(value, should_escape): def ensure_scope(context, root): - return context if isinstance(context, Scope) else Scope(context, context, root) + return context if type(context) is Scope else Scope(context, context, root) def _each(this, options, context): @@ -509,8 +529,8 @@ def start(self): ]) else: self._result.grow(u"def %s(context, helpers, partials, root):\n" % function_name) - self._result.grow(u" result = strlist()\n") self._result.grow(u" context = ensure_scope(context, root)\n") + self._result.grow(u" result = strlist()\n") def finish(self): lines, ns, function_name = self.stack.pop(-1) @@ -521,7 +541,7 @@ def finish(self): self._result.grow(u" result = %s(result)\n" % str_class.__name__) self._result.grow(u" return result\n") - source = str_class(u"".join(lines)) + source = str_class(lines) self._result = self.stack and self.stack[-1][0] self._locals = self.stack and self.stack[-1][1] @@ -578,11 +598,12 @@ def add_block(self, symbol, arguments, nested, alt_nested): u" value = helper(context, options%s\n" % call, u" else:\n" u" value = helpers['blockHelperMissing'](context, options, value)\n" - u" result.grow(value or '')\n" + u" if value:\n" + u" result.grow(value)\n" ]) def add_literal(self, value): - self._result.grow(u" result.append(%s)\n" % repr(value)) + self._result.grow(u" result.value += %s\n" % (repr(value),)) def _lookup_arg(self, arg): if not arg: