Skip to content

Commit 4fd006e

Browse files
gh-142752: add more thread safety tests for mock (#142791)
1 parent c35b812 commit 4fd006e

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

Lib/test/test_unittest/testmock/testthreadingmock.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,5 +219,81 @@ def test_function():
219219
self.assertEqual(m.call_count, LOOPS * THREADS)
220220

221221

222+
def test_call_args_thread_safe(self):
223+
m = ThreadingMock()
224+
LOOPS = 100
225+
THREADS = 10
226+
def test_function(thread_id):
227+
for i in range(LOOPS):
228+
m(thread_id, i)
229+
230+
oldswitchinterval = sys.getswitchinterval()
231+
setswitchinterval(1e-6)
232+
try:
233+
threads = [
234+
threading.Thread(target=test_function, args=(thread_id,))
235+
for thread_id in range(THREADS)
236+
]
237+
with threading_helper.start_threads(threads):
238+
pass
239+
finally:
240+
sys.setswitchinterval(oldswitchinterval)
241+
expected_calls = {
242+
(thread_id, i)
243+
for thread_id in range(THREADS)
244+
for i in range(LOOPS)
245+
}
246+
self.assertSetEqual({call.args for call in m.call_args_list}, expected_calls)
247+
248+
def test_method_calls_thread_safe(self):
249+
m = ThreadingMock()
250+
LOOPS = 100
251+
THREADS = 10
252+
def test_function(thread_id):
253+
for i in range(LOOPS):
254+
getattr(m, f"method_{thread_id}")(i)
255+
256+
oldswitchinterval = sys.getswitchinterval()
257+
setswitchinterval(1e-6)
258+
try:
259+
threads = [
260+
threading.Thread(target=test_function, args=(thread_id,))
261+
for thread_id in range(THREADS)
262+
]
263+
with threading_helper.start_threads(threads):
264+
pass
265+
finally:
266+
sys.setswitchinterval(oldswitchinterval)
267+
for thread_id in range(THREADS):
268+
self.assertEqual(getattr(m, f"method_{thread_id}").call_count, LOOPS)
269+
self.assertEqual({call.args for call in getattr(m, f"method_{thread_id}").call_args_list},
270+
{(i,) for i in range(LOOPS)})
271+
272+
def test_mock_calls_thread_safe(self):
273+
m = ThreadingMock()
274+
LOOPS = 100
275+
THREADS = 10
276+
def test_function(thread_id):
277+
for i in range(LOOPS):
278+
m(thread_id, i)
279+
280+
oldswitchinterval = sys.getswitchinterval()
281+
setswitchinterval(1e-6)
282+
try:
283+
threads = [
284+
threading.Thread(target=test_function, args=(thread_id,))
285+
for thread_id in range(THREADS)
286+
]
287+
with threading_helper.start_threads(threads):
288+
pass
289+
finally:
290+
sys.setswitchinterval(oldswitchinterval)
291+
expected_calls = {
292+
(thread_id, i)
293+
for thread_id in range(THREADS)
294+
for i in range(LOOPS)
295+
}
296+
self.assertSetEqual({call.args for call in m.mock_calls}, expected_calls)
297+
222298
if __name__ == "__main__":
223299
unittest.main()

0 commit comments

Comments
 (0)