[Zope-CVS] CVS: Products/Ape/lib/apelib/config - __init__.py:1.1.2.1 apeconf.py:1.1.2.1 common.py:1.1.2.1 interfaces.py:1.1.2.1 minitables.py:1.1.2.1

Shane Hathaway shane@zope.com
Mon, 7 Jul 2003 18:59:22 -0400


Update of /cvs-repository/Products/Ape/lib/apelib/config
In directory cvs.zope.org:/tmp/cvs-serv2793/config

Added Files:
      Tag: ape-newconf-branch
	__init__.py apeconf.py common.py interfaces.py minitables.py 
Log Message:
Implemented XML-based configuration.  The filesystem tests pass.

Used an experimental approach for mixing configuration from multiple sources.
Take a look at zope2/apeconf.xml.


=== Added File Products/Ape/lib/apelib/config/__init__.py ===
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""Ape configuration package.

$Id: __init__.py,v 1.1.2.1 2003/07/07 22:59:13 shane Exp $
"""


=== Added File Products/Ape/lib/apelib/config/apeconf.py ===
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""Ape configuration assembler.

$Id: apeconf.py,v 1.1.2.1 2003/07/07 22:59:13 shane Exp $
"""

from apelib.core.mapper import Mapper
from apelib.core.serializers import CompositeSerializer, AnyObjectSerializer
from apelib.core.gateways import CompositeGateway
from apelib.core.interfaces import IDatabaseInitializer

from minitables import Table, TableSchema
from common import Directive, DirectiveReader, ComponentSystem

class AssemblyError(Exception):
    """Error while assembling components"""


class MapperDeclaration(Directive):
    schema = TableSchema()
    schema.addColumn('mapper_name', primary=1)


class MapperAttribute(Directive):
    schema = TableSchema()
    schema.addColumn('mapper_name', primary=1)
    # Attribute names: 'class', 'parent', 'extends'
    schema.addColumn('name', primary=1, indexed=1)
    schema.addColumn('value')


class ComponentDefinition(Directive):
    schema = TableSchema()
    # comptypes: 'serializer', 'gateway', 'classifier', 'keygen'
    schema.addColumn('comptype', primary=1)
    schema.addColumn('name', primary=1)
    schema.addColumn('producer')


class MapperComponent(Directive):
    schema = TableSchema()
    schema.addColumn('mapper_name', primary=1)
    schema.addColumn('comptype', primary=1)
    schema.addColumn('producer')


class MapperCompositeComponent(Directive):
    schema = TableSchema()
    schema.addColumn('mapper_name', primary=1)
    schema.addColumn('comptype', primary=1)
    schema.addColumn('name', primary=1)
    schema.addColumn('producer')
    schema.addColumn('order')


class MapperRegistration(Directive):
    # Contains a use-for directive
    schema = TableSchema()
    schema.addColumn('mapper_name', primary=1)
    schema.addColumn('attr', primary=1)
    schema.addColumn('value', primary=1)  # Multiple values allowed


class ClassifierOption(Directive):
    schema = TableSchema()
    schema.addColumn('mapper_name', primary=1)
    schema.addColumn('option', primary=1)
    schema.addColumn('value')


class DisabledProducer:
    def __init__(self, source):
        self.source = source

    def __call__(self, compsys):
        return None


class UseProducer:
    def __init__(self, source, comptype, name):
        self.source = source
        self.comptype = comptype
        self.name = name

    def __call__(self, compsys):
        return compsys.get(self.comptype, self.name)


class FactoryProducer:

    def __init__(self, source, factory_name, param=None):
        self.source = source
        pos = factory_name.rfind('.')
        if pos <= 0:
            raise ValueError, "Module and class name required"
        self.module_name = factory_name[:pos]
        self.class_name = factory_name[pos + 1:]
        self.param = param
        self.sub_producer = None

    def __call__(self, compsys):
        params = []
        if self.param is not None:
            params.append(self.param)
        if self.sub_producer is not None:
            params.append(self.sub_producer(compsys))
        m = __import__(self.module_name, {}, {}, ('__doc__',))
        try:
            c = getattr(m, self.class_name)
        except AttributeError:
            raise ImportError, "No class %s in module %s" % (
                self.class_name, self.module_name)
        return c(*params)


def makeProducer(source, comptype, attrs, raise_exc=1):
    if attrs.get('enabled', '').lower() == 'false':
        return DisabledProducer(source)
    elif attrs.has_key('use'):
        if attrs.has_key('factory'):
            raise ValueError, "Both 'use' and 'factory' not allowed"
        return UseProducer(source, comptype, attrs['use'])
    elif attrs.has_key('factory'):
        param = attrs.get('param')
        return FactoryProducer(source, attrs['factory'], param)
    elif raise_exc:
        raise ValueError, "Need a 'use', 'factory', or 'enabled' attribute"
    else:
        return None


def getElementHandlers():
    """Returns a dictionary of XML element handlers.
    """

    def handle_configuration(source, vars, attrs):
        assert vars.has_key('directives')

    def handle_variant(source, vars, attrs):
        variants = vars['variants']
        name = attrs['name']
        d = variants.get(name)
        if d is None:
            # Make a new variant.
            d = []
            variants[name] = d
        vars['directives'] = d

    def handle_mapper(source, vars, attrs):
        d = vars['directives']
        mapper_name = attrs['name']
        vars['mapper_name'] = mapper_name
        d.append(MapperDeclaration(source, mapper_name))
        for key in ('class', 'parent', 'extends'):
            if attrs.has_key(key):
                d.append(MapperAttribute(source, mapper_name, key, attrs[key]))

    def handle_component(source, vars, attrs, comptype, multiple):
        d = vars['directives']
        producer = makeProducer(source, comptype, attrs)
        mapper_name = vars.get('mapper_name')
        if mapper_name is None:
            # Reusable component
            name = attrs['name']
            directive = ComponentDefinition(source, comptype, name, producer)
        elif multiple:
            # Composite component of a mapper
            name = attrs['name']
            directive = MapperCompositeComponent(
                source, mapper_name, comptype,
                name, producer, attrs.get('order', 'middle'))
        else:
            # Singular component of a mapper
            directive = MapperComponent(
                source, mapper_name, comptype, producer)
        d.append(directive)
        return producer

    def handle_serializer(source, vars, attrs):
        handle_component(source, vars, attrs, 'serializer', multiple=1)

    def handle_gateway(source, vars, attrs):
        p = vars.get('classifier_producer')
        if p is not None:
            # Set a gateway for a classifier.
            if not hasattr(p, 'sub_producer'):
                raise ValueError(
                    "Classifier at %s needs a factory in order to "
                    "use a gateway" % source)
            if p.sub_producer is not None:
                raise ValueError(
                    "Multiple gateways in classifiers not allowed at %s" %
                    source)
            p.sub_producer = makeProducer(source, 'gateway', attrs)
        else:
            handle_component(source, vars, attrs, 'gateway', multiple=1)

    def handle_classifier(source, vars, attrs):
        p = handle_component(source, vars, attrs, 'classifier', multiple=0)
        vars['classifier_producer'] = p

    def handle_keygen(source, vars, attrs):
        handle_component(source, vars, attrs, 'keygen', multiple=0)

    def handle_use_for(source, vars, attrs):
        d = vars['directives']
        mapper_name = vars['mapper_name']
        for attr in ('class', 'extensions', 'fallback', 'key'):
            if attrs.has_key(attr):
                v = attrs[attr]
                if attr == 'extensions':
                    first = 1
                    for ext in v.split():
                        if not ext.startswith('.'):
                            ext = '.' + ext
                        ext = ext.lower()
                        d.append(MapperRegistration(
                            source, mapper_name, 'extension', ext))
                        if first:
                            # Use a classifier option to set the default
                            # extension.
                            first = 0
                            d.append(ClassifierOption(
                                source, mapper_name, 'default_extension', ext))
                else:
                    d.append(MapperRegistration(
                        source, mapper_name, attr, v))

    def handle_option(source, vars, attrs):
        d = vars['directives']
        mapper_name = vars['mapper_name']
        name = attrs['name']
        value = attrs['value']
        d.append(ClassifierOption(source, mapper_name, name, value))

    handlers = {
        'configuration': handle_configuration,
        'variant':       handle_variant,
        'mapper':        handle_mapper,
        'serializer':    handle_serializer,
        'gateway':       handle_gateway,
        'classifier':    handle_classifier,
        'keygen':        handle_keygen,
        'use-for':       handle_use_for,
        'option':        handle_option,
        }

    return handlers



class BasicComponentAssembler:
    """Assembler for simple components.

    Configures at the time of creation.
    """

    def __init__(self, compsys, comptype, name):
        self.compsys = compsys
        records = compsys.dtables.query(
            ComponentDefinition, comptype=comptype, name=name)
        if not records:
            raise AssemblyError("No %s component named %s exists"
                                % (comptype, repr(name)))
        assert len(records) == 1
        producer = records[0]['producer']
        self.producer = producer

    def create(self):
        return self.producer(self.compsys)

    def configure(self):
        pass


class ClassifierAssembler:
    """Produces a classifier.
    """
    def __init__(self, compsys, comptype, name):
        self.compsys = compsys
        records = compsys.dtables.query(
            ComponentDefinition, comptype=comptype, name=name)
        if not records:
            raise AssemblyError("No %s component named %s exists"
                                % (comptype, repr(name)))
        assert len(records) == 1
        producer = records[0]['producer']
        self.producer = producer

    def create(self):
        return self.producer.create(self.compsys)

    def configure(self):
        self.producer.configure(self.compsys)


class MapperAssembler:
    """Assembler for mapper components.
    """

    def __init__(self, compsys, comptype, name):
        self.compsys = compsys
        self.dtables = dtables = compsys.dtables
        self.mapper_name = name
        if not dtables.query(MapperDeclaration, mapper_name=name):
            raise AssemblyError("No mapper named %s exists" % repr(name))
        self.subobjs = []  # all subobjects
        self.attrs = {}
        for record in dtables.query(MapperAttribute, mapper_name=name):
            self.attrs[record['name']] = record['value']
        self.prepareSubComponents()

    def prepareSubComponents(self):
        self.single_comps = {} # comptype -> record
        self.multi_comps = {}  # comptype -> name -> record
        name = self.mapper_name
        all_names = []  # mapper_name and all of its base mapper_names
        while name:
            all_names.append(name)
            records = self.dtables.query(MapperComponent, mapper_name=name)
            for r in records:
                self.single_comps.setdefault(r['comptype'], r)
            records = self.dtables.query(
                MapperCompositeComponent, mapper_name=name)
            for r in records:
                d = self.multi_comps.setdefault(r['comptype'], {})
                d.setdefault(r['name'], r)
            name = self.dtables.queryField(
                MapperAttribute, 'value', mapper_name=name, name='extends')
            if name and name in all_names:
                raise AssemblyError(
                    "Circular extension in mappers %s" % repr(all_names))
        self.sub_mapper_names = []
        records = self.dtables.query(
            MapperAttribute, name='parent', value=self.mapper_name)
        for r in records:
            self.sub_mapper_names.append(r['mapper_name'])

    def create(self):
        self.obj = Mapper()
        return self.obj

    def configure(self):
        self.setParent()
        self.addSerializers()
        self.addGateways()
        self.setClassifier()
        self.setKeygen()
        self.addInitializers()
        self.addSubMappers()
        self.registerClassifications()

    def setParent(self):
        if self.attrs.get('parent'):
            p = self.compsys.get('mapper', self.attrs['parent'])
            self.obj.setParent(p)
        else:
            self.obj.setParent(None)

    def addSerializers(self):
        cname = self.attrs.get('class')
        if cname == 'any':
            # This mapper is usable for many classes
            s = AnyObjectSerializer()
        elif cname == 'none':
            # This mapper is abstract (usable for no classes)
            s = CompositeSerializer(None, None)
        else:
            # This mapper is concrete (usable for one class only)
            if cname is None:
                cname = self.mapper_name
            pos = cname.rfind('.')
            if pos < 0:
                raise AssemblyError("Class name must include a module name")
            s = CompositeSerializer(cname[:pos], cname[pos + 1:])
        
        d = self.multi_comps.get('serializer')
        if d:
            ordered = [(r.get('order', ''), name, r) for name, r in d.items()]
            ordered.sort()
            for order, name, r in ordered:
                o = r['producer'](self.compsys)
                if o is not None:
                    s.addSerializer(str(name), o)
                    self.subobjs.append(o)
        self.obj.setSerializer(s)

    def addGateways(self):
        g = CompositeGateway()
        d = self.multi_comps.get('gateway')
        if d:
            for name, r in d.items():
                o = r['producer'](self.compsys)
                if o is not None:
                    g.addGateway(str(name), o)
                    self.subobjs.append(o)
        self.obj.setGateway(g)

    def setClassifier(self):
        r = self.single_comps.get('classifier')
        if r:
            o = r['producer'](self.compsys)
            if o is not None:
                self.obj.setClassifier(o)
                self.subobjs.append(o)

    def setKeygen(self):
        r = self.single_comps.get('keygen')
        if r:
            o = r['producer'](self.compsys)
            if o is not None:
                self.obj.setKeychainGenerator(o)
                self.subobjs.append(o)

    def addInitializers(self):
        for o in self.subobjs:
            if IDatabaseInitializer.isImplementedBy(o):
                self.obj.addInitializer(o)

    def addSubMappers(self):
        for name in self.sub_mapper_names:
            o = self.compsys.get('mapper', name)
            self.obj.addSubMapper(name, o)

    def registerClassifications(self):
        """Registers classifications on behalf of sub-mappers."""
        all_regs = {}     # { (attr, value) -> mapper_name }
        all_options = {}  # { (mapper_name, option) -> value }
        need_classifier = 0
        for name in self.sub_mapper_names:
            # use-for directives
            records = self.dtables.query(
                MapperRegistration, mapper_name=name)
            for r in records:
                key = ((r['attr'], r['value']))
                if all_regs.has_key(key) and all_regs[key] != name:
                    raise AssemblyError(
                        "Mappers %s and %s are contending over %s == %s" % (
                        name, all_regs[key],
                        r['attr'], repr(r['value'])))
                all_regs[key] = name
                need_classifier = 1

            # class="" attributes
            class_name = self.dtables.queryField(
                MapperAttribute, 'value', mapper_name=name, name='class')
            if class_name is None:
                class_name = name
            elif class_name in ('none', 'any'):
                class_name = None
            if class_name is not None:
                # Add an implicit use-for directive
                key = ('class', class_name)
                if all_regs.has_key(key) and all_regs[key] != name:
                    raise AssemblyError(
                        "Mappers %s and %s are contending over %s == %s" % (
                        name, all_regs[key],
                        'class', repr(class_name)))
                all_regs[key] = name

            # options
            records = self.dtables.query(
                ClassifierOption, mapper_name=name)
            for r in records:
                all_options[(name, r['option'])] = r['value']
                need_classifier = 1

        if all_regs or all_options:
            cfr = self.obj.getClassifier()
            if cfr is None:
                if not need_classifier:
                    return
                raise AssemblyError(
                    "Mapper %s needs a classifier because it has "
                    "sub-mappers with registrations" % self.mapper_name)
            for (attr, value), name in all_regs.items():
                cfr.register(attr, value, name)
            for (name, option), value in all_options.items():
                cfr.setOption(name, option, value)


def makeComponentSystem(filenames, vnames=('',)):
    """Returns an Ape component system.
    """
    handlers = getElementHandlers()
    reader = DirectiveReader(handlers)
    for fn in filenames:
        reader.read(fn)
    directives = reader.getDirectives(vnames)
    cs = ComponentSystem(directives)
    cs.addComponentType('mapper', MapperAssembler)
    for comptype in ('serializer', 'gateway', 'classifier', 'keygen'):
        cs.addComponentType(comptype, BasicComponentAssembler)
    return cs



=== Added File Products/Ape/lib/apelib/config/common.py ===
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""Bits useful for configuration.  May move to its own package.

$Id: common.py,v 1.1.2.1 2003/07/07 22:59:13 shane Exp $
"""

import xml.sax.handler
from xml.sax import parse

import minitables


class Directive:
    """Abstract base class for table-oriented directives.
    """

    schema = None  # override

    def __init__(self, source, *args, **kw):
        self.source = source
        if args:
            columns = self.schema.getColumns()
            for n in range(len(args)):
                key = columns[n].name
                if kw.has_key(key):
                    raise TypeError(
                        '%s supplied as both positional and keyword argument'
                        % repr(key))
                kw[key] = args[n]
        self.data = kw
        unique_key = [self.__class__]
        for name in self.schema.getPrimaryNames():
            unique_key.append(kw[name])
        self.unique_key = tuple(unique_key)

    def getUniqueKey(self):
        return self.unique_key

    def index(self, tables):
        t = tables.get(self.__class__)
        if t is None:
            t = minitables.Table(self.schema)
            tables[self.__class__] = t
        t.insert(self.data)

    def __eq__(self, other):
        if other.__class__ is self.__class__:
            return other.data == self.data
        return 0

    def __ne__(self, other):
        # XXX shouldn't this be implicit?
        return not self.__eq__(other)

    def __repr__(self):
        return "<%s from %s with %s>" % (
            self.__class__.__name__, repr(self.source), repr(self.data))



class XMLConfigReader (xml.sax.handler.ContentHandler):
    """Reads configuration from XML files.
    """

    def __init__(self, handlers):
        self.handlers = handlers
        # Set up a directive list in a default variant.
        directives = []
        self.variants = {'': directives}
        self.stack = [{'directives': directives, 'variants': self.variants}]
        self.locator = None

    def setDocumentLocator(self, locator):
        self.locator = locator

    def startElement(self, name, attrs):
        vars = self.stack[-1].copy()
        self.stack.append(vars)
        handler = self.handlers[name]
        locator = self.locator
        if locator is not None:
            source = (locator.getSystemId(), locator.getLineNumber())
        else:
            source = ("unknown", 0)
        handler(source, vars, attrs)

    def endElement(self, name):
        del self.stack[-1]
        


class DirectiveReader:

    def __init__(self, handlers):
        self.directives = {}  # { unique key -> variant -> directive }
        self.handlers = handlers

    def read(self, filename):
        reader = XMLConfigReader(self.handlers)
        parse(filename, reader)
        for vname, directives in reader.variants.items():
            self.add(directives, vname)

    def add(self, directives, vname):
        for d in directives:
            key = d.getUniqueKey()
            info = self.directives.setdefault(key, {})
            if info.has_key(vname):
                if d == info[vname]:
                    # OK
                    pass
                else:
                    raise KeyError(
                        'Conflicting directives: %s != %s' % (
                        repr(d), repr(info[vname])))
            else:
                info[vname] = d

    def getDirectives(self, vnames=('',)):
        res = []
        for key, info in self.directives.items():
            for vname in vnames:
                if info.has_key(vname):
                    res.append(info[vname])
                    break  # Go to next directive
        return res



class DirectiveTables:

    def __init__(self, directives):
        self.tables = {}      # {table name -> table}
        for d in directives:
            d.index(self.tables)

    def query(self, table_name, **filter):
        """Returns the specified directive records.
        """
        t = self.tables.get(table_name)
        if t is None:
            return []
        return t.select(filter)

    def queryField(self, table_name, field, **filter):
        t = self.tables.get(table_name)
        if t is None:
            return None
        records = t.select(filter)
        if len(records) > 1:
            raise LookupError, "More than one record returned from field query"
        if not records:
            return None
        return records[0][field]



class ComponentSystem:

    def __init__(self, directives):
        self.dtables = DirectiveTables(directives)
        self.factories = {}   # {comptype -> assembler factory}
        self.components = {}  # {(comptype, name) -> component}

    def addComponentType(self, comptype, assembler_factory):
        self.factories[comptype] = assembler_factory

    def get(self, comptype, name):
        obj = self.components.get((comptype, name))
        if obj is not None:
            return obj
        f = self.factories[comptype]
        assembler = f(self, comptype, name)
        obj = assembler.create()
        self.components[(comptype, name)] = obj
        assembler.configure()
        return obj




=== Added File Products/Ape/lib/apelib/config/interfaces.py ===
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""Configuration interfaces.

$Id: interfaces.py,v 1.1.2.1 2003/07/07 22:59:13 shane Exp $
"""

from Interface import Interface

class IDirective(Interface):
    """A configuration directive.
    """

    def getUniqueKey():
        """Returns a key that distinguishes this directive from all others.

        This is used to detect conflicting directives.  The result
        must be hashable.  It normally includes the type (class or
        interface) of the directive.  If this returns None, the
        directive conflicts with nothing.
        """

    def index(tables):
        """Adds self to a table.

        tables is a mapping from table name to table.  The table name
        is usually the class of the directive.
        """

# IAssembler, IComponentSystem, etc.



=== Added File Products/Ape/lib/apelib/config/minitables.py ===
##############################################################################
#
# Copyright (c) 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""In-memory tables with support for basic relational operations.

$Id: minitables.py,v 1.1.2.1 2003/07/07 22:59:13 shane Exp $
"""


from BTrees.IIBTree import IISet, intersection


class DuplicateError(Exception):
    """Duplicated data record"""


class Column:

    def __init__(self, name, primary, indexed):
        self.name = name        # string
        self.primary = primary  # boolean
        self.indexed = indexed  # boolean


class TableSchema:

    reserved_names = ('rid',)

    def __init__(self):
        self.columns = []
        self.column_names = {}
        self.primary_names = []

    def addColumn(self, name, primary=0, indexed=0):
        if name in self.reserved_names:
            raise ValueError, "Column name %s is reserved" % repr(name)
        if self.column_names.has_key(name):
            raise ValueError, "Column %s already exists" % repr(name)
        self.column_names[name] = 1
        self.columns.append(Column(name, primary, indexed))
        if primary:
            self.primary_names.append(name)

    def getColumns(self):
        return tuple(self.columns)

    def getPrimaryNames(self):
        return self.primary_names



class Table:
    """Simple, generic relational table.
    """
    schema = None

    def __init__(self, schema=None):
        if schema is not None:
            self.schema = schema
        self.columns = self.schema.getColumns()
        self.next_rid = 1
        self.data = {}     # {rid -> record as {column name -> value}}
        self.indexes = {}  # {index_name -> {value -> IISet}}
        self.primary_index = {}  # {primary key -> rid}
        for c in self.columns:
            if c.indexed:
                self.indexes[c.name] = {}


    def makeRecord(self, params):
        """Returns a record containing {name -> value}.
        """
        res = {}
        columns = self.columns
        for column in columns:
            name = column.name
            if params.has_key(name):
                res[name] = params[name]
        if len(params) > len(res):
            raise ValueError, "Too many parameters"
        return res


    def insert(self, params):
        record = self.makeRecord(params)

        # Determine the primary key.
        primary_key = []
        for column in self.columns:
            if column.primary:
                if not record.has_key(column.name):
                    raise ValueError, (
                        "No value provided for primary key column %s"
                        % repr(column.name))
                primary_key.append(record[column.name])
        if primary_key:
            primary_key = tuple(primary_key)
            if self.primary_index.has_key(primary_key):
                raise DuplicateError(
                    "Primary key %s in use" % repr(primary_key))

        # Add a record.
        rid = self.next_rid
        self.next_rid += 1
        record['rid'] = rid
        self.data[rid] = record
        if primary_key:
            self.primary_index[primary_key] = rid

        # Add to indexes.
        for name, value in record.items():
            if self.indexes.has_key(name):
                set = self.indexes[name].get(value)
                if set is None:
                    set = IISet()
                    self.indexes[name][value] = set
                set.insert(rid)

        # Return the number of rows inserted.
        return 1


    def update(self, filter, changes):
        rids = self._select_rids(self.makeRecord(filter))
        if rids is None:
            rids = self.data.keys()
        elif not rids:
            # Nothing needs to be updated.
            return 0

        # Identify changes.
        old_data = {}    # rid -> old record
        new_data = {}    # rid -> new record
        old_to_new = {}  # old primary key -> new primary key
        new_to_rid = {}  # new primary key -> rid

        changes = self.makeRecord(changes)
        for rid in rids:
            old_r = self.data[rid]
            old_data[rid] = old_r
            new_r = old_r.copy()
            new_r.update(changes)
            new_data[rid] = new_r
            opk = []
            npk = []
            for column in self.columns:
                if column.primary:
                    opk.append(old_r[column.name])
                    npk.append(new_r[column.name])
            if opk != npk:
                opk = tuple(opk)
                npk = tuple(npk)
                old_to_new[opk] = npk
                new_to_rid[npk] = rid

        # Look for primary key conflicts.  A primary key conflict can
        # occur when changing a record to a different primary key and
        # the new primary key is already in use.
        for pk in old_to_new.values():
            if (self.primary_index.has_key(pk)
                and not old_to_new.has_key(pk)):
                raise DuplicateError("Primary key %s in use" % repr(pk))

        # Update the data.
        self.data.update(new_data)

        # Remove old primary key indexes and insert new primary key indexes.
        for pk in old_to_new.keys():
            del self.primary_index[pk]
        self.primary_index.update(new_to_rid)

        # Update indexes.
        for rid, old_r in old_data.items():
            for column in self.columns:
                name = column.name
                index = self.indexes.get(name)
                if index is not None and changes.has_key(name):
                    if changes[name] != old_r[name]:
                        # Remove an index entry.
                        old_value = old_r[name]
                        set = index[old_value]
                        set.remove(rid)
                        if not set:
                            del index[old_value]
                        # Add an index entry.
                        new_value = changes[name]
                        set = index.get(new_value)
                        if set is None:
                            set = IISet()
                            index[new_value] = set
                        set.insert(rid)

        # Return the number of rows affected.
        return len(rids)


    def select(self, filter):
        rids = self._select_rids(self.makeRecord(filter))
        if rids is None:
            # All
            return self.data.values()
        elif rids:
            # Some
            return [self.data[rid] for rid in rids]
        else:
            # None
            return []


    def _select_rids(self, query):
        """Searches the table for matches, returning record ids.

        Returns a sequence of record ids, or None for all records.
        """
        # Shortcut: if no query, return all.
        if not query:
            return None

        # First strategy: try to satisfy the request by consulting
        # the primary key index.
        primary_key = []
        filter_columns = []
        for column in self.columns:
            name = column.name
            if query.has_key(name):
                if column.primary:
                    primary_key.append(query[name])
                else:
                    # Specified a value that's not in the primary key
                    filter_columns.append((name, query[name]))
            elif column.primary:
                # Didn't fully specify a primary key
                break
        else:
            if primary_key:
                # To satisfy the request, we only need to look at
                # primary index.
                primary_key = tuple(primary_key)
                rid = self.primary_index.get(primary_key)

                # Possibly filter out the single item
                if rid is not None and filter_columns:
                    cand = self.data[rid]
                    for name, value in filter_columns:
                        if cand[name] != value:
                            # Not a match.
                            rid = None
                            break

                if rid is None:
                    return ()
                else:
                    return (rid,)

        # Second strategy: try to satisfy the request by intersecting
        # indexes.
        filter_columns = []
        rids = None
        for name, value in query.items():
            if self.indexes.has_key(name):
                set = self.indexes[name].get(value)
                if set is None:
                    # No rows satisfy this criterion.
                    return ()
                if rids is None:
                    rids = set
                else:
                    rids = intersection(rids, set)
                if not rids:
                    # No rows satisfy all criteria.
                    return ()
            else:
                filter_columns.append((name, value))
        if not filter_columns:
            # No need to search each record.
            return rids

        # Fallback strategy: Eliminate items one by one.
        if rids is None:
            # Use the whole data set.
            candidates = self.data.items()
        else:
            # Use the specified records.
            candidates = [(rid, self.data[rid]) for rid in rids]

        rids = []
        for rid, cand in candidates:
            for name, value in filter_columns:
                if not cand.has_key(name) or cand[name] != value:
                    # Not a match.
                    break
            else:
                # A match.
                rids.append(rid)
        return rids