[Zodb-checkins] CVS: ZEO/ZEO/tests - CommitLockTests.py:1.2 ThreadTests.py:1.2 testTransactionBuffer.py:1.4 Cache.py:1.8 forker.py:1.16 multi.py:1.8 speed.py:1.7 stress.py:1.6 testZEO.py:1.25

Jeremy Hylton jeremy@zope.com
Tue, 11 Jun 2002 09:43:08 -0400


Update of /cvs-repository/ZEO/ZEO/tests
In directory cvs.zope.org:/tmp/cvs-serv5548/ZEO/tests

Modified Files:
	Cache.py forker.py multi.py speed.py stress.py testZEO.py 
Added Files:
	CommitLockTests.py ThreadTests.py testTransactionBuffer.py 
Log Message:
Merge ZEO2-branch to trunk.


=== ZEO/ZEO/tests/CommitLockTests.py 1.1 => 1.2 ===
+#
+# Copyright (c) 2002 Zope Corporation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE
+#
+##############################################################################
+"""Tests of the distributed commit lock."""
+
+import threading
+
+from ZODB.Transaction import Transaction
+from ZODB.tests.StorageTestBase import zodb_pickle, MinPO
+
+import ZEO.ClientStorage
+from ZEO.Exceptions import Disconnected
+
+ZERO = '\0'*8
+
+class DummyDB:
+    def invalidate(self, *args):
+        pass
+
+class WorkerThread(threading.Thread):
+
+    # run the entire test in a thread so that the blocking call for
+    # tpc_vote() doesn't hang the test suite.
+
+    def __init__(self, storage, trans, method="tpc_finish"):
+        self.storage = storage
+        self.trans = trans
+        self.method = method
+        threading.Thread.__init__(self)
+
+    def run(self):
+        try:
+            self.storage.tpc_begin(self.trans)
+            oid = self.storage.new_oid()
+            self.storage.store(oid, ZERO, zodb_pickle(MinPO("c")), '', self.trans)
+            oid = self.storage.new_oid()
+            self.storage.store(oid, ZERO, zodb_pickle(MinPO("c")), '', self.trans)
+            self.storage.tpc_vote(self.trans)
+            if self.method == "tpc_finish":
+                self.storage.tpc_finish(self.trans)
+            else:
+                self.storage.tpc_abort(self.trans)
+        except Disconnected:
+            pass
+
+class CommitLockTests:
+
+    # The commit lock tests verify that the storage successfully
+    # blocks and restarts transactions when there is content for a
+    # single storage.  There are a lot of cases to cover.
+
+    # CommitLock1 checks the case where a single transaction delays
+    # other transactions before they actually block.  IOW, by the time
+    # the other transactions get to the vote stage, the first
+    # transaction has finished.
+
+    def checkCommitLock1OnCommit(self):
+        self._storages = []
+        try:
+            self._checkCommitLock("tpc_finish", self._dosetup1, self._dowork1)
+        finally:
+            self._cleanup()
+
+    def checkCommitLock1OnAbort(self):
+        self._storages = []
+        try:
+            self._checkCommitLock("tpc_abort", self._dosetup1, self._dowork1)
+        finally:
+            self._cleanup()
+
+    def checkCommitLock2OnCommit(self):
+        self._storages = []
+        try:
+            self._checkCommitLock("tpc_finish", self._dosetup2, self._dowork2)
+        finally:
+            self._cleanup()
+
+    def checkCommitLock2OnAbort(self):
+        self._storages = []
+        try:
+            self._checkCommitLock("tpc_abort", self._dosetup2, self._dowork2)
+        finally:
+            self._cleanup()
+
+    def _cleanup(self):
+        for store, trans in self._storages:
+            store.tpc_abort(trans)
+            store.close()
+        self._storages = []
+
+    def _checkCommitLock(self, method_name, dosetup, dowork):
+        # check the commit lock when a client attemps a transaction,
+        # but fails/exits before finishing the commit.
+
+        # Start on transaction normally.
+        t = Transaction()
+        self._storage.tpc_begin(t)
+
+        # Start a second transaction on a different connection without
+        # blocking the test thread.
+        self._storages = []
+        for i in range(4):
+            storage2 = self._duplicate_client()
+            t2 = Transaction()
+            tid = `ZEO.ClientStorage.get_timestamp()` # XXX why?
+            dosetup(storage2, t2, tid)
+            if i == 0:
+                storage2.close()
+            else:
+                self._storages.append((storage2, t2))
+
+        oid = self._storage.new_oid()
+        self._storage.store(oid, ZERO, zodb_pickle(MinPO(1)), '', t)
+        self._storage.tpc_vote(t)
+        if method_name == "tpc_finish":
+            self._storage.tpc_finish(t)
+            self._storage.load(oid, '')
+        else:
+            self._storage.tpc_abort(t)
+
+        dowork(method_name)
+
+        # Make sure the server is still responsive
+        self._dostore()
+
+    def _dosetup1(self, storage, trans, tid):
+        storage.tpc_begin(trans, tid)
+
+    def _dowork1(self, method_name):
+        for store, trans in self._storages:
+            oid = store.new_oid()
+            store.store(oid, ZERO, zodb_pickle(MinPO("c")), '', trans)
+            store.tpc_vote(trans)
+            if method_name == "tpc_finish":
+                store.tpc_finish(trans)
+            else:
+                store.tpc_abort(trans)
+
+    def _dosetup2(self, storage, trans, tid):
+        self._threads = []
+        t = WorkerThread(storage, trans)
+        self._threads.append(t)
+        t.start()
+
+    def _dowork2(self, method_name):
+        for t in self._threads:
+            t.join()
+
+    def _duplicate_client(self):
+        "Open another ClientStorage to the same server."
+        # XXX argh it's hard to find the actual address
+        # The rpc mgr addr attribute is a list.  Each element in the
+        # list is a socket domain (AF_INET, AF_UNIX, etc.) and an
+        # address.
+        addr = self._storage._rpc_mgr.addr[0][1]
+        new = ZEO.ClientStorage.ClientStorage(addr, wait=1)
+        new.registerDB(DummyDB(), None)
+        return new
+
+    def _get_timestamp(self):
+        t = time.time()
+        t = apply(TimeStamp,(time.gmtime(t)[:5]+(t%60,)))
+        return `t`
+


=== ZEO/ZEO/tests/ThreadTests.py 1.1 => 1.2 ===
+#
+# Copyright (c) 2002 Zope Corporation and Contributors.
+# All Rights Reserved.
+# 
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE
+# 
+##############################################################################
+"""Compromising positions involving threads."""
+
+import threading
+
+from ZODB.Transaction import Transaction
+from ZODB.tests.StorageTestBase import zodb_pickle, MinPO
+
+import ZEO.ClientStorage
+from ZEO.Exceptions import Disconnected
+
+ZERO = '\0'*8
+
+class BasicThread(threading.Thread):
+    def __init__(self, storage, doNextEvent, threadStartedEvent):
+        self.storage = storage
+        self.trans = Transaction()
+        self.doNextEvent = doNextEvent
+        self.threadStartedEvent = threadStartedEvent
+        self.gotValueError = 0
+        self.gotDisconnected = 0
+        threading.Thread.__init__(self)
+
+
+class GetsThroughVoteThread(BasicThread):
+    # This thread gets partially through a transaction before it turns
+    # execution over to another thread.  We're trying to establish that a
+    # tpc_finish() after a storage has been closed by another thread will get
+    # a ClientStorageError error.
+    #
+    # This class gets does a tpc_begin(), store(), tpc_vote() and is waiting
+    # to do the tpc_finish() when the other thread closes the storage.
+    def run(self):
+        self.storage.tpc_begin(self.trans)
+        oid = self.storage.new_oid()
+        self.storage.store(oid, ZERO, zodb_pickle(MinPO("c")), '', self.trans)
+        self.storage.tpc_vote(self.trans)
+        self.threadStartedEvent.set()
+        self.doNextEvent.wait(10)
+        try:
+            self.storage.tpc_finish(self.trans)
+        except ZEO.ClientStorage.ClientStorageError:
+            self.gotValueError = 1
+            self.storage.tpc_abort(self.trans)
+
+
+class GetsThroughBeginThread(BasicThread):
+    # This class is like the above except that it is intended to be run when
+    # another thread is already in a tpc_begin().  Thus, this thread will
+    # block in the tpc_begin until another thread closes the storage.  When
+    # that happens, this one will get disconnected too.
+    def run(self):
+        try:
+            self.storage.tpc_begin(self.trans)
+        except ZEO.ClientStorage.ClientStorageError:
+            self.gotValueError = 1
+
+
+class AbortsAfterBeginFailsThread(BasicThread):
+    # This class is identical to GetsThroughBeginThread except that it
+    # attempts to tpc_abort() after the tpc_begin() fails.  That will raise a
+    # ClientDisconnected exception which implies that we don't have the lock,
+    # and that's what we really want to test (but it's difficult given the
+    # threading module's API).
+    def run(self):
+        try:
+            self.storage.tpc_begin(self.trans)
+        except ZEO.ClientStorage.ClientStorageError:
+            self.gotValueError = 1
+        try:
+            self.storage.tpc_abort(self.trans)
+        except Disconnected:
+            self.gotDisconnected = 1
+
+
+class ThreadTests:
+    # Thread 1 should start a transaction, but not get all the way through it.
+    # Main thread should close the connection.  Thread 1 should then get
+    # disconnected.
+    def checkDisconnectedOnThread2Close(self):
+        doNextEvent = threading.Event()
+        threadStartedEvent = threading.Event()
+        thread1 = GetsThroughVoteThread(self._storage,
+                                        doNextEvent, threadStartedEvent)
+        thread1.start()
+        threadStartedEvent.wait(10)
+        self._storage.close()
+        doNextEvent.set()
+        thread1.join()
+        self.assertEqual(thread1.gotValueError, 1)
+
+    # Thread 1 should start a transaction, but not get all the way through
+    # it.  While thread 1 is in the middle of the transaction, a second thread
+    # should start a transaction, and it will block in the tcp_begin() --
+    # because thread 1 has acquired the lock in its tpc_begin().  Now the main
+    # thread closes the storage and both sub-threads should get disconnected.
+    def checkSecondBeginFails(self):
+        doNextEvent = threading.Event()
+        threadStartedEvent = threading.Event()
+        thread1 = GetsThroughVoteThread(self._storage,
+                                        doNextEvent, threadStartedEvent)
+        thread2 = GetsThroughBeginThread(self._storage,
+                                         doNextEvent, threadStartedEvent)
+        thread1.start()
+        threadStartedEvent.wait(1)
+        thread2.start()
+        self._storage.close()
+        doNextEvent.set()
+        thread1.join()
+        thread2.join()
+        self.assertEqual(thread1.gotValueError, 1)
+        self.assertEqual(thread2.gotValueError, 1)
+
+    def checkThatFailedBeginDoesNotHaveLock(self):
+        doNextEvent = threading.Event()
+        threadStartedEvent = threading.Event()
+        thread1 = GetsThroughVoteThread(self._storage,
+                                        doNextEvent, threadStartedEvent)
+        thread2 = AbortsAfterBeginFailsThread(self._storage,
+                                              doNextEvent, threadStartedEvent)
+        thread1.start()
+        threadStartedEvent.wait(1)
+        thread2.start()
+        self._storage.close()
+        doNextEvent.set()
+        thread1.join()
+        thread2.join()
+        self.assertEqual(thread1.gotValueError, 1)
+        self.assertEqual(thread2.gotValueError, 1)
+        self.assertEqual(thread2.gotDisconnected, 1)


=== ZEO/ZEO/tests/testTransactionBuffer.py 1.3 => 1.4 ===
+#
+# Copyright (c) 2001, 2002 Zope Corporation and Contributors.
+# All Rights Reserved.
+# 
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE
+# 
+##############################################################################
+import random
+import unittest
+
+from ZEO.TransactionBuffer import TransactionBuffer
+
+def random_string(size):
+    """Return a random string of size size."""
+    l = [chr(random.randrange(256)) for i in range(size)]
+    return "".join(l)
+
+def new_store_data():
+    """Return arbitrary data to use as argument to store() method."""
+    return random_string(8), '', random_string(random.randrange(1000))
+
+def new_invalidate_data():
+    """Return arbitrary data to use as argument to invalidate() method."""
+    return random_string(8), ''
+
+class TransBufTests(unittest.TestCase):
+
+    def checkTypicalUsage(self):
+        tbuf = TransactionBuffer()
+        tbuf.store(*new_store_data())
+        tbuf.invalidate(*new_invalidate_data())
+        tbuf.begin_iterate()
+        while 1:
+            o = tbuf.next()
+            if o is None:
+                break
+        tbuf.clear()
+
+    def doUpdates(self, tbuf):
+        data = []
+        for i in range(10):
+            d = new_store_data()
+            tbuf.store(*d)
+            data.append(d)
+            d = new_invalidate_data()
+            tbuf.invalidate(*d)
+            data.append(d)
+
+        tbuf.begin_iterate()
+        for i in range(len(data)):
+            x = tbuf.next()
+            if x[2] is None:
+                # the tbuf add a dummy None to invalidates
+                x = x[:2]
+            self.assertEqual(x, data[i])
+
+    def checkOrderPreserved(self):
+        tbuf = TransactionBuffer()
+        self.doUpdates(tbuf)
+
+    def checkReusable(self):
+        tbuf = TransactionBuffer()
+        self.doUpdates(tbuf)
+        tbuf.clear()
+        self.doUpdates(tbuf)
+        tbuf.clear()
+        self.doUpdates(tbuf)
+
+def test_suite():
+    return unittest.makeSuite(TransBufTests, 'check')


=== ZEO/ZEO/tests/Cache.py 1.7 => 1.8 ===
         # Make sure this doesn't load invalid data into the cache
         self._storage.load(oid, '')
-        
+
         self._storage.tpc_vote(t)
         self._storage.tpc_finish(t)
 


=== ZEO/ZEO/tests/forker.py 1.15 => 1.16 ===
 import asyncore
 import os
-import profile
 import random
 import socket
 import sys
+import traceback
 import types
-import ZEO.ClientStorage, ZEO.StorageServer
+import ZEO.ClientStorage
 
+# Change value of PROFILE to enable server-side profiling
 PROFILE = 0
+if PROFILE:
+    import hotshot
 
 def get_port():
     """Return a port that is not in use.
@@ -47,21 +50,23 @@
 
 if os.name == "nt":
 
-    def start_zeo_server(storage_name, args, port=None):
+    def start_zeo_server(storage_name, args, addr=None):
         """Start a ZEO server in a separate process.
 
         Returns the ZEO port, the test server port, and the pid.
         """
         import ZEO.tests.winserver
-        if port is None:
+        if addr is None:
             port = get_port()
+        else:
+            port = addr[1]
         script = ZEO.tests.winserver.__file__
         if script.endswith('.pyc'):
             script = script[:-1]
         args = (sys.executable, script, str(port), storage_name) + args
         d = os.environ.copy()
         d['PYTHONPATH'] = os.pathsep.join(sys.path)
-        pid = os.spawnve(os.P_NOWAIT, sys.executable, args, os.environ)
+        pid = os.spawnve(os.P_NOWAIT, sys.executable, args, d)
         return ('localhost', port), ('localhost', port + 1), pid
 
 else:
@@ -79,9 +84,11 @@
             buf = self.recv(4)
             if buf:
                 assert buf == "done"
+                server.close_server()
                 asyncore.socket_map.clear()
 
         def handle_close(self):
+            server.close_server()
             asyncore.socket_map.clear()
 
     class ZEOClientExit:
@@ -90,38 +97,56 @@
             self.pipe = pipe
 
         def close(self):
-            os.write(self.pipe, "done")
-            os.close(self.pipe)
+            try:
+                os.write(self.pipe, "done")
+                os.close(self.pipe)
+            except os.error:
+                pass
 
-    def start_zeo_server(storage, addr):
+    def start_zeo_server(storage_name, args, addr):
+        assert isinstance(args, types.TupleType)
         rd, wr = os.pipe()
         pid = os.fork()
         if pid == 0:
-            if PROFILE:
-                p = profile.Profile()
-                p.runctx("run_server(storage, addr, rd, wr)", globals(),
-                         locals())
-                p.dump_stats("stats.s.%d" % os.getpid())
-            else:
-                run_server(storage, addr, rd, wr)
+            import ZEO.zrpc.log
+            reload(ZEO.zrpc.log)
+            try:
+                if PROFILE:
+                    p = hotshot.Profile("stats.s.%d" % os.getpid())
+                    p.runctx("run_server(storage, addr, rd, wr)",
+                             globals(), locals())
+                    p.close()
+                else:
+                    run_server(addr, rd, wr, storage_name, args)
+            except:
+                print "Exception in ZEO server process"
+                traceback.print_exc()
             os._exit(0)
         else:
             os.close(rd)
             return pid, ZEOClientExit(wr)
 
-    def run_server(storage, addr, rd, wr):
+    def load_storage(name, args):
+        package = __import__("ZODB." + name)
+        mod = getattr(package, name)
+        klass = getattr(mod, name)
+        return klass(*args)
+
+    def run_server(addr, rd, wr, storage_name, args):
         # in the child, run the storage server
+        global server
         os.close(wr)
         ZEOServerExit(rd)
-        serv = ZEO.StorageServer.StorageServer(addr, {'1':storage})
-        asyncore.loop()
-        os.close(rd)
+        import ZEO.StorageServer, ZEO.zrpc.server
+        storage = load_storage(storage_name, args)
+        server = ZEO.StorageServer.StorageServer(addr, {'1':storage})
+        ZEO.zrpc.server.loop()
         storage.close()
         if isinstance(addr, types.StringType):
             os.unlink(addr)
 
-    def start_zeo(storage, cache=None, cleanup=None, domain="AF_INET",
-                  storage_id="1", cache_size=20000000):
+    def start_zeo(storage_name, args, cache=None, cleanup=None,
+                  domain="AF_INET", storage_id="1", cache_size=20000000):
         """Setup ZEO client-server for storage.
 
         Returns a ClientStorage instance and a ZEOClientExit instance.
@@ -137,10 +162,10 @@
         else:
             raise ValueError, "bad domain: %s" % domain
 
-        pid, exit = start_zeo_server(storage, addr)
+        pid, exit = start_zeo_server(storage_name, args, addr)
         s = ZEO.ClientStorage.ClientStorage(addr, storage_id,
-                                            debug=1, client=cache,
+                                            client=cache,
                                             cache_size=cache_size,
-                                            min_disconnect_poll=0.5)
+                                            min_disconnect_poll=0.5,
+                                            wait=1)
         return s, exit, pid
-


=== ZEO/ZEO/tests/multi.py 1.7 => 1.8 ===
     pid = os.fork()
     if pid == 0:
-        import ZEO.ClientStorage
-        if VERBOSE:
-            print "Client process started:", os.getpid()
-        cli = ZEO.ClientStorage.ClientStorage(addr, client=CLIENT_CACHE)
-        if client_func is None:
-            run(cli)
-        else:
-            client_func(cli)
-        cli.close()
-        os._exit(0)
+        try:
+            import ZEO.ClientStorage
+            if VERBOSE:
+                print "Client process started:", os.getpid()
+            cli = ZEO.ClientStorage.ClientStorage(addr, client=CLIENT_CACHE)
+            if client_func is None:
+                run(cli)
+            else:
+                client_func(cli)
+            cli.close()
+        finally:
+            os._exit(0)
     else:
         return pid
 


=== ZEO/ZEO/tests/speed.py 1.6 => 1.7 ===
 """
 
-import asyncore  
+import asyncore
 import sys, os, getopt, string, time
 ##sys.path.insert(0, os.getcwd())
 
@@ -81,7 +81,7 @@
         for r in 1, 10, 100, 1000:
             t = time.time()
             conflicts = 0
-            
+
             jar = db.open()
             while 1:
                 try:
@@ -105,7 +105,7 @@
                 else:
                     break
             jar.close()
-            
+
             t = time.time() - t
             if detailed:
                 if threadno is None:
@@ -205,11 +205,11 @@
     for v in l:
         tot = tot + v
     return tot / len(l)
-    
+
 ##def compress(s):
 ##    c = zlib.compressobj()
 ##    o = c.compress(s)
-##    return o + c.flush()    
+##    return o + c.flush()
 
 if __name__=='__main__':
     main(sys.argv[1:])


=== ZEO/ZEO/tests/stress.py 1.5 => 1.6 ===
     if pid != 0:
         return pid
-    
-    storage = ClientStorage(zaddr, debug=1, min_disconnect_poll=0.5)
+    try:
+        _start_child(zaddr)
+    finally:
+        os._exit(0)
+
+def _start_child(zaddr):
+    storage = ClientStorage(zaddr, debug=1, min_disconnect_poll=0.5, wait=1)
     db = ZODB.DB(storage, pool_size=NUM_CONNECTIONS)
     setup(db.open())
     conns = []
@@ -128,8 +133,6 @@
         else:
             c.__count += 1
         work(c)
-
-    os._exit(0)
 
 if __name__ == "__main__":
     main()


=== ZEO/ZEO/tests/testZEO.py 1.24 => 1.25 === (612/712 lines abridged)
 import os
 import random
+import select
 import socket
 import sys
 import tempfile
+import thread
 import time
 import types
 import unittest
@@ -26,22 +28,20 @@
 import ZEO.ClientStorage, ZEO.StorageServer
 import ThreadedAsync, ZEO.trigger
 from ZODB.FileStorage import FileStorage
-from ZODB.TimeStamp import TimeStamp
 from ZODB.Transaction import Transaction
-import thread
+from ZODB.tests.StorageTestBase import zodb_pickle, MinPO
+import zLOG
 
-from ZEO.tests import forker, Cache
+from ZEO.tests import forker, Cache, CommitLockTests, ThreadTests
 from ZEO.smac import Disconnected
 
-# Sorry Jim...
 from ZODB.tests import StorageTestBase, BasicStorage, VersionStorage, \
      TransactionalUndoStorage, TransactionalUndoVersionStorage, \
-     PackableStorage, Synchronization, ConflictResolution
+     PackableStorage, Synchronization, ConflictResolution, RevisionStorage, \
+     MTStorage, ReadOnlyStorage
 from ZODB.tests.MinPO import MinPO
 from ZODB.tests.StorageTestBase import zodb_unpickle
 
-ZERO = '\0'*8
-
 class DummyDB:
     def invalidate(self, *args):
         pass
@@ -56,93 +56,22 @@
     def pack(self, t, f):
         self.storage.pack(t, f, wait=1)
 
-class ZEOTestBase(StorageTestBase.StorageTestBase):
-    """Version of the storage test class that supports ZEO.
-    
-    For ZEO, we don't always get the serialno/exception for a
-    particular store as the return value from the store.   But we
-    will get no later than the return value from vote.
-    """
-    

[-=- -=- -=- 612 lines omitted -=- -=- -=-]

-        for k, v in klass.__dict__.items():
-            if callable(v):
-                meth[k] = 1
-    return meth.keys()
+            # XXX waitpid() isn't available until Python 2.3
+            time.sleep(0.5)
 
 if os.name == "posix":
     test_classes = ZEOFileStorageTests, UnixConnectionTests
@@ -502,36 +430,12 @@
 else:
     raise RuntimeError, "unsupported os: %s" % os.name
 
-def makeTestSuite(testname=''):
+def test_suite():
     suite = unittest.TestSuite()
-    name = 'check' + testname
-    lname = len(name)
     for klass in test_classes:
-        for meth in get_methods(klass):
-            if meth[:lname] == name:
-                suite.addTest(klass(meth))
+        sub = unittest.makeSuite(klass, 'check')
+        suite.addTest(sub)
     return suite
 
-def test_suite():
-    return makeTestSuite()
-
-def main():
-    import sys, getopt
-
-    name_of_test = ''
-
-    opts, args = getopt.getopt(sys.argv[1:], 'n:')
-    for flag, val in opts:
-        if flag == '-n':
-            name_of_test = val
-
-    if args:
-        print "Did not expect arguments.  Got %s" % args
-        return 0
-    
-    tests = makeTestSuite(name_of_test)
-    runner = unittest.TextTestRunner()
-    runner.run(tests)
-
 if __name__ == "__main__":
-    main()
+    unittest.main(defaultTest='test_suite')