#
# Class for a generic comparator-sorted binary tree.
#
class BinTree:

    def __init__(self, value = None):

        self._parent = None
        self._left_child = None
        self._right_child = None
        self._value = value
        

    def _find_node(self, value, comparator):

        if (self._value == None):
            return (self, 0)
        else:
            cmp = comparator(value, self._value)
            if (cmp < 0 and self._left_child):
                return self._left_child._find_node(value, comparator)
            elif (cmp > 0 and self._right_child):
                return self._right_child._find_node(value, comparator)
            elif (cmp == 0):
                return (self, 1)
            else:
                return (self, 0)


    def _get_successor(self):

        succ = self._right_child
        while (succ._left_child):
            succ = succ._left_child

        return succ

        
    def insert(self, value, comparator):

        node, success = self._find_node(value, comparator)

        if (node._value == None):
            node._value = value
        else:
            new_node = BinTree(value)
            new_node._parent = node
            if (comparator(value, node._value) < 0):
                node._left_child = new_node
            else:
                node._right_child = new_node
    

    def find(self, value, comparator):

        node, success = self._find_node(value, comparator)
        if (success): return node._value
        else: return None


    def remove(self, value, comparator):

        node, success = self._find_node(value, comparator)
        if (success):
            if (not node._left_child or not node._right_child):
                y = node
            else:
                y = node._get_successor()

            if (y._left_child):
                x = y._left_child
            else:
                x = y._right_child

            if (x):
                x._parent = y._parent

            if (not y._parent):
                if (x):
                    self._value = x._value
                else:
                    self._value = None
            else:
                if (y == y._parent._left_child):
                    y._parent._left_child = x
                else:
                    y._parent._right_child = x

            if (y != node):
                node._value = y._value
        

    def dump(self, indent = 0):

        print " " * indent + "+ " + str(self._value)
        if (self._left_child): self._left_child.dump(indent + 2)
        else: print " " * (indent + 2) + "+"
        if (self._right_child): self._right_child.dump(indent + 2)
        else: print " " * (indent + 2) + "+"


    def count(self):

        cnt = 0
        if (self._value): cnt += 1
        if (self._value and self._left_child):
            cnt += self._left_child.count()
        if (self._value and self._right_child):
            cnt += self._right_child.count()

        return cnt


if (__name__ == "__main__"):

    def neq(a, b): return cmp(a, b) or 1
    
    t = BinTree()
    for v in [5, 34, 2, 21, 5, 3, 23, 6, 2, 49, 10]:
        t.insert(v, neq)
    for v in [5, 34, 2, 21, 5, 3, 23, 6, 2, 49, 10]:
        print t.find(v, cmp)
     
    for v in [5, 34, 2, 21, 5, 3, 23, 6, 2, 49, 10]:
        t.remove(v, cmp)

    print "search"
    for v in [5, 34, 2, 21, 5, 3, 23, 6, 2, 49, 10]:
        print t.find(v, cmp)
