Pythonの練習

Pythonの練習として、PriorityQueueとUnionFindを実装してみた。まだPythonをよくわかってないので間違ったことをしてるかも。

あとテストにunittest使ってみた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

class PriorityQueue:
  """
  優先度付キュー
  演算子"<"で比較できるオブジェクトなら何でも格納できる
  小さいオブジェクトから順にpopされる
  """

  def __init__(self):
    """
    空のキューを生成する O(1)
    """
    self.heap = []

  def push(self, x):
    """
    xをキューに追加する O(log N)
    """
    self.heap.append(x)
    self._siftUp()

  def pop(self):
    """
    キューから一番小さいオブジェクトを取り除いて返す O(log N)
    """
    assert not self.empty()

    if len(self.heap) == 1:
      return self.heap.pop()
    else:
      x = self.heap[0]
      self.heap[0] = self.heap.pop()
      self._siftDown()
      return x

  def empty(self):
    """
    キューが空なら真を返す O(1)
    """
    return len(self.heap) == 0

  def _siftUp(self):
    heap = self.heap
    i = len(heap) - 1
    while i > 0:
      parent = (i - 1) / 2
      if parent >= 0 and heap[i] < heap[parent]:
        heap[i], heap[parent] = heap[parent], heap[i]
        i = parent
      else:
        break

  def _siftDown(self):
    heap = self.heap
    i = 0
    while i < len(heap):
      child1 = i * 2 + 1
      child2 = i * 2 + 2
      if child1 < len(heap) and child2 < len(heap):
        if not (heap[child1] < heap[i]) and not (heap[child2] < heap[i]):
          break
        elif heap[child1] < heap[child2]:
          heap[i], heap[child1] = heap[child1], heap[i]
          i = child1
        else:
          heap[i], heap[child2] = heap[child2], heap[i]
          i = child2
      elif child1 < len(heap) and heap[child1] < heap[i]:
        heap[i], heap[child1] = heap[child1], heap[i]
        i = child1
      else:
        break


if __name__ == '__main__':
  import unittest

  class PriorityQueueTest(unittest.TestCase):
    def setUp(self):
      self.q = PriorityQueue()

    def test(self):
      q = self.q
      assert q.empty()
      q.push(10)
      assert not q.empty()
      assert q.pop() == 10
      assert q.empty()
      q.push(10)
      q.push(5)
      q.push(100)
      q.push(20)
      q.push(77)
      q.push(0)
      q.push(140)
      q.push(123)
      assert q.pop() == 0
      assert q.pop() == 5
      assert q.pop() == 10
      assert q.pop() == 20
      assert q.pop() == 77
      q.push(8)
      q.push(130)
      q.push(130)
      q.push(200)
      assert q.pop() == 8
      assert q.pop() == 100
      assert q.pop() == 123
      assert q.pop() == 130
      assert q.pop() == 130
      assert q.pop() == 140
      q.push(-1000)
      assert q.pop() == -1000
      assert q.pop() == 200
      assert q.empty()

  suite = unittest.TestLoader().loadTestsFromTestCase(PriorityQueueTest)
  _unit = unittest.TextTestRunner(verbosity=2).run(suite)
  print _unit
#!/usr/bin/env python
# -*- coding: utf-8 -*-

class DisjointSets:
  """
  互いに素な集合族に対して,次の2つの操作を高速に行うデータ構造
  - union: 2つの集合を合併する
  - find: 2つの要素が同じ集合に属しているか調べる
  """

  def __init__(self, n):
    self.parent = n * [-1]

  def union(self, x, y):
    x = self.root(x)
    y = self.root(y)
    if x != y:
      if self.parent[x] < self.parent[y]:
        self.parent[x] += self.parent[y]
        self.parent[y] = x
      else:
        self.parent[y] += self.parent[x]
        self.parent[x] = y
    return x != y

  def find(self, x, y):
    return self.root(x) == self.root(y)

  def root(self, x):
    y = x
    while self.parent[x] >= 0: x = self.parent[x]
    while self.parent[y] >= 0: self.parent[y], y = x, self.parent[y]
    return x


if __name__ == '__main__':
  import unittest

  class DisjointSetsTest(unittest.TestCase):
    def test(self):
      ds = DisjointSets(5)
      assert not ds.find(0, 1)
      assert not ds.find(1, 2)
      assert not ds.find(0, 3)
      assert not ds.find(3, 4)
      ds.union(1, 2)
      assert ds.find(1, 2)
      assert not ds.find(0, 1)
      assert not ds.find(3, 4)
      ds.union(4, 2)
      assert ds.find(4, 2)
      assert ds.find(1, 4)
      assert not ds.find(0, 1)
      assert not ds.find(3, 4)
      ds.union(0, 3)
      assert not ds.find(0, 1)
      assert not ds.find(3, 4)
      assert ds.find(0, 3)
      ds.union(3, 1)
      assert ds.find(3, 1)
      assert ds.find(4, 0)

  suite = unittest.TestLoader().loadTestsFromTestCase(DisjointSetsTest)
  _unit = unittest.TextTestRunner(verbosity=2).run(suite)
  print _unit