# -*- coding: utf-8 -*-

from tuikit.eventsource import EventSource
from tuikit.widget import Widget


class TreeIter:
    def __init__(self, root, collapsed=[]):
        self._node = root
        self._index = 0
        self._stack = []
        self._collapsed = collapsed
        
    def __next__(self):
        node = None
        while node is None:
            try:
                if self._node in self._collapsed:
                    raise IndexError()
                node = self._node[self._index]
                if node is None:
                    raise Exception('Bad node: None')
            except IndexError:
                if len(self._stack):
                    self._node, self._index = self._stack.pop()
                else:
                    raise StopIteration

        level = len(self._stack) + 1
        index = self._index + 1
        count = len(self._node)

        self._index += 1
        
        self._stack.append((self._node, self._index))
        self._node = node
        self._index = 0
        
        return (level, index, count, node)


class TreeNode(list):
    def __init__(self, parent=None, name=''):
        list.__init__(self)
        self.parent = parent
        self.name = name
    
    def __eq__(self, other):
        # do not compare by list content
        return self is other


class TreeModel(EventSource):
    def __init__(self):
        EventSource.__init__(self)
        self.addevents('change')
        self.root = TreeNode()
        
    def __iter__(self):
        return TreeIter(self.root)

    def find(self, path):
        if isinstance(path, str):
            path = path.split('/')
        # strip empty strings from both ends
        while path and path[0] == '':
            del path[0]
        while path and path[-1] == '':
            del path[-1]

        node = self.root
        for item in path:
            if isinstance(item, int):
                node = node[item]
            else:
                found = False
                for subnode in node:
                    if subnode.name == item:
                        node = subnode
                        found = True
                        break
                if not found:
                    item = int(item)
                    node = node[item]
        
        return node

    def add(self, path, names):
        node = self.find(path)
        
        if isinstance(names, str):
            names = (names,)
            
        for name in names:
            node.append(TreeNode(node, name))
            
        self.emit('change')


class TreeView(Widget):
    def __init__(self, model=None, width=20, height=20):
        Widget.__init__(self, width, height)
        
        # cursor
        self.cnode = None

        # model
        self._model = None
        self.setmodel(model)
        
        self.collapsed = []
        
        self.connect('draw', self.on_draw)
        self.connect('keypress', self.on_keypress)

    def __iter__(self):
        return TreeIter(self._model.root, self.collapsed)

    def getmodel(self):
        '''TreeModel in use by this TreeView.'''
        return self._model

    def setmodel(self, value):
        if self._model:
            self._model.disconnect('change', self.redraw)
        self._model = value
        if self._model:
            self._model.connect('change', self.redraw)
            try:
                self.cnode = self._model.root[0]
            except IndexError:
                pass
    
    model = property(getmodel, setmodel)

    def collapse(self, path, collapse=True):
        node = self._model.find(path)
        self.collapse_node(node, collapse)
    
    def collapse_node(self, node, collapse=True):
        if collapse:
            if not node in self.collapsed and len(node) > 0:
                self.collapsed.append(node)
        else:
            try:
                self.collapsed.remove(node)
            except ValueError:
                pass

    def on_draw(self, screen, x, y):
        screen.pushcolor('normal')
        
        lines = 0  # bit array, bit 0 - draw vertical line on first column, etc. 
        for level, index, count, node in self:
            # prepare string with vertical lines where they should be
            head = []
            for l in range(level-1):
                if lines & (1 << l):
                    head.append(screen.unigraph.VLINE + ' ')
                else:
                    head.append('  ')
            # add vertical line if needed
            if index < count:
                head.append(screen.unigraph.LTEE)
                lines |= 1 << level-1
            else:
                head.append(screen.unigraph.LLCORNER)
                lines &= ~(1 << level-1)
            # draw lines and name
            head = ''.join(head)
            if node in self.collapsed:
                sep = '+'
            else:
                sep = ' '
            screen.puts(x, y, head + sep + node.name)
            if node is self.cnode:
                screen.pushcolor('active')
                screen.puts(x + len(head), y, sep + node.name + ' ')
                screen.popcolor()
                
            y += 1
        
        screen.popcolor()        

    def on_keypress(self, keyname, char):
        if keyname:
            if keyname == 'up':    self.move_up()
            if keyname == 'down':  self.move_down()
            if keyname == 'left':  self.move_left()
            if keyname == 'right': self.move_right()

        self.redraw()

    def prev_node(self, node):
        # previous sibling
        parent = node.parent
        i = parent.index(node)
        if i > 0:
            node = parent[i-1]
            while node not in self.collapsed and len(node) > 0:
                node = node[-1]
            return node
        else:
            if parent.parent is None:
                return None
            return parent

    def next_node(self, node):
        if node in self.collapsed or len(node) == 0:
            # next sibling
            parent = node.parent
            while parent is not None:
                i = parent.index(node)
                try:
                    return parent[i+1]
                except IndexError:
                    node = parent
                    parent = node.parent
            return None
        else:
            # first child
            return node[0]

    def move_up(self):
        prev = self.prev_node(self.cnode)
        if prev is not None:
            self.cnode = prev
            return True
        return False

    def move_down(self):
        node = self.next_node(self.cnode)
        if node is not None:
            self.cnode = node
            return True
        return False

    def move_left(self):
        self.collapse_node(self.cnode, True)
    
    def move_right(self):
        self.collapse_node(self.cnode, False)

