[Checkins] SVN: zc.buildout/branches/sylvain-distribution-cache/src/zc/buildout/easy_install.py Improve code by factoring a RequirementSet.

Sylvain Viollon sylvain at infrae.com
Thu Jan 29 11:00:55 EST 2009


Log message for revision 95444:
  Improve code by factoring a RequirementSet.
  
  

Changed:
  U   zc.buildout/branches/sylvain-distribution-cache/src/zc/buildout/easy_install.py

-=-
Modified: zc.buildout/branches/sylvain-distribution-cache/src/zc/buildout/easy_install.py
===================================================================
--- zc.buildout/branches/sylvain-distribution-cache/src/zc/buildout/easy_install.py	2009-01-29 14:55:57 UTC (rev 95443)
+++ zc.buildout/branches/sylvain-distribution-cache/src/zc/buildout/easy_install.py	2009-01-29 16:00:54 UTC (rev 95444)
@@ -128,6 +128,49 @@
     'from setuptools.command.easy_install import main; main()'
     )
 
+class RequirementSet:
+    """A requirement set remerbers requirements objects.
+    """
+
+    def __init__(self):
+        self.req = {}
+
+    def append(self, requirement):
+        """Add a requirement to the set.
+        """
+        all = self.req.get(requirement.key, [])
+        all.append(requirement)
+        self.req[requirement.key] = all
+
+    def extend(self, other):
+        """Extend the set with one other.
+        """
+        for key, requirements in other.req.items():
+            all = self.req.get(key, [])
+            all.extend(requirements)
+            self.req[key] = all
+
+    def integrate(self, dists, ws):
+        """Add the list of distribution dists in the working set ws if
+        they match all the requirements.
+        """
+        for dist in dists:
+            dist_requirements = self.req.get(dist.key, None)
+            if dist_requirements and dist in ws:
+                try:
+                    for requirement in dist_requirements:
+                        ws.find(requirement):
+                except pkg_resources.VersionConflict, err:
+                    # TODO We can have a better conflict resolution
+                    # here. For the moment this works like the trunk,
+                    # but if we make a RequirementSet for the given
+                    # ws, if the conflicteous packages does not have a
+                    # requirement, or they match ours, we can replace it.
+                    raise VersionConflict(err, tmp_ws)
+            else:
+                ws.add(dist)
+
+
 class Installer:
 
     _versions = {}
@@ -600,6 +643,17 @@
 
         return requirement
 
+    def _constrain_safe(self, requirement):
+        constrained = self._constrain(requirement)
+        if requirement != constrained:
+            return constrained, False
+        return requirement, bool(requirement.specs)
+
+    def _get_cached_distribution(self, spec, ws):
+        dists, requirement = distribution_cache[spec]
+        requirement.integrate(dists, ws)
+        return dists, requirement
+
     def install(self, specs, working_set=None):
 
         logger.debug('Installing %s.', repr(specs)[1:-1])
@@ -609,22 +663,22 @@
         if dest is not None and dest not in path:
             path.insert(0, dest)
 
-        requirements = [self._constrain(pkg_resources.Requirement.parse(spec))
-                        for spec in specs]
-
         if working_set is None:
             ws = pkg_resources.WorkingSet([])
         else:
             ws = working_set
 
-        dists_ws = []
-        for requirement in requirements:
+
+        for spec in specs:
+            requirement = pkg_resources.Requirement.parse(spec)
+            requirement, not_safe = self._constrain_safe(requirement)
+
+            # Clean spec
             spec = str(requirement)
-
             if distribution_cache.has_key(spec):
                 logger.debug(
                     "Hit distribution cache for %r", spec)
-                dists_ws.extend(distribution_cache[spec])
+                self._get_cached_distribution(spec, ws)
                 continue
 
             # That's a problem. Two different requirements can pick
@@ -632,6 +686,9 @@
 
             tmp_ws = pkg_resources.WorkingSet([])
             tmp_dists = []
+            tmp_requirement = RequirementSet()
+            if not_safe:
+                tmp_requirement.append(requirement)
 
             for dist in self._get_dist(
                 requirement, tmp_ws, self._always_unzip):
@@ -653,38 +710,44 @@
                     tmp_ws.resolve([requirement])
                 except pkg_resources.DistributionNotFound, err:
                     [missing_requirement] = err
-                    missing_requirement = self._constrain(missing_requirement)
+                    missing_requirement, missing_not_safe = \
+                        self._constrain_safe(missing_requirement)
                     missing_spec = str(missing_requirement)
+
+                    if distribution_cache.has_key(missing_spec):
+                        logger.debug(
+                            "Hit distribution cache for %r", missing_spec)
+                        missing_dists, missing_reqs = \
+                            self._get_cached_distribution(
+                            missing_spec, tmp_ws)
+                        # Merge cache results
+                        tmp_dists.extend(missing_dists)
+                        tmp_requirement.extend(missing_reqs)
+                        continue
+
                     if dest:
                         logger.debug('Getting required %r', missing_spec)
                     else:
                         logger.debug('Adding required %r', missing_spec)
                     _log_requirement(tmp_ws, missing_requirement)
 
-                    if distribution_cache.has_key(missing_spec):
-                        logger.debug(
-                            "Hit distribution cache for %r", missing_spec)
-                        tmp_dists.extend(distribution_cache[missing_spec])
-                        for missing_dist in distribution_cache[missing_spec]:
-                            tmp_ws.add(missing_dist)
-                        continue
+                    if missing_not_safe:
+                        tmp_requirement.append(missing_requirement)
 
                     for dist in self._get_dist(
                         missing_requirement, tmp_ws, self._always_unzip):
-
                         tmp_ws.add(dist)
                         tmp_dists.append(dist)
                         self._maybe_add_setuptools(tmp_ws, dist)
+
                 except pkg_resources.VersionConflict, err:
                     raise VersionConflict(err, tmp_ws)
                 else:
                     break
 
-            dists_ws.extend(tmp_dists)
-            distribution_cache[spec] = tmp_dists
+            distribution_cache[spec] = [tmp_dists, tmp_requirement]
+            self._get_cached_distribution(spec, ws)
 
-        for dist in dists_ws:
-            ws.add(dist)
         return ws
 
     def build(self, spec, build_ext):



More information about the Checkins mailing list