Back to index

rabbitmq-server  2.8.4
base.py
Go to the documentation of this file.
00001 import unittest
00002 import stomp
00003 import sys
00004 import threading
00005 
00006 
00007 class BaseTest(unittest.TestCase):
00008 
00009    def create_connection(self, version=None, heartbeat=None):
00010        conn = stomp.Connection(user="guest", passcode="guest",
00011                                version=version, heartbeat=heartbeat)
00012        conn.start()
00013        conn.connect()
00014        return conn
00015 
00016    def create_subscriber_connection(self, dest):
00017        conn = self.create_connection()
00018        listener = WaitableListener()
00019        conn.set_listener('', listener)
00020        conn.subscribe(destination=dest, receipt="sub.receipt")
00021        listener.await()
00022        self.assertEquals(1, len(listener.receipts))
00023        listener.reset()
00024        return conn, listener
00025 
00026    def setUp(self):
00027         self.conn = self.create_connection()
00028         self.listener = WaitableListener()
00029         self.conn.set_listener('', self.listener)
00030 
00031    def tearDown(self):
00032         if self.conn.is_connected():
00033             self.conn.stop()
00034 
00035    def simple_test_send_rec(self, dest, route = None):
00036         self.listener.reset()
00037 
00038         self.conn.subscribe(destination=dest)
00039         self.conn.send("foo", destination=dest)
00040 
00041         self.assertTrue(self.listener.await(), "Timeout, no message received")
00042 
00043         # assert no errors
00044         if len(self.listener.errors) > 0:
00045             self.fail(self.listener.errors[0]['message'])
00046 
00047         # check header content
00048         msg = self.listener.messages[0]
00049         self.assertEquals("foo", msg['message'])
00050         self.assertEquals(dest, msg['headers']['destination'])
00051 
00052    def assertListener(self, errMsg, numMsgs=0, numErrs=0, numRcts=0, timeout=1):
00053         if numMsgs + numErrs + numRcts > 0:
00054             self.assertTrue(self.listener.await(timeout), errMsg + " (#awaiting)")
00055         else:
00056             self.assertFalse(self.listener.await(timeout), errMsg + " (#awaiting)")
00057         self.assertEquals(numMsgs, len(self.listener.messages), errMsg + " (#messages)")
00058         self.assertEquals(numErrs, len(self.listener.errors), errMsg + " (#errors)")
00059         self.assertEquals(numRcts, len(self.listener.receipts), errMsg + " (#receipts)")
00060 
00061    def assertListenerAfter(self, verb, errMsg="", numMsgs=0, numErrs=0, numRcts=0, timeout=1):
00062         num = numMsgs + numErrs + numRcts
00063         self.listener.reset(num if num>0 else 1)
00064         verb()
00065         self.assertListener(errMsg=errMsg, numMsgs=numMsgs, numErrs=numErrs, numRcts=numRcts, timeout=timeout)
00066 
00067 class WaitableListener(object):
00068 
00069     def __init__(self):
00070         self.debug = False
00071         if self.debug:
00072             print '(listener) init'
00073         self.messages = []
00074         self.errors = []
00075         self.receipts = []
00076         self.latch = Latch(1)
00077 
00078     def on_receipt(self, headers, message):
00079         if self.debug:
00080             print '(on_receipt) message:', message, 'headers:', headers
00081         self.receipts.append({'message' : message, 'headers' : headers})
00082         self.latch.countdown()
00083 
00084     def on_error(self, headers, message):
00085         if self.debug:
00086             print '(on_error) message:', message, 'headers:', headers
00087         self.errors.append({'message' : message, 'headers' : headers})
00088         self.latch.countdown()
00089 
00090     def on_message(self, headers, message):
00091         if self.debug:
00092             print '(on_message) message:', message, 'headers:', headers
00093         self.messages.append({'message' : message, 'headers' : headers})
00094         self.latch.countdown()
00095 
00096     def reset(self, count=1):
00097         if self.debug:
00098             self.print_state('(reset listener--old state)')
00099         self.messages = []
00100         self.errors = []
00101         self.receipts = []
00102         self.latch = Latch(count)
00103         if self.debug:
00104             self.print_state('(reset listener--new state)')
00105 
00106     def await(self, timeout=10):
00107         return self.latch.await(timeout)
00108 
00109     def print_state(self, hdr=""):
00110         print hdr,
00111         print '#messages:', len(self.messages),
00112         print '#errors:', len(self.errors),
00113         print '#receipts:', len(self.receipts),
00114         print 'Remaining count:', self.latch.get_count()
00115 
00116 class Latch(object):
00117 
00118    def __init__(self, count=1):
00119       self.cond = threading.Condition()
00120       self.cond.acquire()
00121       self.count = count
00122       self.cond.release()
00123 
00124    def countdown(self):
00125       self.cond.acquire()
00126       if self.count > 0:
00127          self.count -= 1
00128       if self.count == 0:
00129          self.cond.notify_all()
00130       self.cond.release()
00131 
00132    def await(self, timeout=None):
00133       try:
00134          self.cond.acquire()
00135          if self.count == 0:
00136             return True
00137          else:
00138             self.cond.wait(timeout)
00139             return self.count == 0
00140       finally:
00141          self.cond.release()
00142 
00143    def get_count(self):
00144       try:
00145           self.cond.acquire()
00146           return self.count
00147       finally:
00148           self.cond.release()