[Checkins] SVN: relstorage/trunk/relstorage/adapters/ Began adding failover support.

Shane Hathaway shane at hathawaymix.org
Fri Oct 2 02:50:02 EDT 2009


Log message for revision 104731:
  Began adding failover support.
  

Changed:
  U   relstorage/trunk/relstorage/adapters/connmanager.py
  U   relstorage/trunk/relstorage/adapters/interfaces.py
  U   relstorage/trunk/relstorage/adapters/mysql.py
  U   relstorage/trunk/relstorage/adapters/oracle.py
  U   relstorage/trunk/relstorage/adapters/postgresql.py
  A   relstorage/trunk/relstorage/adapters/tests/
  A   relstorage/trunk/relstorage/adapters/tests/__init__.py
  A   relstorage/trunk/relstorage/adapters/tests/test_connmanager.py

-=-
Modified: relstorage/trunk/relstorage/adapters/connmanager.py
===================================================================
--- relstorage/trunk/relstorage/adapters/connmanager.py	2009-10-02 05:39:33 UTC (rev 104730)
+++ relstorage/trunk/relstorage/adapters/connmanager.py	2009-10-02 06:50:02 UTC (rev 104731)
@@ -13,8 +13,12 @@
 ##############################################################################
 
 from relstorage.adapters.interfaces import IConnectionManager
+from relstorage.adapters.interfaces import ReplicaClosedException
 from zope.interface import implements
+import os
+import time
 
+
 class AbstractConnectionManager(object):
     """Abstract base class for connection management.
 
@@ -24,7 +28,7 @@
 
     # disconnected_exceptions contains the exception types that might be
     # raised when the connection to the database has been broken.
-    disconnected_exceptions = ()
+    disconnected_exceptions = (ReplicaClosedException,)
 
     # close_exceptions contains the exception types to ignore
     # when the adapter attempts to close a database connection.
@@ -34,6 +38,12 @@
     # will be called whenever a store cursor is opened or rolled back.
     on_store_opened = None
 
+    def __init__(self, replica_conf=None):
+        if replica_conf:
+            self.replicas = ReplicaSelector(replica_conf)
+        else:
+            self.replicas = None
+
     def set_on_store_opened(self, f):
         """Set the on_store_opened hook"""
         self.on_store_opened = f
@@ -78,6 +88,11 @@
 
     def restart_load(self, conn, cursor):
         """Reinitialize a connection for loading objects."""
+        if self.replicas is not None:
+            if conn.replica != self.replicas.current():
+                # Prompt the change to a new replica by raising an exception.
+                self.close(conn, cursor)
+                raise ReplicaClosedException()
         conn.rollback()
 
     def open_for_store(self):
@@ -96,6 +111,11 @@
 
     def restart_store(self, conn, cursor):
         """Reuse a store connection."""
+        if self.replicas is not None:
+            if conn.replica != self.replicas.current():
+                # Prompt the change to a new replica by raising an exception.
+                self.close(conn, cursor)
+                raise ReplicaClosedException()
         conn.rollback()
         if self.on_store_opened is not None:
             self.on_store_opened(cursor, restart=True)
@@ -105,3 +125,96 @@
         Returns (conn, cursor).
         """
         return self.open()
+
+
+class ReplicaSelector(object):
+
+    def __init__(self, replica_conf, alt_timeout=600):
+        self.replica_conf = replica_conf
+        self.alt_timeout = alt_timeout
+        self._read_config()
+        self._select(0)
+        self._iterating = False
+        self._skip_index = None
+
+    def _read_config(self):
+        self._config_modified = os.path.getmtime(self.replica_conf)
+        self._config_checked = time.time()
+        f = open(self.replica_conf, 'r')
+        try:
+            lines = f.readlines()
+        finally:
+            f.close()
+        replicas = []
+        for line in lines:
+            line = line.strip()
+            if not line or line.startswith('#'):
+                continue
+            replicas.append(line)
+        if not replicas:
+            raise IndexError(
+                "No replicas specified in %s" % self.replica_conf)
+        self._replicas = replicas
+
+    def _is_config_modified(self):
+        now = time.time()
+        if now < self._config_checked + 1:
+            # don't check the last mod time more often than once per second
+            return False
+        self._config_checked = now
+        t = os.path.getmtime(self.replica_conf)
+        return t != self._config_modified
+
+    def _select(self, index):
+        self._current_replica = self._replicas[index]
+        self._current_index = index
+        if index > 0 and self.alt_timeout:
+            self._expiration = time.time() + self.alt_timeout
+        else:
+            self._expiration = None
+
+    def current(self):
+        """Get the current replica."""
+        self._iterating = False
+        if self._is_config_modified():
+            self._read_config()
+            self._select(0)
+        elif self._expiration is not None and time.time() >= self._expiration:
+            self._select(0)
+        return self._current_replica
+
+    def next(self):
+        """Return the next replica to try.
+
+        Return None if there are no more replicas defined.
+        """
+        if self._is_config_modified():
+            # Start over even if iteration was already in progress.
+            self._read_config()
+            self._select(0)
+            self._skip_index = None
+            self._iterating = True
+        elif not self._iterating:
+            # Start iterating.
+            self._skip_index = self._current_index
+            i = 0
+            if i == self._skip_index:
+                i = 1
+                if i >= len(self._replicas):
+                    # There are no more replicas to try.
+                    self._select(0)
+                    return None
+            self._select(i)
+            self._iterating = True
+        else:
+            # Continue iterating.
+            i = self._current_index + 1
+            if i == self._skip_index:
+                i += 1
+            if i >= len(self._replicas):
+                # There are no more replicas to try.
+                self._select(0)
+                return None
+            self._select(i)
+
+        return self._current_replica

Modified: relstorage/trunk/relstorage/adapters/interfaces.py
===================================================================
--- relstorage/trunk/relstorage/adapters/interfaces.py	2009-10-02 05:39:33 UTC (rev 104730)
+++ relstorage/trunk/relstorage/adapters/interfaces.py	2009-10-02 06:50:02 UTC (rev 104731)
@@ -417,3 +417,7 @@
     def abort(conn, cursor, txn=None):
         """Abort the commit.  If txn is not None, phase 1 is also aborted."""
 
+
+class ReplicaClosedException(Exception):
+    """The connection to the replica has been closed"""
+

Modified: relstorage/trunk/relstorage/adapters/mysql.py
===================================================================
--- relstorage/trunk/relstorage/adapters/mysql.py	2009-10-02 05:39:33 UTC (rev 104730)
+++ relstorage/trunk/relstorage/adapters/mysql.py	2009-10-02 06:50:02 UTC (rev 104731)
@@ -56,6 +56,7 @@
 from relstorage.adapters.dbiter import HistoryFreeDatabaseIterator
 from relstorage.adapters.dbiter import HistoryPreservingDatabaseIterator
 from relstorage.adapters.interfaces import IRelStorageAdapter
+from relstorage.adapters.interfaces import ReplicaClosedException
 from relstorage.adapters.locker import MySQLLocker
 from relstorage.adapters.mover import ObjectMover
 from relstorage.adapters.oidallocator import MySQLOIDAllocator
@@ -71,7 +72,11 @@
 
 # disconnected_exceptions contains the exception types that might be
 # raised when the connection to the database has been broken.
-disconnected_exceptions = (MySQLdb.OperationalError, MySQLdb.InterfaceError)
+disconnected_exceptions = (
+    MySQLdb.OperationalError,
+    MySQLdb.InterfaceError,
+    ReplicaClosedException,
+    )
 
 # close_exceptions contains the exception types to ignore
 # when the adapter attempts to close a database connection.
@@ -85,7 +90,7 @@
     def __init__(self, keep_history=True, **params):
         self.keep_history = keep_history
         self._params = params
-        self.connmanager = MySQLdbConnectionManager(params)
+        self.connmanager = MySQLdbConnectionManager(**params)
         self.runner = ScriptRunner()
         self.locker = MySQLLocker(
             keep_history=self.keep_history,
@@ -142,9 +147,7 @@
             )
 
     def new_instance(self):
-        # This adapter and its components are stateless, so it's
-        # safe to share it between threads.
-        return self
+        return MySQLAdapter(keep_history=self.keep_history, **self._params)
 
     def __str__(self):
         if self.keep_history:
@@ -168,23 +171,58 @@
     disconnected_exceptions = disconnected_exceptions
     close_exceptions = close_exceptions
 
-    def __init__(self, params):
-        self._params = params.copy()
+    def __init__(self, replica_conf=None, **params):
+        self._orig_params = params.copy()
+        self._params = self._orig_params
+        self._current_replica = None
+        super(MySQLdbConnectionManager, self).__init__(
+            replica_conf=replica_conf)
 
+    def _set_params(self, replica):
+        """Alter the connection parameters to use the specified replica.
+
+        The replica parameter is a string specifying either host or host:port.
+        """
+        if replica != self._current_replica:
+            params = self._orig_params.copy()
+            if ':' in replica:
+                host, port = replica.split(':')
+                params['host'] = host
+                params['port'] = int(port)
+            else:
+                params['host'] = replica
+            self._current_replica = replica
+            self._params = params
+
     def open(self, transaction_mode="ISOLATION LEVEL READ COMMITTED"):
         """Open a database connection and return (conn, cursor)."""
-        try:
-            conn = MySQLdb.connect(**self._params)
-            cursor = conn.cursor()
-            cursor.arraysize = 64
-            if transaction_mode:
-                conn.autocommit(True)
-                cursor.execute("SET SESSION TRANSACTION %s" % transaction_mode)
-                conn.autocommit(False)
-            return conn, cursor
-        except MySQLdb.OperationalError, e:
-            log.warning("Unable to connect: %s", e)
-            raise
+        if self.replicas is not None:
+            self._set_params(self.replicas.current())
+        while True:
+            try:
+                conn = MySQLdb.connect(**self._params)
+                cursor = conn.cursor()
+                cursor.arraysize = 64
+                if transaction_mode:
+                    conn.autocommit(True)
+                    cursor.execute(
+                        "SET SESSION TRANSACTION %s" % transaction_mode)
+                    conn.autocommit(False)
+                conn.replica = self._current_replica
+                return conn, cursor
+            except MySQLdb.OperationalError, e:
+                if self._current_replica:
+                    log.warning("Unable to connect to replica %s: %s",
+                        self._current_replica, e)
+                else:
+                    log.warning("Unable to connect: %s", e)
+                if self.replicas is not None:
+                    replica = self.replicas.next()
+                    if replica is not None:
+                        # try the new replica
+                        self._set_params(replica)
+                        continue
+                raise
 
     def open_for_load(self):
         """Open and initialize a connection for loading objects.

Modified: relstorage/trunk/relstorage/adapters/oracle.py
===================================================================
--- relstorage/trunk/relstorage/adapters/oracle.py	2009-10-02 05:39:33 UTC (rev 104730)
+++ relstorage/trunk/relstorage/adapters/oracle.py	2009-10-02 06:50:02 UTC (rev 104731)
@@ -21,6 +21,7 @@
 from relstorage.adapters.dbiter import HistoryFreeDatabaseIterator
 from relstorage.adapters.dbiter import HistoryPreservingDatabaseIterator
 from relstorage.adapters.interfaces import IRelStorageAdapter
+from relstorage.adapters.interfaces import ReplicaClosedException
 from relstorage.adapters.locker import OracleLocker
 from relstorage.adapters.mover import ObjectMover
 from relstorage.adapters.oidallocator import OracleOIDAllocator
@@ -40,6 +41,7 @@
     cx_Oracle.OperationalError,
     cx_Oracle.InterfaceError,
     cx_Oracle.DatabaseError,
+    ReplicaClosedException,
     )
 
 # close_exceptions contains the exception types to ignore
@@ -50,8 +52,8 @@
     """Oracle adapter for RelStorage."""
     implements(IRelStorageAdapter)
 
-    def __init__(self, user, password, dsn, twophase=False, arraysize=64,
-            use_inline_lobs=None, keep_history=True):
+    def __init__(self, user, password, dsn, twophase=False,
+            keep_history=True, replica_conf=None):
         """Create an Oracle adapter.
 
         The user, password, and dsn parameters are provided to cx_Oracle
@@ -60,28 +62,22 @@
         If twophase is true, all commits go through an Oracle-level two-phase
         commit process.  This is disabled by default.  Even when this option
         is disabled, the ZODB two-phase commit is still in effect.
-
-        arraysize sets the number of rows to buffer in cx_Oracle.  The default
-        is 64.
-
-        use_inline_lobs enables Oracle to send BLOBs inline in response to
-        queries.  It depends on features in cx_Oracle 5.  The default is None,
-        telling the adapter to auto-detect the presence of cx_Oracle 5.
         """
-        if use_inline_lobs is None:
-            use_inline_lobs = (cx_Oracle.version >= '5.0')
-        self.keep_history = keep_history
         self._user = user
+        self._password = password
         self._dsn = dsn
+        self._twophase = twophase
+        self.keep_history = keep_history
+        self.replica_conf = replica_conf
 
         self.connmanager = CXOracleConnectionManager(
-            params=(user, password, dsn),
-            arraysize=arraysize,
+            user=user,
+            password=password,
+            dsn=dsn,
             twophase=twophase,
+            replica_conf=replica_conf,
             )
-        self.runner = CXOracleScriptRunner(
-            use_inline_lobs=use_inline_lobs,
-            )
+        self.runner = CXOracleScriptRunner()
         self.locker = OracleLocker(
             keep_history=self.keep_history,
             lock_exceptions=(cx_Oracle.DatabaseError,),
@@ -145,21 +141,32 @@
     def new_instance(self):
         # This adapter and its components are stateless, so it's
         # safe to share it between threads.
-        return self
+        return OracleAdapter(
+            user=self._user,
+            password=self._password,
+            dsn=self._dsn,
+            twophase=self._twophase,
+            keep_history=self.keep_history,
+            replica_conf=self.replica_conf,
+            )
 
     def __str__(self):
+        parts = [self.__class__.__name__]
         if self.keep_history:
-            t = 'history preserving'
+            parts.append('history preserving')
         else:
-            t = 'history free'
-        return "%s, %s, user=%r, dsn=%r" % (
-            self.__class__.__name__, t, self._user, self._dsn)
+            parts.append('history free')
+        parts.append('user=%r' % self._user)
+        parts.append('dsn=%r' % self._dsn)
+        parts.append('twophase=%r' % self._twophase)
+        parts.append('replica_conf=%r' % self.replica_conf)
+        return ", ".join(parts)
 
 
 class CXOracleScriptRunner(OracleScriptRunner):
 
-    def __init__(self, use_inline_lobs):
-        self.use_inline_lobs = use_inline_lobs
+    def __init__(self):
+        self.use_inline_lobs = (cx_Oracle.version >= '5.0')
 
     def _outputtypehandler(self,
             cursor, name, defaultType, size, precision, scale):
@@ -225,26 +232,39 @@
     disconnected_exceptions = disconnected_exceptions
     close_exceptions = close_exceptions
 
-    def __init__(self, params, arraysize, twophase):
-        self._params = params
-        self._arraysize = arraysize
+    def __init__(self, user, password, dsn, twophase, replica_conf=None):
+        self._user = user
+        self._password = password
+        self._dsn = dsn
         self._twophase = twophase
+        super(CXOracleConnectionManager, self).__init__(
+            replica_conf=replica_conf)
 
     def open(self, transaction_mode="ISOLATION LEVEL READ COMMITTED",
             twophase=False):
         """Open a database connection and return (conn, cursor)."""
-        try:
-            kw = {'twophase': twophase}  #, 'threaded': True}
-            conn = cx_Oracle.connect(*self._params, **kw)
-            cursor = conn.cursor()
-            cursor.arraysize = self._arraysize
-            if transaction_mode:
-                cursor.execute("SET TRANSACTION %s" % transaction_mode)
-            return conn, cursor
+        if self.replicas is not None:
+            self._dsn = self.replicas.current()
+        while True:
+            try:
+                kw = {'twophase': twophase}  #, 'threaded': True}
+                conn = cx_Oracle.connect(
+                    self._user, self._password, self._dsn, **kw)
+                cursor = conn.cursor()
+                cursor.arraysize = 64
+                if transaction_mode:
+                    cursor.execute("SET TRANSACTION %s" % transaction_mode)
+                return conn, cursor
 
-        except cx_Oracle.OperationalError, e:
-            log.warning("Unable to connect: %s", e)
-            raise
+            except cx_Oracle.OperationalError, e:
+                log.warning("Unable to connect to DSN %s: %s", self._dsn, e)
+                if self.replicas is not None:
+                    replica = self.replicas.next()
+                    if replica is not None:
+                        # try the new replica
+                        self._dsn = replica
+                        continue
+                raise
 
     def open_for_load(self):
         """Open and initialize a connection for loading objects.
@@ -255,6 +275,11 @@
 
     def restart_load(self, conn, cursor):
         """Reinitialize a connection for loading objects."""
+        if self.replicas is not None:
+            if conn.dsn != self.replicas.current():
+                # Prompt the change to a new replica by raising an exception.
+                self.close(conn, cursor)
+                raise ReplicaClosedException()
         conn.rollback()
         cursor.execute("SET TRANSACTION READ ONLY")
 
@@ -288,6 +313,11 @@
 
     def restart_store(self, conn, cursor):
         """Reuse a store connection."""
+        if self.replicas is not None:
+            if conn.dsn != self.replicas.current():
+                # Prompt the change to a new replica by raising an exception.
+                self.close(conn, cursor)
+                raise ReplicaClosedException()
         conn.rollback()
         if self._twophase:
             self._set_xid(conn, cursor)

Modified: relstorage/trunk/relstorage/adapters/postgresql.py
===================================================================
--- relstorage/trunk/relstorage/adapters/postgresql.py	2009-10-02 05:39:33 UTC (rev 104730)
+++ relstorage/trunk/relstorage/adapters/postgresql.py	2009-10-02 06:50:02 UTC (rev 104731)
@@ -22,6 +22,7 @@
 from relstorage.adapters.dbiter import HistoryFreeDatabaseIterator
 from relstorage.adapters.dbiter import HistoryPreservingDatabaseIterator
 from relstorage.adapters.interfaces import IRelStorageAdapter
+from relstorage.adapters.interfaces import ReplicaClosedException
 from relstorage.adapters.locker import PostgreSQLLocker
 from relstorage.adapters.mover import ObjectMover
 from relstorage.adapters.oidallocator import PostgreSQLOIDAllocator
@@ -40,6 +41,7 @@
 disconnected_exceptions = (
     psycopg2.OperationalError,
     psycopg2.InterfaceError,
+    ReplicaClosedException,
     )
 
 # close_exceptions contains the exception types to ignore
@@ -50,12 +52,14 @@
     """PostgreSQL adapter for RelStorage."""
     implements(IRelStorageAdapter)
 
-    def __init__(self, dsn='', keep_history=True):
+    def __init__(self, dsn='', keep_history=True, replica_conf=None):
+        self._dsn = dsn
         self.keep_history = keep_history
-        self._dsn = dsn
+        self.replica_conf = replica_conf
         self.connmanager = Psycopg2ConnectionManager(
             dsn=dsn,
             keep_history=self.keep_history,
+            replica_conf=replica_conf,
             )
         self.runner = ScriptRunner()
         self.locker = PostgreSQLLocker(
@@ -109,20 +113,29 @@
             )
 
     def new_instance(self):
-        # This adapter and its components are stateless, so it's
-        # safe to share it between threads.
-        return self
+        return PostgreSQLAdapter(
+            dsn=self._dsn, keep_history=self.keep_history,
+            replica_conf=self.replica_conf)
 
     def __str__(self):
+        parts = [self.__class__.__name__]
         if self.keep_history:
-            t = 'history preserving'
+            parts.append('history preserving')
         else:
-            t = 'history free'
-        parts = self._dsn.split()
-        s = ' '.join(p for p in parts if not p.startswith('password'))
-        return "%s, %s, dsn=%r" % (self.__class__.__name__, t, s)
+            parts.append('history free')
+        dsnparts = self._dsn.split()
+        s = ' '.join(p for p in dsnparts if not p.startswith('password'))
+        parts.append('dsn=%r' % s)
+        parts.append('replica_conf=%r' % self.replica_conf)
+        return ", ".join(parts)
 
 
+class Psycopg2Connection(psycopg2.extensions.connection):
+    # The replica attribute holds the name of the replica this
+    # connection is bound to.
+    __slots__ = ('replica',)
+
+
 class Psycopg2ConnectionManager(AbstractConnectionManager):
 
     isolation_read_committed = (
@@ -133,22 +146,53 @@
     disconnected_exceptions = disconnected_exceptions
     close_exceptions = close_exceptions
 
-    def __init__(self, dsn, keep_history):
+    def __init__(self, dsn, keep_history, replica_conf=None):
+        self._orig_dsn = dsn
         self._dsn = dsn
         self.keep_history = keep_history
+        self._current_replica = None
+        super(Psycopg2ConnectionManager, self).__init__(
+            replica_conf=replica_conf)
 
+    def _set_dsn(self, replica):
+        """Alter the DSN to use the specified replica.
+
+        The replica parameter is a string specifying either host or host:port.
+        """
+        if replica != self._current_replica:
+            if ':' in replica:
+                host, port = replica.split(':')
+                self._dsn = self._orig_dsn + ' host=%s port=%s' % (host, port)
+            else:
+                self._dsn = self._orig_dsn + ' host=%s' % replica
+            self._current_replica = replica
+
     def open(self,
             isolation=psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED):
         """Open a database connection and return (conn, cursor)."""
-        try:
-            conn = psycopg2.connect(self._dsn)
-            conn.set_isolation_level(isolation)
-            cursor = conn.cursor()
-            cursor.arraysize = 64
-        except psycopg2.OperationalError, e:
-            log.warning("Unable to connect: %s", e)
-            raise
-        return conn, cursor
+        if self.replicas is not None:
+            self._set_dsn(self.replicas.current())
+        while True:
+            try:
+                conn = Psycopg2Connection(self._dsn)
+                conn.set_isolation_level(isolation)
+                cursor = conn.cursor()
+                cursor.arraysize = 64
+                conn.replica = self._current_replica
+                return conn, cursor
+            except psycopg2.OperationalError, e:
+                if self._current_replica:
+                    log.warning("Unable to connect to replica %s: %s",
+                        self._current_replica, e)
+                else:
+                    log.warning("Unable to connect: %s", e)
+                if self.replicas is not None:
+                    replica = self.replicas.next()
+                    if replica is not None:
+                        # try the new replica
+                        self._set_dsn(replica)
+                        continue
+                raise
 
     def open_for_load(self):
         """Open and initialize a connection for loading objects.

Added: relstorage/trunk/relstorage/adapters/tests/__init__.py
===================================================================
--- relstorage/trunk/relstorage/adapters/tests/__init__.py	                        (rev 0)
+++ relstorage/trunk/relstorage/adapters/tests/__init__.py	2009-10-02 06:50:02 UTC (rev 104731)
@@ -0,0 +1 @@
+

Added: relstorage/trunk/relstorage/adapters/tests/test_connmanager.py
===================================================================
--- relstorage/trunk/relstorage/adapters/tests/test_connmanager.py	                        (rev 0)
+++ relstorage/trunk/relstorage/adapters/tests/test_connmanager.py	2009-10-02 06:50:02 UTC (rev 104731)
@@ -0,0 +1,191 @@
+##############################################################################
+#
+# Copyright (c) 2009 Zope Foundation and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE.
+#
+##############################################################################
+
+import unittest
+
+class AbstractConnectionManagerTests(unittest.TestCase):
+
+    def test_without_replica_conf(self):
+        from relstorage.adapters.connmanager import AbstractConnectionManager
+        cm = AbstractConnectionManager()
+
+        conn = MockConnection()
+        cm.restart_load(conn, MockCursor())
+        self.assertTrue(conn.rolled_back)
+
+        conn = MockConnection()
+        cm.restart_store(conn, MockCursor())
+        self.assertTrue(conn.rolled_back)
+
+    def test_with_replica_conf(self):
+        import tempfile
+        f = tempfile.NamedTemporaryFile()
+        f.write("example.com:1234\n")
+        f.flush()
+
+        from relstorage.adapters.connmanager import AbstractConnectionManager
+        from relstorage.adapters.interfaces import ReplicaClosedException
+        cm = AbstractConnectionManager(f.name)
+
+        conn = MockConnection()
+        conn.replica = 'example.com:1234'
+        cm.restart_load(conn, MockCursor())
+        self.assertTrue(conn.rolled_back)
+        conn.replica = 'other'
+        self.assertRaises(ReplicaClosedException,
+            cm.restart_load, conn, MockCursor())
+
+        conn = MockConnection()
+        conn.replica = 'example.com:1234'
+        cm.restart_store(conn, MockCursor())
+        self.assertTrue(conn.rolled_back)
+        conn.replica = 'other'
+        self.assertRaises(ReplicaClosedException,
+            cm.restart_store, conn, MockCursor())
+
+
+class ReplicaSelectorTests(unittest.TestCase):
+
+    def setUp(self):
+        import tempfile
+        self.f = tempfile.NamedTemporaryFile()
+        self.f.write(
+            "# Replicas\n\nexample.com:1234\nlocalhost:4321\n"
+            "\nlocalhost:9999\n")
+        self.f.flush()
+
+    def tearDown(self):
+        self.f.close()
+
+    def test__read_config_normal(self):
+        from relstorage.adapters.connmanager import ReplicaSelector
+        rs = ReplicaSelector(self.f.name)
+        self.assertEqual(rs._replicas,
+            ['example.com:1234', 'localhost:4321', 'localhost:9999'])
+
+    def test__read_config_empty(self):
+        from relstorage.adapters.connmanager import ReplicaSelector
+        self.f.seek(0)
+        self.f.truncate()
+        self.assertRaises(IndexError, ReplicaSelector, self.f.name)
+
+    def test__is_config_modified(self):
+        from relstorage.adapters.connmanager import ReplicaSelector
+        import time
+        rs = ReplicaSelector(self.f.name)
+        self.assertEqual(rs._is_config_modified(), False)
+        # change the file
+        rs._config_modified = 0
+        # don't check the file yet
+        rs._config_checked = time.time() + 3600
+        self.assertEqual(rs._is_config_modified(), False)
+        # now check the file
+        rs._config_checked = 0
+        self.assertEqual(rs._is_config_modified(), True)
+
+    def test__select(self):
+        from relstorage.adapters.connmanager import ReplicaSelector
+        rs = ReplicaSelector(self.f.name)
+        rs._select(0)
+        self.assertEqual(rs._current_replica, 'example.com:1234')
+        self.assertEqual(rs._current_index, 0)
+        self.assertEqual(rs._expiration, None)
+        rs._select(1)
+        self.assertEqual(rs._current_replica, 'localhost:4321')
+        self.assertEqual(rs._current_index, 1)
+        self.assertNotEqual(rs._expiration, None)
+
+    def test_current(self):
+        from relstorage.adapters.connmanager import ReplicaSelector
+        rs = ReplicaSelector(self.f.name)
+        self.assertEqual(rs.current(), 'example.com:1234')
+        # change the file and get the new current replica
+        self.f.seek(0)
+        self.f.write('localhost\nalternate\n')
+        self.f.flush()
+        rs._config_checked = 0
+        rs._config_modified = 0
+        self.assertEqual(rs.current(), 'localhost')
+        # switch to the alternate
+        rs._select(1)
+        self.assertEqual(rs.current(), 'alternate')
+        # expire the alternate
+        rs._expiration = 0
+        self.assertEqual(rs.current(), 'localhost')
+
+    def test_next_iteration(self):
+        from relstorage.adapters.connmanager import ReplicaSelector
+        rs = ReplicaSelector(self.f.name)
+
+        # test forward iteration
+        self.assertEqual(rs.current(), 'example.com:1234')
+        self.assertEqual(rs.next(), 'localhost:4321')
+        self.assertEqual(rs.next(), 'localhost:9999')
+        self.assertEqual(rs.next(), None)
+
+        # test iteration that skips over the replica that failed
+        self.assertEqual(rs.current(), 'example.com:1234')
+        self.assertEqual(rs.next(), 'localhost:4321')
+        self.assertEqual(rs.current(), 'localhost:4321')
+        # next() after current() indicates the last replica failed
+        self.assertEqual(rs.next(), 'example.com:1234')
+        self.assertEqual(rs.next(), 'localhost:9999')
+        self.assertEqual(rs.next(), None)
+
+    def test_next_only_one_server(self):
+        from relstorage.adapters.connmanager import ReplicaSelector
+        self.f.seek(0)
+        self.f.write('localhost\n')
+        self.f.flush()
+        self.f.truncate()
+        rs = ReplicaSelector(self.f.name)
+        self.assertEqual(rs.current(), 'localhost')
+        self.assertEqual(rs.next(), None)
+
+    def test_next_with_new_conf(self):
+        from relstorage.adapters.connmanager import ReplicaSelector
+        rs = ReplicaSelector(self.f.name)
+        self.assertEqual(rs.current(), 'example.com:1234')
+        self.assertEqual(rs.next(), 'localhost:4321')
+        # interrupt the iteration by changing the replica conf file
+        self.f.seek(0)
+        self.f.write('example.com:9999\n')
+        self.f.flush()
+        self.f.truncate()
+        rs._config_checked = 0
+        rs._config_modified = 0
+        self.assertEqual(rs.next(), 'example.com:9999')
+        self.assertEqual(rs.next(), None)
+
+
+class MockConnection:
+    def rollback(self):
+        self.rolled_back = True
+
+    def close(self):
+        self.closed = True
+
+class MockCursor:
+    def close(self):
+        self.closed = True
+
+
+def test_suite():
+    suite = unittest.TestSuite()
+    for klass in [
+            AbstractConnectionManagerTests,
+            ReplicaSelectorTests,
+            ]:
+        suite.addTest(unittest.makeSuite(klass))
+    return suite



More information about the checkins mailing list