From 62d5c6e4660f66a754b2c3fbe96386d58c20248a Mon Sep 17 00:00:00 2001 From: Bobronium Date: Fri, 13 Feb 2026 13:18:50 +0700 Subject: [PATCH] Add descriptive error causes for reduce values errors --- src/_deepcopy.c | 36 +++++++++++--- src/_deepcopy_legacy.c | 35 ++++++++++--- src/_reduce_helpers.c | 48 +++++++++++++++++- src/copium.c | 2 +- tests/test_reduce_errors_causes.py | 79 ++++++++++++++++++++++++++++++ 5 files changed, 186 insertions(+), 14 deletions(-) create mode 100644 tests/test_reduce_errors_causes.py diff --git a/src/_deepcopy.c b/src/_deepcopy.c index 18dc8a3..bfad853 100644 --- a/src/_deepcopy.c +++ b/src/_deepcopy.c @@ -463,7 +463,9 @@ static PyObject* reconstruct_newobj(PyObject* argtup, PyMemoObject* memo) { return instance; } -static PyObject* reconstruct_newobj_ex(PyObject* argtup, PyMemoObject* memo) { +static PyObject* reconstruct_newobj_ex( + PyObject* argtup, PyMemoObject* memo, PyTypeObject* reducing_type +) { if (PyTuple_GET_SIZE(argtup) != 3) { PyErr_Format( PyExc_TypeError, "__newobj_ex__ requires 3 arguments, got %zd", PyTuple_GET_SIZE(argtup) @@ -487,8 +489,13 @@ static PyObject* reconstruct_newobj_ex(PyObject* argtup, PyMemoObject* memo) { if (!PyTuple_Check(args)) { coerced_args = PySequence_Tuple(args); - if (!coerced_args) + if (!coerced_args) { + _chain_type_error( + "__newobj_ex__ args in %s.__reduce__ result must be a tuple, not %.200s", + reducing_type->tp_name, Py_TYPE(args)->tp_name + ); return NULL; + } args = coerced_args; } if (!PyDict_Check(kwargs)) { @@ -498,6 +505,10 @@ static PyObject* reconstruct_newobj_ex(PyObject* argtup, PyMemoObject* memo) { return NULL; } if (PyDict_Merge(coerced_kwargs, kwargs, 1) < 0) { + _chain_type_error( + "__newobj_ex__ kwargs in %s.__reduce__ result must be a dict, not %.200s", + reducing_type->tp_name, Py_TYPE(kwargs)->tp_name + ); Py_XDECREF(coerced_args); Py_DECREF(coerced_kwargs); return NULL; @@ -592,6 +603,12 @@ static int apply_dict_state(PyObject* instance, PyObject* dict_state, PyMemoObje } int ret = PyDict_Merge(instance_dict, copied, 1); Py_DECREF(instance_dict); + if (ret < 0) { + _chain_type_error( + "dict state from %s.__reduce__ must be a dict or mapping, got %.200s", + Py_TYPE(instance)->tp_name, Py_TYPE(copied)->tp_name + ); + } Py_DECREF(copied); return ret; } @@ -628,9 +645,16 @@ static int apply_slot_state(PyObject* instance, PyObject* slotstate, PyMemoObjec if (UNLIKELY(!PyDict_Check(copied))) { PyObject* items = PyObject_CallMethod(copied, "items", NULL); - Py_DECREF(copied); - if (!items) + if (!items) { + _chain_type_error( + "slot state from %s.__reduce__ must be a dict or have an items() method, " + "got %.200s", + Py_TYPE(instance)->tp_name, Py_TYPE(copied)->tp_name + ); + Py_DECREF(copied); return -1; + } + Py_DECREF(copied); PyObject* iterator = PyObject_GetIter(items); Py_DECREF(items); @@ -838,7 +862,7 @@ static PyObject* deepcopy_object( PyObject *callable, *argtup, *state, *listitems, *dictitems; int valid = validate_reduce_tuple( - reduce_result, &callable, &argtup, &state, &listitems, &dictitems + reduce_result, tp, &callable, &argtup, &state, &listitems, &dictitems ); if (valid == REDUCE_ERROR) { @@ -855,7 +879,7 @@ static PyObject* deepcopy_object( if (callable == module_state.copyreg___newobj__) instance = reconstruct_newobj(argtup, memo); else if (callable == module_state.copyreg___newobj___ex) - instance = reconstruct_newobj_ex(argtup, memo); + instance = reconstruct_newobj_ex(argtup, memo, tp); else instance = reconstruct_callable(callable, argtup, memo); diff --git a/src/_deepcopy_legacy.c b/src/_deepcopy_legacy.c index f1b7ecf..6400e69 100644 --- a/src/_deepcopy_legacy.c +++ b/src/_deepcopy_legacy.c @@ -503,7 +503,8 @@ static PyObject* reconstruct_newobj_legacy( } static PyObject* reconstruct_newobj_ex_legacy( - PyObject* argtup, PyObject* memo, PyObject** keepalive_pointer + PyObject* argtup, PyObject* memo, PyObject** keepalive_pointer, + PyTypeObject* reducing_type ) { if (PyTuple_GET_SIZE(argtup) != 3) { PyErr_Format( @@ -528,8 +529,13 @@ static PyObject* reconstruct_newobj_ex_legacy( if (!PyTuple_Check(args)) { coerced_args = PySequence_Tuple(args); - if (!coerced_args) + if (!coerced_args) { + _chain_type_error( + "__newobj_ex__ args in %s.__reduce__ result must be a tuple, not %.200s", + reducing_type->tp_name, Py_TYPE(args)->tp_name + ); return NULL; + } args = coerced_args; } if (!PyDict_Check(kwargs)) { @@ -539,6 +545,10 @@ static PyObject* reconstruct_newobj_ex_legacy( return NULL; } if (PyDict_Merge(coerced_kwargs, kwargs, 1) < 0) { + _chain_type_error( + "__newobj_ex__ kwargs in %s.__reduce__ result must be a dict, not %.200s", + reducing_type->tp_name, Py_TYPE(kwargs)->tp_name + ); Py_XDECREF(coerced_args); Py_DECREF(coerced_kwargs); return NULL; @@ -639,6 +649,12 @@ static int apply_dict_state_legacy( } int ret = PyDict_Merge(instance_dict, copied, 1); Py_DECREF(instance_dict); + if (ret < 0) { + _chain_type_error( + "dict state from %s.__reduce__ must be a dict or mapping, got %.200s", + Py_TYPE(instance)->tp_name, Py_TYPE(copied)->tp_name + ); + } Py_DECREF(copied); return ret; } @@ -677,9 +693,16 @@ static int apply_slot_state_legacy( if (UNLIKELY(!PyDict_Check(copied))) { PyObject* items = PyObject_CallMethod(copied, "items", NULL); - Py_DECREF(copied); - if (!items) + if (!items) { + _chain_type_error( + "slot state from %s.__reduce__ must be a dict or have an items() method, " + "got %.200s", + Py_TYPE(instance)->tp_name, Py_TYPE(copied)->tp_name + ); + Py_DECREF(copied); return -1; + } + Py_DECREF(copied); PyObject* iterator = PyObject_GetIter(items); Py_DECREF(items); @@ -893,7 +916,7 @@ static PyObject* deepcopy_object_legacy( PyObject *callable, *argtup, *state, *listitems, *dictitems; int valid = validate_reduce_tuple( - reduce_result, &callable, &argtup, &state, &listitems, &dictitems + reduce_result, tp, &callable, &argtup, &state, &listitems, &dictitems ); if (valid == REDUCE_ERROR) { @@ -910,7 +933,7 @@ static PyObject* deepcopy_object_legacy( if (callable == module_state.copyreg___newobj__) instance = reconstruct_newobj_legacy(argtup, memo, keepalive_pointer); else if (callable == module_state.copyreg___newobj___ex) - instance = reconstruct_newobj_ex_legacy(argtup, memo, keepalive_pointer); + instance = reconstruct_newobj_ex_legacy(argtup, memo, keepalive_pointer, tp); else instance = reconstruct_callable_legacy(callable, argtup, memo, keepalive_pointer); diff --git a/src/_reduce_helpers.c b/src/_reduce_helpers.c index 9cd49b7..08c4644 100644 --- a/src/_reduce_helpers.c +++ b/src/_reduce_helpers.c @@ -64,8 +64,48 @@ static PyObject* call_reduce_method_preferring_ex(PyObject* obj) { return NULL; } +static void _chain_type_error(const char* fmt, ...) { + PyObject *cause_type, *cause_val, *cause_tb; + PyErr_Fetch(&cause_type, &cause_val, &cause_tb); + + va_list vargs; + va_start(vargs, fmt); + PyObject* msg = PyUnicode_FromFormatV(fmt, vargs); + va_end(vargs); + + if (!msg) { + PyErr_Restore(cause_type, cause_val, cause_tb); + return; + } + + if (cause_val) + PyErr_NormalizeException(&cause_type, &cause_val, &cause_tb); + + PyObject* new_exc = PyObject_CallOneArg(PyExc_TypeError, msg); + Py_DECREF(msg); + + if (!new_exc) { + PyErr_Restore(cause_type, cause_val, cause_tb); + return; + } + + if (cause_val) { + PyException_SetCause(cause_val, new_exc); + + PyErr_Restore(cause_type, cause_val, cause_tb); + return; + } + + Py_XDECREF(cause_type); + Py_XDECREF(cause_tb); + + PyErr_SetObject(PyExc_TypeError, new_exc); + Py_DECREF(new_exc); +} + static int validate_reduce_tuple( PyObject* reduce_result, + PyTypeObject* reducing_type, PyObject** out_callable, PyObject** out_argtup, PyObject** out_state, @@ -97,8 +137,14 @@ static int validate_reduce_tuple( if (!PyTuple_Check(argtup)) { PyObject* coerced = PySequence_Tuple(argtup); - if (!coerced) + if (!coerced) { + _chain_type_error( + "second element of the tuple returned by %s.__reduce__ " + "must be a tuple, not %.200s", + reducing_type->tp_name, Py_TYPE(argtup)->tp_name + ); return REDUCE_ERROR; + } PyObject* old = argtup; PyTuple_SET_ITEM(reduce_result, 1, coerced); Py_DECREF(old); diff --git a/src/copium.c b/src/copium.c index 45fcbe0..140f290 100644 --- a/src/copium.c +++ b/src/copium.c @@ -115,7 +115,7 @@ PyObject* py_copy(PyObject* self, PyObject* obj) { PyObject *constructor = NULL, *args = NULL, *state = NULL, *listiter = NULL, *dictiter = NULL; int unpack_result = validate_reduce_tuple( - reduce_result, &constructor, &args, &state, &listiter, &dictiter + reduce_result, obj_type, &constructor, &args, &state, &listiter, &dictiter ); if (unpack_result == REDUCE_ERROR) { Py_DECREF(reduce_result); diff --git a/tests/test_reduce_errors_causes.py b/tests/test_reduce_errors_causes.py new file mode 100644 index 0000000..f119c07 --- /dev/null +++ b/tests/test_reduce_errors_causes.py @@ -0,0 +1,79 @@ +""" +Test that on top of errors equivalent to stdlib, copium also chains descriptive __cause__. +""" + +import copyreg + +import pytest + +import copium + + +class ArgtupNotIterable: + def __reduce__(self): + return (type(self), 42) + + +class NewObjExArgsNotIterable: + def __reduce__(self): + return (copyreg.__newobj_ex__, (type(self), 42, {})) + + +class NewObjExKwargsNotMapping: + def __reduce__(self): + return (copyreg.__newobj_ex__, (type(self), (), 42)) + + +class DictStateNotMapping: + def __reduce__(self): + return (type(self), (), [1, 2, 3]) + + +class SlotStateNoItems: + def __reduce__(self): + return (type(self), (), ({}, 42)) + + +@pytest.mark.parametrize( + "obj, cause_message", + [ + pytest.param( + ArgtupNotIterable(), + "second element of the tuple returned by ArgtupNotIterable.__reduce__" + " must be a tuple, not int", + id="argtup-not-iterable", + ), + pytest.param( + NewObjExArgsNotIterable(), + "__newobj_ex__ args in NewObjExArgsNotIterable.__reduce__ result" + " must be a tuple, not int", + id="newobj_ex-args-not-iterable", + ), + pytest.param( + NewObjExKwargsNotMapping(), + "__newobj_ex__ kwargs in NewObjExKwargsNotMapping.__reduce__ result" + " must be a dict, not int", + id="newobj_ex-kwargs-not-mapping", + ), + pytest.param( + DictStateNotMapping(), + "dict state from DictStateNotMapping.__reduce__" + " must be a dict or mapping, got list", + id="dict-state-not-mapping", + ), + pytest.param( + SlotStateNoItems(), + "slot state from SlotStateNoItems.__reduce__" + " must be a dict or have an items() method, got int", + id="slot-state-no-items", + ), + ], +) +def test_reduce_chained_type_error(obj, cause_message): + with pytest.raises(Exception) as exc_info: + copium.deepcopy(obj) + + cause = exc_info.value.__cause__ + assert cause is not None, f"expected chained __cause__, got bare {exc_info.value!r}" + assert isinstance(cause, TypeError), f"expected TypeError cause, got {type(cause).__name__}: {cause}" + assert str(cause) == str(cause_message)