Skip to content

Commit b33469e

Browse files
committed
Use ContextVar.get_changed() for decimal context.
Ensure that decimal.getcontext() returns a per-task copy of the decimal.Context() so that mutations are isolated between async tasks and threads using sys.flags.thread_inherit_context.
1 parent 0a20e79 commit b33469e

File tree

4 files changed

+81
-6
lines changed

4 files changed

+81
-6
lines changed

Lib/_pydecimal.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,19 @@ def getcontext():
361361
New contexts are copies of DefaultContext.
362362
"""
363363
try:
364-
return _current_context_var.get()
364+
context, changed = _current_context_var.get_changed()
365365
except LookupError:
366366
context = Context()
367367
_current_context_var.set(context)
368368
return context
369+
if not changed:
370+
# The context value was inherited from another task/thread. Because
371+
# the Context() instance is mutable, copy it to ensure that if it is
372+
# changed, those changes are isolated from other tasks/threads.
373+
context = context.copy()
374+
_current_context_var.set(context)
375+
return context
376+
369377

370378
def setcontext(context):
371379
"""Set this thread's context to context."""

Lib/test/test_decimal.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,59 @@ def test_threading(self):
17701770
DefaultContext.Emax = save_emax
17711771
DefaultContext.Emin = save_emin
17721772

1773+
@threading_helper.requires_working_threading()
1774+
def test_inherited_context_isolation(self):
1775+
# Test that when threads inherit contextvars (e.g. via
1776+
# sys.flags.thread_inherit_context), each thread gets its own
1777+
# copy of the decimal context so mutations don't leak between
1778+
# threads. Also verifies correct behavior with asyncio tasks.
1779+
Decimal = self.decimal.Decimal
1780+
getcontext = self.decimal.getcontext
1781+
setcontext = self.decimal.setcontext
1782+
Context = self.decimal.Context
1783+
Underflow = self.decimal.Underflow
1784+
1785+
# Set up parent context with specific precision
1786+
parent_ctx = getcontext()
1787+
parent_ctx.prec = 20
1788+
1789+
barrier = threading.Barrier(2, timeout=2)
1790+
results = {}
1791+
1792+
def child(name, prec_delta):
1793+
barrier.wait()
1794+
ctx = getcontext()
1795+
# Each child should see a context with the parent's precision
1796+
results[name + '_initial_prec'] = ctx.prec
1797+
results[name + '_ctx_id'] = id(ctx)
1798+
# Mutate this thread's context
1799+
ctx.prec += prec_delta
1800+
results[name + '_modified_prec'] = ctx.prec
1801+
1802+
# Spawn threads that inherit the parent's contextvars.
1803+
t1 = threading.Thread(target=child, args=('t1', 5),
1804+
context=contextvars.copy_context())
1805+
t2 = threading.Thread(target=child, args=('t2', 10),
1806+
context=contextvars.copy_context())
1807+
t1.start()
1808+
t2.start()
1809+
t1.join()
1810+
t2.join()
1811+
1812+
# Each thread should have started with the parent's precision
1813+
self.assertEqual(results['t1_initial_prec'], 20)
1814+
self.assertEqual(results['t2_initial_prec'], 20)
1815+
1816+
# Each thread should have its own context (different id)
1817+
self.assertNotEqual(results['t1_ctx_id'], results['t2_ctx_id'])
1818+
1819+
# Mutations should be independent
1820+
self.assertEqual(results['t1_modified_prec'], 25)
1821+
self.assertEqual(results['t2_modified_prec'], 30)
1822+
1823+
# Parent context should be unaffected
1824+
self.assertEqual(getcontext().prec, 20)
1825+
17731826

17741827
@requires_cdecimal
17751828
class CThreadingTest(ThreadingTest, unittest.TestCase):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Ensure that :func:`decimal.getcontext` returns a per-task copy of the
2+
:class:`decimal.Context` so that mutations are isolated between asyncio
3+
tasks and threads using :data:`sys.flags.thread_inherit_context`. Added
4+
:meth:`contextvars.ContextVar.get_changed` to support this.

Modules/_decimal/_decimal.c

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,9 +1914,9 @@ PyDec_SetCurrentContext(PyObject *self, PyObject *v)
19141914
}
19151915
#else
19161916
static PyObject *
1917-
init_current_context(decimal_state *state)
1917+
init_current_context(decimal_state *state, PyObject *prev_context)
19181918
{
1919-
PyObject *tl_context = context_copy(state, state->default_context_template);
1919+
PyObject *tl_context = context_copy(state, prev_context);
19201920
if (tl_context == NULL) {
19211921
return NULL;
19221922
}
@@ -1936,15 +1936,25 @@ static inline PyObject *
19361936
current_context(decimal_state *state)
19371937
{
19381938
PyObject *tl_context;
1939-
if (PyContextVar_Get(state->current_context_var, NULL, &tl_context) < 0) {
1939+
int changed;
1940+
if (PyContextVar_GetChanged(state->current_context_var, NULL, &tl_context,
1941+
&changed) < 0) {
19401942
return NULL;
19411943
}
19421944

19431945
if (tl_context != NULL) {
1944-
return tl_context;
1946+
if (!changed) {
1947+
/* inherited context object from another thread for async task */
1948+
PyObject *new_context = init_current_context(state, tl_context);
1949+
Py_DECREF(tl_context);
1950+
return new_context;
1951+
}
1952+
else {
1953+
return tl_context;
1954+
}
19451955
}
19461956

1947-
return init_current_context(state);
1957+
return init_current_context(state, state->default_context_template);
19481958
}
19491959

19501960
/* ctxobj := borrowed reference to the current context */

0 commit comments

Comments
 (0)