Clean up and update TreeView.
authorRadek Brich <radek.brich@devl.cz>
Mon, 17 Dec 2012 21:07:59 +0100
changeset 37 54dd866b8951
parent 36 caf927c3f10b
child 38 c6e170452c7f
Clean up and update TreeView.
tests/test_treeview.py
tuikit/treeview.py
--- a/tests/test_treeview.py	Mon Dec 17 00:24:34 2012 +0100
+++ b/tests/test_treeview.py	Mon Dec 17 21:07:59 2012 +0100
@@ -1,16 +1,16 @@
 #!/usr/bin/env python3
 
 import sys
-sys.path.append('..')
+sys.path.insert(0, '..')
 
 from tuikit.treeview import *
 import unittest
 
 
 class TestTreeView(unittest.TestCase):
-    def test_treemodel(self):
-        '''Build tree model, iterate through the tree, test result.'''
-        # build tree model
+    def setUp(self):
+        """Build tree model
+
         # root
         # ├ a
         # │ ├ c
@@ -20,19 +20,31 @@
         # │   └ f
         # │     └ h
         # └ b
-        model = TreeModel()
-        model.add('/',  ['a', 'b'])
-        model.add('/a', ['c', 'd'])
-        model.add((0,1), ['e', 'f'])
-        model.add('/0/1/0', 'g')
-        model.add('/a/d/f', 'h')
-        
+
+        """
+        self.model = TreeModel()
+        self.model.add('/',  ['a', 'b'])
+        self.model.add('/a', [TreeNode('c'), TreeNode('d')])
+        self.model.add((0,1), ['e', TreeNode('f')])
+        self.model.add([0,1,0], 'g')
+        self.model.add('/a/d/f', TreeNode('h'))
+
+    def test_treeiter(self):
+        """Iterate through the tree, test result."""
         res = ''
-        for l, i, c, n in model:
+        for l, i, c, n in self.model:
             res += str(l) + str(i) + str(c) + n.name
-            
+
         self.assertEqual(res, '112a212c222d312e411g322f411h122b')
 
+    def test_treemodel_find(self):
+        # good path
+        node = self.model.find('/a/d/f/h')
+        self.assertEqual(node.name, 'h')
+        self.assertEqual(node.path, '/a/d/f/h')
+        # bad path
+        self.assertRaises(ValueError, self.model.find, '/a/b/c')
+        self.assertRaises(IndexError, self.model.find, [0,1,3])
 
 if __name__ == '__main__':
     unittest.main()
--- a/tuikit/treeview.py	Mon Dec 17 00:24:34 2012 +0100
+++ b/tuikit/treeview.py	Mon Dec 17 21:07:59 2012 +0100
@@ -19,7 +19,7 @@
             try:
                 if self._node in self._collapsed:
                     raise IndexError()
-                node = self._node[self._index]
+                node = self._node.children[self._index]
                 if node is None:
                     raise Exception('Bad node: None')
             except IndexError:
@@ -30,7 +30,7 @@
 
         level = len(self._stack) + 1
         index = self._index + 1
-        count = len(self._node)
+        count = len(self._node.children)
 
         self._index += 1
 
@@ -41,26 +41,58 @@
         return (level, index, count, node)
 
 
-class TreeNode(list):
-    def __init__(self, parent=None, name=''):
-        list.__init__(self)
+class TreeNode:
+    def __init__(self, name, title=None, parent=None, model=None):
+        self.model = model
+        self.children = []
         self.parent = parent
         self.name = name
+        self.title = title or name
 
-    def __eq__(self, other):
-        # do not compare by list content
-        return self is other
+    def __iter__(self):
+        return iter(self.children)
+
+    def __str__(self):
+        return self.title
+
+    def __repr__(self):
+        return 'TreeNode(%r)' % self.name
+
+    @property
+    def path(self):
+        if self.parent:
+            return self.parent.path + '/' + self.name
+        else:
+            return self.name
+
+    def add(self, node):
+        node.parent = self
+        node.model = self.model
+        self.children.append(node)
+        self.model.emit('node_added', node)
 
 
 class TreeModel(Emitter):
     def __init__(self):
-        self.add_events('change')
-        self.root = TreeNode()
+        self.add_events('node_added')  # node added, arg is the node
+        self.root = TreeNode('', model=self)
 
     def __iter__(self):
         return TreeIter(self.root)
 
     def find(self, path):
+        """Find node by path.
+
+        Supports two variants of path:
+          '/name1/name2/name3'
+          [0,2,1]
+
+        Numeric variant uses index of node on each level.
+
+        Raises ValueError for unknown component in str variant,
+        IndexError for bad index in numeric variant.
+
+        """
         if isinstance(path, str):
             path = path.split('/')
         # strip empty strings from both ends
@@ -70,32 +102,45 @@
             del path[-1]
 
         node = self.root
-        for item in path:
-            if isinstance(item, int):
-                node = node[item]
+        for component in path:
+            if isinstance(component, int):
+                node = node.children[component]
             else:
                 found = False
                 for subnode in node:
-                    if subnode.name == item:
+                    if subnode.name == component:
                         node = subnode
                         found = True
                         break
                 if not found:
-                    item = int(item)
-                    node = node[item]
+                    raise ValueError('Node not found at component %r of path %r' % (component, path))
 
         return node
 
-    def add(self, path, names):
-        node = self.find(path)
+    def add(self, path, nodes):
+        """Add node(s) to model at path.
+
+        There are four variants for nodes parameter:
+          add('/', 'name')
+          add('/', TreeNode('name'))
+          add('/', ['name1', 'name2'])
+          add('/', [TreeNode('name1'), TreeNode('name2')])
+
+        First two will add one node to root, next two adds two nodes.
 
-        if isinstance(names, str):
-            names = (names,)
+        First and third variant uses strings.
+        It does exactly the same as second and fourth variant.
+
+        """
+        parent_node = self.find(path)
 
-        for name in names:
-            node.append(TreeNode(node, name))
+        if isinstance(nodes, str) or isinstance(nodes, TreeNode):
+            nodes = [nodes]
 
-        self.emit('change')
+        for node in nodes:
+            if isinstance(node, str):
+                node = TreeNode(node)
+            parent_node.add(node)
 
 
 class TreeView(Widget):
@@ -109,36 +154,38 @@
 
         # model
         self._model = None
-        self.setmodel(model)
+        if model:
+            self.model = model
 
         self.collapsed = []
 
+        self.add_events(
+            'expand',   # node expanded, the affected node is given in args
+            'collapse') # node collapsed, the affected node is given in args
+
     def __iter__(self):
         return TreeIter(self._model.root, self.collapsed)
 
-    def getmodel(self):
-        '''TreeModel in use by this TreeView.'''
+    @property
+    def model(self):
+        """TreeModel in use by this TreeView."""
         return self._model
 
-    def setmodel(self, value):
+    @model.setter
+    def model(self, value):
         if self._model:
-            self._model.disconnect('change', self.model_change)
+            self._model.disconnect('node_added', self.on_model_node_added)
         self._model = value
         if self._model:
-            self._model.connect('change', self.model_change)
+            self._model.connect('node_added', self.on_model_node_added)
             try:
-                self.cnode = self._model.root[0]
+                self.cnode = self._model.root.children[0]
             except IndexError:
                 pass
 
-    model = property(getmodel, setmodel)
-
-    def model_change(self):
+    def on_model_node_added(self, node):
         if self.cnode is None:
-            try:
-                self.cnode = self._model.root[0]
-            except IndexError:
-                pass
+            self.cnode = node
         self.redraw()
 
     def collapse(self, path, collapse=True):
@@ -147,17 +194,19 @@
 
     def collapse_node(self, node, collapse=True):
         if collapse:
-            if not node in self.collapsed and len(node) > 0:
+            if not node in self.collapsed and len(node.children) > 0:
                 self.collapsed.append(node)
+                self.emit('collapse', node)
         else:
             try:
                 self.collapsed.remove(node)
+                self.emit('expand', node)
             except ValueError:
                 pass
 
-    def on_draw(self, screen, x, y):
-        super().on_draw(screen, x, y)
-        screen.pushcolor('normal')
+    def on_draw(self, driver, x, y):
+        super().on_draw(driver, x, y)
+        driver.pushcolor('normal')
 
         lines = 0  # bit array, bit 0 - draw vertical line on first column, etc.
         for level, index, count, node in self:
@@ -165,31 +214,31 @@
             head = []
             for l in range(level-1):
                 if lines & (1 << l):
-                    head.append(screen.unigraph.VLINE + ' ')
+                    head.append(driver.unigraph.VLINE + ' ')
                 else:
                     head.append('  ')
             # add vertical line if needed
             if index < count:
-                head.append(screen.unigraph.LTEE)
+                head.append(driver.unigraph.LTEE)
                 lines |= 1 << level-1
             else:
-                head.append(screen.unigraph.LLCORNER)
+                head.append(driver.unigraph.LLCORNER)
                 lines &= ~(1 << level-1)
-            # draw lines and name
+            # draw lines and titles
             head = ''.join(head)
             if node in self.collapsed:
                 sep = '+'
             else:
                 sep = ' '
-            screen.puts(x, y, head + sep + node.name)
+            driver.puts(x, y, head + sep + str(node))
             if node is self.cnode:
-                screen.pushcolor('active')
-                screen.puts(x + len(head), y, sep + node.name + ' ')
-                screen.popcolor()
+                driver.pushcolor('active')
+                driver.puts(x + len(head), y, sep + str(node) + ' ')
+                driver.popcolor()
 
             y += 1
 
-        screen.popcolor()
+        driver.popcolor()
 
     def on_keypress(self, keyname, char):
         super().on_keypress(keyname, char)
@@ -204,11 +253,11 @@
     def prev_node(self, node):
         # previous sibling
         parent = node.parent
-        i = parent.index(node)
+        i = parent.children.index(node)
         if i > 0:
-            node = parent[i-1]
-            while node not in self.collapsed and len(node) > 0:
-                node = node[-1]
+            node = parent.children[i-1]
+            while node not in self.collapsed and len(node.children) > 0:
+                node = node.children[-1]
             return node
         else:
             if parent.parent is None:
@@ -216,20 +265,19 @@
             return parent
 
     def next_node(self, node):
-        if node in self.collapsed or len(node) == 0:
+        if node in self.collapsed or len(node.children) == 0:
             # next sibling
             parent = node.parent
             while parent is not None:
-                i = parent.index(node)
+                i = parent.children.index(node)
                 try:
-                    return parent[i+1]
+                    return parent.children[i+1]
                 except IndexError:
                     node = parent
                     parent = node.parent
             return None
         else:
-            # first child
-            return node[0]
+            return node.children[0]
 
     def move_up(self):
         prev = self.prev_node(self.cnode)