88import time
99import sys
1010
11- class TestThreadTaskNode ( InputIteratorThreadTask ):
11+ class _TestTaskBase ( object ):
1212 def __init__ (self , * args , ** kwargs ):
13- super (TestThreadTaskNode , self ).__init__ (* args , ** kwargs )
13+ super (_TestTaskBase , self ).__init__ (* args , ** kwargs )
1414 self .should_fail = False
1515 self .lock = threading .Lock () # yes, can't safely do x = x + 1 :)
1616 self .plock = threading .Lock ()
1717 self .item_count = 0
1818 self .process_count = 0
19- self ._scheduled_items = 0
2019
2120 def do_fun (self , item ):
2221 self .lock .acquire ()
@@ -32,44 +31,118 @@ def process(self, count=1):
3231 self .plock .acquire ()
3332 self .process_count += 1
3433 self .plock .release ()
35- super (TestThreadTaskNode , self ).process (count )
34+ super (_TestTaskBase , self ).process (count )
3635
3736 def _assert (self , pc , fc , check_scheduled = False ):
3837 """Assert for num process counts (pc) and num function counts (fc)
3938 :return: self"""
40- # TODO: fixme
41- return self
42- self .plock .acquire ()
43- if self .process_count != pc :
44- print self .process_count , pc
45- assert self .process_count == pc
46- self .plock .release ()
4739 self .lock .acquire ()
4840 if self .item_count != fc :
4941 print self .item_count , fc
5042 assert self .item_count == fc
5143 self .lock .release ()
5244
53- # if we read all, we can't really use scheduled items
54- if check_scheduled :
55- assert self ._scheduled_items == 0
56- assert not self .error ()
5745 return self
46+
47+ class TestThreadTaskNode (_TestTaskBase , InputIteratorThreadTask ):
48+ pass
5849
5950
6051class TestThreadFailureNode (TestThreadTaskNode ):
6152 """Fails after X items"""
53+ def __init__ (self , * args , ** kwargs ):
54+ self .fail_after = kwargs .pop ('fail_after' )
55+ super (TestThreadFailureNode , self ).__init__ (* args , ** kwargs )
6256
57+ def do_fun (self , item ):
58+ item = TestThreadTaskNode .do_fun (self , item )
59+ if self .item_count > self .fail_after :
60+ raise AssertionError ("Simulated failure after processing %i items" % self .fail_after )
61+ return item
62+
63+
64+ class TestThreadInputChannelTaskNode (_TestTaskBase , InputChannelTask ):
65+ """Apply a transformation on items read from an input channel"""
66+
67+ def do_fun (self , item ):
68+ """return tuple(i, i*2)"""
69+ item = super (TestThreadInputChannelTaskNode , self ).do_fun (item )
70+ if isinstance (item , tuple ):
71+ i = item [0 ]
72+ return item + (i * self .id , )
73+ else :
74+ return (item , item * self .id )
75+ # END handle tuple
76+
77+
78+ class TestThreadInputChannelVerifyTaskNode (_TestTaskBase , InputChannelTask ):
79+ """An input channel task, which verifies the result of its input channels,
80+ should be last in the chain.
81+ Id must be int"""
82+
83+ def do_fun (self , item ):
84+ """return tuple(i, i*2)"""
85+ item = super (TestThreadInputChannelTaskNode , self ).do_fun (item )
86+
87+ # make sure the computation order matches
88+ assert isinstance (item , tuple )
89+
90+ base = item [0 ]
91+ for num in item [1 :]:
92+ assert num == base * 2
93+ base = num
94+ # END verify order
95+
96+ return item
97+
98+
6399
64100class TestThreadPool (TestBase ):
65101
66102 max_threads = cpu_count ()
67103
68- def _add_triple_task (self , p ):
69- """Add a triplet of feeder, transformer and finalizer to the pool, like
70- t1 -> t2 -> t3, return all 3 return channels in order"""
71- # t1 = TestThreadTaskNode(make_task(), 'iterator', None)
72- # TODO:
104+ def _add_task_chain (self , p , ni , count = 1 ):
105+ """Create a task chain of feeder, count transformers and order verifcator
106+ to the pool p, like t1 -> t2 -> t3
107+ :return: tuple(list(task1, taskN, ...), list(rc1, rcN, ...))"""
108+ nt = p .num_tasks ()
109+
110+ feeder = self ._make_iterator_task (ni )
111+ frc = p .add_task (feeder )
112+
113+ assert p .num_tasks () == nt + 1
114+
115+ rcs = [frc ]
116+ tasks = [feeder ]
117+
118+ inrc = frc
119+ for tc in xrange (count ):
120+ t = TestThreadInputChannelTaskNode (inrc , tc , None )
121+ t .fun = t .do_fun
122+ inrc = p .add_task (t )
123+
124+ tasks .append (t )
125+ rcs .append (inrc )
126+ assert p .num_tasks () == nt + 2 + tc
127+ # END create count transformers
128+
129+ verifier = TestThreadInputChannelVerifyTaskNode (inrc , 'verifier' , None )
130+ verifier .fun = verifier .do_fun
131+ vrc = p .add_task (verifier )
132+
133+ assert p .num_tasks () == nt + tc + 3
134+
135+ tasks .append (verifier )
136+ rcs .append (vrc )
137+ return tasks , rcs
138+
139+ def _make_iterator_task (self , ni , taskcls = TestThreadTaskNode , ** kwargs ):
140+ """:return: task which yields ni items
141+ :param taskcls: the actual iterator type to use
142+ :param **kwargs: additional kwargs to be passed to the task"""
143+ t = taskcls (iter (range (ni )), 'iterator' , None , ** kwargs )
144+ t .fun = t .do_fun
145+ return t
73146
74147 def _assert_single_task (self , p , async = False ):
75148 """Performs testing in a synchronized environment"""
@@ -82,11 +155,7 @@ def _assert_single_task(self, p, async=False):
82155 assert ni % 2 == 0 , "ni needs to be dividable by 2"
83156 assert ni % 4 == 0 , "ni needs to be dividable by 4"
84157
85- def make_task ():
86- t = TestThreadTaskNode (iter (range (ni )), 'iterator' , None )
87- t .fun = t .do_fun
88- return t
89- # END utility
158+ make_task = lambda * args , ** kwargs : self ._make_iterator_task (ni , * args , ** kwargs )
90159
91160 task = make_task ()
92161
@@ -252,15 +321,44 @@ def make_task():
252321
253322 # test failure after ni / 2 items
254323 # This makes sure it correctly closes the channel on failure to prevent blocking
324+ nri = ni / 2
325+ task = make_task (TestThreadFailureNode , fail_after = ni / 2 )
326+ rc = p .add_task (task )
327+ assert len (rc .read ()) == nri
328+ assert task .is_done ()
329+ assert isinstance (task .error (), AssertionError )
255330
256331
257332
258- def _assert_async_dependent_tasks (self , p ):
333+ def _assert_async_dependent_tasks (self , pool ):
259334 # includes failure in center task, 'recursive' orphan cleanup
260335 # This will also verify that the channel-close mechanism works
261336 # t1 -> t2 -> t3
262337 # t1 -> x -> t3
263- pass
338+ null_tasks = pool .num_tasks ()
339+ ni = 100
340+ count = 1
341+ make_task = lambda * args , ** kwargs : self ._add_task_chain (pool , ni , count , * args , ** kwargs )
342+
343+ ts , rcs = make_task ()
344+ assert len (ts ) == count + 2
345+ assert len (rcs ) == count + 2
346+ assert pool .num_tasks () == null_tasks + len (ts )
347+ print pool ._tasks .nodes
348+
349+
350+ # in the end, we expect all tasks to be gone, automatically
351+
352+
353+
354+ # order of deletion matters - just keep the end, then delete
355+ final_rc = rcs [- 1 ]
356+ del (ts )
357+ del (rcs )
358+ del (final_rc )
359+ assert pool .num_tasks () == null_tasks
360+
361+
264362
265363 @terminate_threads
266364 def test_base (self ):
@@ -301,8 +399,8 @@ def test_base(self):
301399 assert p .num_tasks () == 0
302400
303401
304- # DEPENDENT TASKS SERIAL
305- ########################
402+ # DEPENDENT TASKS SYNC MODE
403+ ###########################
306404 self ._assert_async_dependent_tasks (p )
307405
308406
@@ -311,12 +409,11 @@ def test_base(self):
311409 # step one gear up - just one thread for now.
312410 p .set_size (1 )
313411 assert p .size () == 1
314- print len (threading .enumerate ()), num_threads
315412 assert len (threading .enumerate ()) == num_threads + 1
316413 # deleting the pool stops its threads - just to be sure ;)
317414 # Its not synchronized, hence we wait a moment
318415 del (p )
319- time .sleep (0.25 )
416+ time .sleep (0.05 )
320417 assert len (threading .enumerate ()) == num_threads
321418
322419 p = ThreadPool (1 )
0 commit comments