lists.arthurdejong.org
RSS feed

nss-pam-ldapd commit: r1641 - nss-pam-ldapd/pynslcd

[Date Prev][Date Next] [Thread Prev][Thread Next]

nss-pam-ldapd commit: r1641 - nss-pam-ldapd/pynslcd



Author: arthur
Date: Fri Mar 16 14:53:17 2012
New Revision: 1641
URL: http://arthurdejong.org/viewvc/nss-pam-ldapd?revision=1641&view=revision

Log:
refactor some of the attribute mapping code to introduce a mapping instance 
that does the hard work and support the lower() and upper() attribute mapping 
functions

Modified:
   nss-pam-ldapd/pynslcd/attmap.py
   nss-pam-ldapd/pynslcd/common.py
   nss-pam-ldapd/pynslcd/pam.py

Modified: nss-pam-ldapd/pynslcd/attmap.py
==============================================================================
--- nss-pam-ldapd/pynslcd/attmap.py     Fri Mar 16 13:48:28 2012        (r1640)
+++ nss-pam-ldapd/pynslcd/attmap.py     Fri Mar 16 14:53:17 2012        (r1641)
@@ -1,7 +1,7 @@
 
 # attmap.py - attribute mapping class
 #
-# Copyright (C) 2011 Arthur de Jong
+# Copyright (C) 2011, 2012 Arthur de Jong
 #
 # This library is free software; you can redistribute it and/or
 # modify it under the terms of the GNU Lesser General Public
@@ -37,6 +37,11 @@
 '"${gecos:-$cn}"'
 """
 
+import ldap
+import re
+from ldap.filter import escape_filter_chars as escape
+
+
 # exported names
 __all__ = ('Attributes', )
 
@@ -44,7 +49,10 @@
 # FIXME: support multiple attribute values
 # TODO: support objectSid attributes
 # TODO: do more expression validity checking
-# TODO: handle userPassword specially to do filtering of results
+
+
+# regular expression to match function attributes
+attribute_func_re = re.compile('^(?P<function>[a-z]+)\((?P<attribute>.*)\)$')
 
 
 class MyIter(object):
@@ -105,11 +113,11 @@
             return self.expr.value(variables) if value else ''
         return value
 
-    def attributes(self, results):
-        """Add the attributes used in the expression to results."""
+    def variables(self, results):
+        """Add the variables used in the expression to results."""
         results.add(self.name)
         if self.expr:
-            self.expr.attributes(results)
+            self.expr.variables(results)
 
 
 class Expression(object):
@@ -150,13 +158,13 @@
                 res += x
         return res
 
-    def attributes(self, results=None):
+    def variables(self, results=None):
         """Return the attributes defined in the expression."""
         if not results:
             results = set()
         for x in self.expr:
-            if hasattr(x, 'attributes'):
-                x.attributes(results)
+            if hasattr(x, 'variables'):
+                x.variables(results)
         return results
 
     def __str__(self):
@@ -166,36 +174,110 @@
         return repr(str(self))
 
 
+class SimpleMapping(str):
+    """Simple mapping to another attribute name."""
+
+    def attributes(self):
+        return [self]
+
+    def mk_filter(self, value):
+        return '(%s=%s)' % (self, escape(str(value)))
+
+    def values(self, variables):
+        """Expand the expression using the variables specified."""
+        return variables.get(self, [])
+
+
+class ExpressionMapping(str):
+    """Class for parsing and expanding an expression."""
+
+    def __init__(self, value):
+        """Parse the expression as a string."""
+        self.expression = Expression(value)
+
+    def values(self, variables):
+        """Expand the expression using the variables specified."""
+        return [self.expression.value(variables)]
+
+    def attributes(self):
+        """Return the attributes defined in the expression."""
+        return self.expression.variables()
+
+
+class FunctionMapping(str):
+    """Mapping to a function to another attribute."""
+
+    def __init__(self, mapping):
+        self.mapping = mapping
+        m = attribute_func_re.match(mapping)
+        self.attribute = m.group('attribute')
+        self.function = getattr(self, m.group('function'))
+
+    def upper(self, value):
+        return value.upper()
+
+    def lower(self, value):
+        return value.lower()
+
+    def attributes(self):
+        return [self.attribute]
+
+    def mk_filter(self, value):
+        return '(%s=%s)' % (self.attribute, escape(value))
+
+    def values(self, variables):
+        return [self.function(value)
+                for value in variables.get(self.attribute, [])]
+
+
 class Attributes(dict):
     """Dictionary-like class for handling attribute mapping."""
 
-    def _prepare(self):
-        """Go over all values to parse any expressions."""
-        updates = dict()
-        for k, v in self.iteritems():
-            if isinstance(v, basestring) and v[0] == '"':
-                updates[k] = Expression(v)
-        self.update(updates)
+    def __init__(self, *args, **kwargs):
+        self.update(*args, **kwargs)
+
+    def __setitem__(self, attribute, mapping):
+        # translate the mapping into a mapping object
+        if mapping[0] == '"':
+            mapping = ExpressionMapping(mapping)
+        elif '(' in mapping:
+            mapping = FunctionMapping(mapping)
+        else:
+            mapping = SimpleMapping(mapping)
+        super(Attributes, self).__setitem__(attribute, mapping)
+
+    def update(self, *args, **kwargs):
+        for arg in args:
+            other = dict(arg)
+            for key in other:
+                self[key] = other[key]
+        for key in kwargs:
+            self[key] = kwargs[key]
 
     def attributes(self):
-        """Return the list of attributes that are referenced in this attribute
-        mapping."""
-        self._prepare()
+        """Return the list of attributes that are referenced in this
+        attribute mapping. These are the attributes that should be
+        requested in the search."""
         attributes = set()
-        for value in self.itervalues():
-            if hasattr(value, 'attributes'):
-                attributes.update(value.attributes())
-            else:
-                attributes.add(value)
+        for mapping in self.itervalues():
+            attributes.update(mapping.attributes())
         return list(attributes)
 
+    def mk_filter(self, attribute, value):
+        """Construct a search filter for searching for the attribute value
+        combination."""
+        mapping = self.get(attribute, SimpleMapping(attribute))
+        return mapping.mk_filter(value)
+
     def translate(self, variables):
         """Return a dictionary with every attribute mapped to their value from
         the specified variables."""
         results = dict()
-        for k, v in self.iteritems():
-            if hasattr(v, 'value'):
-                results[k] = [v.value(variables)]
-            else:
-                results[k] = variables.get(v, [])
+        for attribute, mapping in self.iteritems():
+            results[attribute] = mapping.values(variables)
         return results
+
+    def get_rdn_value(self, dn, attribute):
+        """Extract the attribute value from from DN if possible. Return None
+        otherwise."""
+        return self.translate(dict((x, [y]) for x, y, z in 
ldap.dn.str2dn(dn)[0]))[attribute][0]

Modified: nss-pam-ldapd/pynslcd/common.py
==============================================================================
--- nss-pam-ldapd/pynslcd/common.py     Fri Mar 16 13:48:28 2012        (r1640)
+++ nss-pam-ldapd/pynslcd/common.py     Fri Mar 16 14:53:17 2012        (r1641)
@@ -108,7 +108,7 @@
         """Return the results from the search."""
         filter = self.mk_filter()
         for base in self.bases:
-            logging.debug('SEARCHING %s', base)
+            logging.debug('SEARCHING %s %s', base, filter)
             try:
                 for entry in self.conn.search_s(base, self.scope, filter, 
self.attributes):
                     if entry[0]:
@@ -126,9 +126,10 @@
     def mk_filter(self):
         """Return the active search filter (based on the read parameters)."""
         if self.parameters:
-            return '(&%s(%s))' % (self.filter,
-                ')('.join('%s=%s' % (self.attmap[attribute], 
self.escape(value))
-                          for attribute, value in self.parameters.items()))
+            return '(&%s%s)' % (
+                self.filter,
+                ''.join(self.attmap.mk_filter(attribute, value)
+                        for attribute, value in self.parameters.items()))
         return self.filter
 
     def handle_entry(self, dn, attributes):
@@ -138,7 +139,7 @@
         attributes = self.attmap.translate(attributes)
         # make sure value from DN is first value
         for attr in self.canonical_first:
-            primary_value = get_rdn_value(dn, self.attmap[attr])
+            primary_value = self.attmap.get_rdn_value(dn, attr)
             if primary_value:
                 values = attributes[attr]
                 if primary_value in values:
@@ -239,7 +240,3 @@
         if issubclass(cls, Request) and hasattr(cls, 'action'):
             res[cls.action] = cls
     return res
-
-
-def get_rdn_value(dn, attribute):
-    return dict((x, y) for x, y, z in ldap.dn.str2dn(dn)[0])[attribute]

Modified: nss-pam-ldapd/pynslcd/pam.py
==============================================================================
--- nss-pam-ldapd/pynslcd/pam.py        Fri Mar 16 13:48:28 2012        (r1640)
+++ nss-pam-ldapd/pynslcd/pam.py        Fri Mar 16 14:53:17 2012        (r1641)
@@ -56,10 +56,10 @@
             # save the DN
             parameters['userdn'] = entry[0]
             # get the "real" username
-            value = common.get_rdn_value(entry[0], passwd.attmap['uid'])
+            value = passwd.attmap.get_rdn_value(entry[0], 'uid')
             if not value:
                 # get the username from the uid attribute
-                values = myldap_get_values(entry, passwd.attmap['uid'])
+                values = entry[1]['uid']
                 if not values or not values[0]:
                     logging.warning('%s: is missing a %s attribute', dn, 
passwd.attmap['uid'])
                 value = values[0]
-- 
To unsubscribe send an email to
nss-pam-ldapd-commits-unsubscribe@lists.arthurdejong.org or see
http://lists.arthurdejong.org/nss-pam-ldapd-commits/