tuikit/treeview.py
author Radek Brich <radek.brich@devl.cz>
Wed, 20 Aug 2014 15:06:52 +0200
changeset 102 29a8a26a721f
parent 77 fc1989059e19
permissions -rw-r--r--
Update TreeView (old uncommitted work).

# -*- coding: utf-8 -*-

from tuikit.events import Event, Emitter
from tuikit.widget import Widget


class TreeEvent(Event):
    def __init__(self, node):
        Event.__init__(self)
        self.node = node


class TreeIter:
    """Iterates nodes under root in depth-first order.

    This is useful for displaying the tree.

    """

    def __init__(self, root, collapsed_nodes=[]):
        self._node = root
        self._index = 0
        self._stack = []
        self._collapsed_nodes = collapsed_nodes

    def __next__(self):
        node = None
        while node is None:
            try:
                if self._node in self._collapsed_nodes:
                    raise IndexError()
                node = self._node.children[self._index]
                if node is None:
                    raise Exception('Bad node: None')
            except IndexError:
                if len(self._stack):
                    self._node, self._index = self._stack.pop()
                else:
                    raise StopIteration

        level = len(self._stack) + 1
        index = self._index + 1
        count = len(self._node.children)

        self._index += 1

        self._stack.append((self._node, self._index))
        self._node = node
        self._index = 0

        return (level, index, count, node)


class TreeNode:
    """Node of tree.

    Maintains its parent and children.

    Attributes:
     * name - used for searching
     * title - this is displayed on screen

    """

    def __init__(self, name, title=None, parent=None, model=None):
        self.model = model
        self.children = []
        self.parent = parent
        self.name = name
        self.title = title or name

    def __iter__(self):
        return iter(self.children)

    def __str__(self):
        return self.title

    def __repr__(self):
        return 'TreeNode(%r)' % self.name

    @property
    def path(self):
        """Path of this node in model."""
        if self.parent:
            return self.parent.path + '/' + self.name
        else:
            return self.name

    def add(self, node):
        """Add child and connect it to self."""
        node.parent = self
        node.model = self.model
        self.children.append(node)
        self.model.emit('node_added', node)


class TreeModel(Emitter):
    """Tree data model.

    Tree model stores all nodes of tree but knows nothing about displaying them.
    Same model can be used in many views.

    """

    def __init__(self):
        self.add_events('node_added', TreeEvent)
        self.root = TreeNode('', model=self)

    def __iter__(self):
        return TreeIter(self.root)

    def find(self, path):
        """Find node by path.

        Supports two variants of path:
          '/name1/name2/name3'
          [0,2,1]

        Numeric variant uses index of node on each level.

        Raises ValueError for unknown component in str variant,
        IndexError for bad index in numeric variant.

        """
        if isinstance(path, str):
            path = path.split('/')
        # strip empty strings from both ends
        while path and path[0] == '':
            del path[0]
        while path and path[-1] == '':
            del path[-1]

        node = self.root
        for component in path:
            if isinstance(component, int):
                node = node.children[component]
            else:
                found = False
                for subnode in node:
                    if subnode.name == component:
                        node = subnode
                        found = True
                        break
                if not found:
                    raise ValueError('Node not found at component %r of path %r' % (component, path))

        return node

    def add(self, path, nodes):
        """Add node(s) to model at path.

        There are four variants for nodes parameter:
          add('/', 'name')
          add('/', TreeNode('name'))
          add('/', ['name1', 'name2'])
          add('/', [TreeNode('name1'), TreeNode('name2')])

        First two will add one node to root, next two adds two nodes.

        First and third variant uses strings.
        It does exactly the same as second and fourth variant.

        """
        parent_node = self.find(path)

        if isinstance(nodes, str) or isinstance(nodes, TreeNode):
            nodes = [nodes]

        for node in nodes:
            if isinstance(node, str):
                node = TreeNode(node)
            parent_node.add(node)


class TreeView(Widget):
    """Tree view displays data from tree model."""

    def __init__(self, model=None):
        Widget.__init__(self)
        self._default_size.update(20, 20)

        self.allow_focus = True

        # model
        self._model = None

        # cursor
        self._cursor_node = None

        self.collapsed_nodes = []

        self.add_events(
            'expand', TreeEvent,  # node expanded, event carries the affected node
            'collapse', TreeEvent) # node collapsed, event carries the affected node

        if model:
            self.model = model

    def __iter__(self):
        return TreeIter(self._model.root, self.collapsed_nodes)

    @property
    def model(self):
        """TreeModel in use by this TreeView."""
        return self._model

    @model.setter
    def model(self, value):
        if self._model:
            self._model.remove_handler('node_added', self.on_model_node_added)
        self._model = value
        if self._model:
            self._model.add_handler('node_added', self.on_model_node_added)
            try:
                self.cursor_node = self._model.root.children[0]
            except IndexError:
                pass
            self._update_sizereq()

    def on_model_node_added(self, ev):
        if self.cursor_node is None:
            self.cursor_node = ev.node
        self._update_sizereq()
        self.redraw()

    @property
    def cursor_node(self):
        return self._cursor_node

    @cursor_node.setter
    def cursor_node(self, value):
        self._cursor_node = value
        self._update_spot()

    def collapse(self, path, collapse=True):
        node = self._model.find(path)
        self.collapse_node(node, collapse)

    def collapse_node(self, node, collapse=True):
        if collapse:
            if not node in self.collapsed_nodes and len(node.children) > 0:
                self.collapsed_nodes.append(node)
                self.emit('collapse', node)
        else:
            try:
                self.collapsed_nodes.remove(node)
                self.emit('expand', node)
            except ValueError:
                pass
        self._update_sizereq()

    def on_draw(self, ev):
        ev.driver.pushcolor('normal')
        ev.driver.fill_clip()

        lines = 0  # bit array, bit 0 - draw vertical line on first column, etc.
        y = ev.y
        for level, index, count, node in self:
            # prepare string with vertical lines where they should be
            head = []
            for l in range(level-1):
                if lines & (1 << l):
                    head.append(ev.driver.unigraph.VLINE + ' ')
                else:
                    head.append('  ')
            # add vertical line if needed
            if index < count:
                head.append(ev.driver.unigraph.LTEE)
                lines |= 1 << level-1
            else:
                head.append(ev.driver.unigraph.LLCORNER)
                lines &= ~(1 << level-1)
            # draw lines and titles
            head = ''.join(head)
            if node in self.collapsed_nodes:
                sep = '+'
            else:
                sep = ' '
            ev.driver.puts(ev.x, y, head + sep + str(node))
            if node is self.cursor_node:
                ev.driver.pushcolor('active')
                ev.driver.puts(ev.x + len(head), y, sep + str(node) + ' ')
                ev.driver.popcolor()
            y += 1

        ev.driver.popcolor()

    def on_keypress(self, ev):
        key_map = {
            'up': self.move_up,
            'down': self.move_down,
            'left': self.move_left,
            'right': self.move_right}
        if ev.keyname in key_map:
            key_map[ev.keyname]()
            self.redraw()
            return True

    def prev_node(self, node):
        # previous sibling
        parent = node.parent
        i = parent.children.index(node)
        if i > 0:
            node = parent.children[i-1]
            while node not in self.collapsed_nodes and len(node.children) > 0:
                node = node.children[-1]
            return node
        else:
            if parent.parent is None:
                return None
            return parent

    def next_node(self, node):
        if node in self.collapsed_nodes or len(node.children) == 0:
            # next sibling
            parent = node.parent
            while parent is not None:
                i = parent.children.index(node)
                try:
                    return parent.children[i+1]
                except IndexError:
                    node = parent
                    parent = node.parent
            return None
        else:
            return node.children[0]

    def move_up(self):
        prev = self.prev_node(self.cursor_node)
        if prev is not None:
            self.cursor_node = prev
            return True
        return False

    def move_down(self):
        node = self.next_node(self.cursor_node)
        if node is not None:
            self.cursor_node = node
            return True
        return False

    def move_left(self):
        self.collapse_node(self.cursor_node, True)

    def move_right(self):
        self.collapse_node(self.cursor_node, False)

    def _update_sizereq(self):
        height = 0
        for num, _ in enumerate(self, start=1):
            height = num
        self.sizereq.update(h = height)

    def _update_spot(self):
        """Update spot to current position of cursor node."""
        for num, (_level, _index, _count, node) in enumerate(self):
            if node is self.cursor_node:
                self._spot.update(_level * 2, num)