Back to index

plone3  3.1.7
SQLMethod.py
Go to the documentation of this file.
00001 from Products.Archetypes.debug import log_exc
00002 
00003 from Shared.DC.ZRDB import Aqueduct, RDB
00004 from Shared.DC.ZRDB.Results import Results
00005 from Shared.DC.ZRDB.DA import SQL
00006 from App.Extensions import getBrain
00007 from cStringIO import StringIO
00008 import sys, types
00009 from ZODB.POSException import ConflictError
00010 
00011 from string import atoi
00012 from time import time
00013 
00014 try:
00015     from IOBTree import Bucket
00016 except:
00017     Bucket = lambda:{}
00018 
00019 _defaults = {'max_rows_':1000,
00020              'cache_time_':0,
00021              'max_cache_': 100,
00022              'class_name_': '',
00023              'class_file_': '',
00024              'template_class': SQL
00025              }
00026 
00027 class SQLMethod(Aqueduct.BaseQuery):
00028 
00029     _arg=None
00030     _col=None
00031 
00032     def __init__(self, context):
00033         self.context = context
00034         self.id = str(context.__class__.__name__)
00035         self.title = ''
00036         for k, v in _defaults.items():
00037             if not hasattr(context, k):
00038                 setattr(context, k, v)
00039 
00040     def edit(self, connection_id, arguments, template):
00041         """Change database method  properties
00042 
00043         The 'connection_id' argument is the id of a database connection
00044         that resides in the current folder or in a folder above the
00045         current folder.  The database should understand SQL.
00046 
00047         The 'arguments' argument is a string containing an arguments
00048         specification, as would be given in the SQL method cration form.
00049 
00050         The 'template' argument is a string containing the source for the
00051         SQL Template.
00052         """
00053         context = self.context
00054         self.connection_id = str(connection_id)
00055         arguments = str(arguments)
00056         self.arguments_src = arguments
00057         self._arg = Aqueduct.parse(arguments)
00058         if not isinstance(template, (str, unicode)):
00059             template = str(template)
00060         self.src = template
00061         self.template = t = context.template_class(template)
00062         t.cook()
00063         context._v_query_cache={}, Bucket()
00064 
00065     def advanced_edit(self, max_rows=1000, max_cache=100, cache_time=0,
00066                         class_name='', class_file='',
00067                         REQUEST=None):
00068         """Change advanced properties
00069 
00070         The arguments are:
00071 
00072         max_rows -- The maximum number of rows to be returned from a query.
00073 
00074         max_cache -- The maximum number of results to cache
00075 
00076         cache_time -- The maximum amound of time to use a cached result.
00077 
00078         class_name -- The name of a class that provides additional
00079           attributes for result record objects. This class will be a
00080           base class of the result record class.
00081 
00082         class_file -- The name of the file containing the class
00083           definition.
00084 
00085         The class file normally resides in the 'Extensions'
00086         directory, however, the file name may have a prefix of
00087         'product.', indicating that it should be found in a product
00088         directory.
00089 
00090         For example, if the class file is: 'ACMEWidgets.foo', then an
00091         attempt will first be made to use the file
00092         'lib/python/Products/ACMEWidgets/Extensions/foo.py'. If this
00093         failes, then the file 'Extensions/ACMEWidgets.foo.py' will be
00094         used.
00095 
00096         """
00097         context = self.context
00098         # paranoid type checking
00099         if type(max_rows) is not type(1):
00100             max_rows = atoi(max_rows)
00101         if type(max_cache) is not type(1):
00102             max_cache = atoi(max_cache)
00103         if type(cache_time) is not type(1):
00104             cache_time = atoi(cache_time)
00105         class_name = str(class_name)
00106         class_file = str(class_file)
00107 
00108         context.max_rows_ = max_rows
00109         context.max_cache_, context.cache_time_ = max_cache, cache_time
00110         context._v_sql_cache = {}, Bucket()
00111         context.class_name_, context.class_file_ = class_name, class_file
00112         context._v_sql_brain = getBrain(context.class_file_,
00113                                         context.class_name_, 1)
00114 
00115     def _cached_result(self, DB__, query):
00116         context = self.context
00117         # Try to fetch from cache
00118         if hasattr(context,'_v_sql_cache'):
00119             cache = context._v_sql_cache
00120         else:
00121             cache = context._v_sql_cache={}, Bucket()
00122         cache, tcache = cache
00123         max_cache = context.max_cache_
00124         now = time()
00125         t = now - context.cache_time_
00126         if len(cache) > max_cache / 2:
00127             keys = tcache.keys()
00128             keys.reverse()
00129             while keys and (len(keys) > max_cache or keys[-1] < t):
00130                 key = keys[-1]
00131                 q = tcache[key]
00132                 del tcache[key]
00133                 if int(cache[q][0]) == key:
00134                     del cache[q]
00135                 del keys[-1]
00136 
00137         if cache.has_key(query):
00138             k, r = cache[query]
00139             if k > t: return r
00140 
00141         result = apply(DB__.query, query)
00142         if context.cache_time_ > 0:
00143             tcache[int(now)] = query
00144             cache[query] = now, result
00145 
00146         return result
00147 
00148     def _get_dbc(self):
00149         """Get the database connection"""
00150         context = self.context
00151 
00152         try:
00153             dbc = getattr(context, self.connection_id)
00154         except AttributeError:
00155             raise AttributeError, (
00156                 "The database connection <em>%s</em> cannot be found." % (
00157                 self.connection_id))
00158 
00159         try:
00160             DB__ = dbc()
00161         except ConflictError:
00162             raise
00163         except:
00164             raise 'Database Error', (
00165             '%s is not connected to a database' % self.id)
00166 
00167         return dbc, DB__
00168 
00169     def __call__(self, src__=0, test__=0, **kw):
00170         """Call the database method
00171 
00172         The arguments to the method should be passed via keyword
00173         arguments, or in a single mapping object. If no arguments are
00174         given, and if the method was invoked through the Web, then the
00175         method will try to acquire and use the Web REQUEST object as
00176         the argument mapping.
00177 
00178         The returned value is a sequence of record objects.
00179         """
00180         context = self.context
00181 
00182         dbc, DB__ = self._get_dbc()
00183 
00184         p = None
00185 
00186         argdata = self._argdata(kw)
00187         argdata['sql_delimiter'] = '\0'
00188         argdata['sql_quote__'] = dbc.sql_quote__
00189 
00190         # TODO: Review the argdata dictonary. The line bellow is receiving unicode
00191         # strings, mixed with standard strings. It is insane! Archetypes needs a policy
00192         # about unicode, and lots of tests on this way. I prefer to not correct it now,
00193         # only doing another workarround. We need to correct the cause of this problem,
00194         # not its side effects :-(
00195 
00196         try:
00197             query = apply(self.template, (p,), argdata)
00198         except TypeError, msg:
00199             msg = str(msg)
00200             if 'client' in msg:
00201                 raise NameError("'client' may not be used as an " +
00202                                 "argument name in this context")
00203             else: raise
00204 
00205         __traceback_info__ = query
00206 
00207         if src__: return query
00208 
00209         # Get the encoding arguments
00210         # We have two possible kw arguments:
00211         #   db_encoding:        The encoding used in the external database
00212         #   site_encoding:      The uncoding used for the site
00213         #                       If not specified, we use sys.getdefaultencoding()
00214         db_encoding = kw.get('db_encoding',None)
00215 
00216         try:
00217             site_encoding = kw.get('site_encoding', context.portal_properties.site_properties.default_charset)
00218         except AttributeError, KeyError:
00219             site_encoding = kw.get('site_encoding',sys.getdefaultencoding())
00220 
00221         if type(query) == type(u''):
00222             if db_encoding:
00223                 query = query.encode(db_encoding)
00224             else:
00225                 try:
00226                     query = query.encode(site_encoding)
00227                 except UnicodeEncodeError:
00228                     query = query.encode('UTF-8')
00229 
00230 
00231         if context.cache_time_ > 0 and context.max_cache_ > 0:
00232             result = self._cached_result(DB__, (query, context.max_rows_))
00233         else:
00234             try:
00235                 result = DB__.query(query, context.max_rows_)
00236             except ConflictError:
00237                 raise
00238             except:
00239                 log_exc(msg='Database query failed', reraise=1)
00240 
00241         if hasattr(context, '_v_sql_brain'):
00242             brain = context._v_sql_brain
00243         else:
00244             brain=context._v_sql_brain = getBrain(context.class_file_,
00245                                                 context.class_name_)
00246 
00247         if type(result) is type(''):
00248             f = StringIO()
00249             f.write(result)
00250             f.seek(0)
00251             result = RDB.File(f, brain, p, None)
00252         else:
00253             if db_encoding:
00254                 # Encode result before we wrap it in Result object
00255                 # We will change the encoding from source to either the specified target_encoding
00256                 # or the site default encoding
00257 
00258                 # The data is a list of tuples of column data
00259                 encoded_result = []
00260                 for row in result[1]:
00261                     columns = ()
00262                     for col in row:
00263                         if isinstance(col, types.StringType):
00264                             # coerce column to unicode with database encoding
00265                             newcol = unicode(col,db_encoding)
00266                             # Encode column as string with site_encoding
00267                             newcol = newcol.encode(site_encoding)
00268                         else:
00269                             newcol = col
00270 
00271                         columns += newcol,
00272 
00273                     encoded_result.append(columns)
00274 
00275                 result = (result[0],encoded_result)
00276 
00277             result = Results(result, brain, p, None)
00278 
00279         columns = result._searchable_result_columns()
00280 
00281         if test__ and columns != self._col:
00282             self._col=columns
00283 
00284         # If run in test mode, return both the query and results so
00285         # that the template doesn't have to be rendered twice!
00286         if test__: return query, result
00287 
00288         return result
00289 
00290     def abort(self):
00291         dbc, DB__ = self._get_dbc()
00292         try:
00293             DB__.tpc_abort()
00294         except ConflictError:
00295             raise
00296         except:
00297             log_exc(msg = 'Database abort failed')
00298 
00299     def connectionIsValid(self):
00300         context = self.context
00301         return (hasattr(context, self.connection_id) and
00302                 hasattr(getattr(context, self.connection_id), 'connected'))
00303 
00304     def connected(self):
00305         context = self.context
00306         return getattr(getattr(context, self.connection_id), 'connected')()