Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions src/_deepcopy.c
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ static PyObject* reconstruct_newobj(PyObject* argtup, PyMemoObject* memo) {
return instance;
}

static PyObject* reconstruct_newobj_ex(PyObject* argtup, PyMemoObject* memo) {
static PyObject* reconstruct_newobj_ex(
PyObject* argtup, PyMemoObject* memo, PyTypeObject* reducing_type
) {
if (PyTuple_GET_SIZE(argtup) != 3) {
PyErr_Format(
PyExc_TypeError, "__newobj_ex__ requires 3 arguments, got %zd", PyTuple_GET_SIZE(argtup)
Expand All @@ -487,8 +489,13 @@ static PyObject* reconstruct_newobj_ex(PyObject* argtup, PyMemoObject* memo) {

if (!PyTuple_Check(args)) {
coerced_args = PySequence_Tuple(args);
if (!coerced_args)
if (!coerced_args) {
_chain_type_error(
"__newobj_ex__ args in %s.__reduce__ result must be a tuple, not %.200s",
reducing_type->tp_name, Py_TYPE(args)->tp_name
);
return NULL;
}
args = coerced_args;
}
if (!PyDict_Check(kwargs)) {
Expand All @@ -498,6 +505,10 @@ static PyObject* reconstruct_newobj_ex(PyObject* argtup, PyMemoObject* memo) {
return NULL;
}
if (PyDict_Merge(coerced_kwargs, kwargs, 1) < 0) {
_chain_type_error(
"__newobj_ex__ kwargs in %s.__reduce__ result must be a dict, not %.200s",
reducing_type->tp_name, Py_TYPE(kwargs)->tp_name
);
Py_XDECREF(coerced_args);
Py_DECREF(coerced_kwargs);
return NULL;
Expand Down Expand Up @@ -592,6 +603,12 @@ static int apply_dict_state(PyObject* instance, PyObject* dict_state, PyMemoObje
}
int ret = PyDict_Merge(instance_dict, copied, 1);
Py_DECREF(instance_dict);
if (ret < 0) {
_chain_type_error(
"dict state from %s.__reduce__ must be a dict or mapping, got %.200s",
Py_TYPE(instance)->tp_name, Py_TYPE(copied)->tp_name
);
}
Py_DECREF(copied);
return ret;
}
Expand Down Expand Up @@ -628,9 +645,16 @@ static int apply_slot_state(PyObject* instance, PyObject* slotstate, PyMemoObjec

if (UNLIKELY(!PyDict_Check(copied))) {
PyObject* items = PyObject_CallMethod(copied, "items", NULL);
Py_DECREF(copied);
if (!items)
if (!items) {
_chain_type_error(
"slot state from %s.__reduce__ must be a dict or have an items() method, "
"got %.200s",
Py_TYPE(instance)->tp_name, Py_TYPE(copied)->tp_name
);
Py_DECREF(copied);
return -1;
}
Py_DECREF(copied);

PyObject* iterator = PyObject_GetIter(items);
Py_DECREF(items);
Expand Down Expand Up @@ -838,7 +862,7 @@ static PyObject* deepcopy_object(

PyObject *callable, *argtup, *state, *listitems, *dictitems;
int valid = validate_reduce_tuple(
reduce_result, &callable, &argtup, &state, &listitems, &dictitems
reduce_result, tp, &callable, &argtup, &state, &listitems, &dictitems
);

if (valid == REDUCE_ERROR) {
Expand All @@ -855,7 +879,7 @@ static PyObject* deepcopy_object(
if (callable == module_state.copyreg___newobj__)
instance = reconstruct_newobj(argtup, memo);
else if (callable == module_state.copyreg___newobj___ex)
instance = reconstruct_newobj_ex(argtup, memo);
instance = reconstruct_newobj_ex(argtup, memo, tp);
else
instance = reconstruct_callable(callable, argtup, memo);

Expand Down
35 changes: 29 additions & 6 deletions src/_deepcopy_legacy.c
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ static PyObject* reconstruct_newobj_legacy(
}

static PyObject* reconstruct_newobj_ex_legacy(
PyObject* argtup, PyObject* memo, PyObject** keepalive_pointer
PyObject* argtup, PyObject* memo, PyObject** keepalive_pointer,
PyTypeObject* reducing_type
) {
if (PyTuple_GET_SIZE(argtup) != 3) {
PyErr_Format(
Expand All @@ -528,8 +529,13 @@ static PyObject* reconstruct_newobj_ex_legacy(

if (!PyTuple_Check(args)) {
coerced_args = PySequence_Tuple(args);
if (!coerced_args)
if (!coerced_args) {
_chain_type_error(
"__newobj_ex__ args in %s.__reduce__ result must be a tuple, not %.200s",
reducing_type->tp_name, Py_TYPE(args)->tp_name
);
return NULL;
}
args = coerced_args;
}
if (!PyDict_Check(kwargs)) {
Expand All @@ -539,6 +545,10 @@ static PyObject* reconstruct_newobj_ex_legacy(
return NULL;
}
if (PyDict_Merge(coerced_kwargs, kwargs, 1) < 0) {
_chain_type_error(
"__newobj_ex__ kwargs in %s.__reduce__ result must be a dict, not %.200s",
reducing_type->tp_name, Py_TYPE(kwargs)->tp_name
);
Py_XDECREF(coerced_args);
Py_DECREF(coerced_kwargs);
return NULL;
Expand Down Expand Up @@ -639,6 +649,12 @@ static int apply_dict_state_legacy(
}
int ret = PyDict_Merge(instance_dict, copied, 1);
Py_DECREF(instance_dict);
if (ret < 0) {
_chain_type_error(
"dict state from %s.__reduce__ must be a dict or mapping, got %.200s",
Py_TYPE(instance)->tp_name, Py_TYPE(copied)->tp_name
);
}
Py_DECREF(copied);
return ret;
}
Expand Down Expand Up @@ -677,9 +693,16 @@ static int apply_slot_state_legacy(

if (UNLIKELY(!PyDict_Check(copied))) {
PyObject* items = PyObject_CallMethod(copied, "items", NULL);
Py_DECREF(copied);
if (!items)
if (!items) {
_chain_type_error(
"slot state from %s.__reduce__ must be a dict or have an items() method, "
"got %.200s",
Py_TYPE(instance)->tp_name, Py_TYPE(copied)->tp_name
);
Py_DECREF(copied);
return -1;
}
Py_DECREF(copied);

PyObject* iterator = PyObject_GetIter(items);
Py_DECREF(items);
Expand Down Expand Up @@ -893,7 +916,7 @@ static PyObject* deepcopy_object_legacy(

PyObject *callable, *argtup, *state, *listitems, *dictitems;
int valid = validate_reduce_tuple(
reduce_result, &callable, &argtup, &state, &listitems, &dictitems
reduce_result, tp, &callable, &argtup, &state, &listitems, &dictitems
);

if (valid == REDUCE_ERROR) {
Expand All @@ -910,7 +933,7 @@ static PyObject* deepcopy_object_legacy(
if (callable == module_state.copyreg___newobj__)
instance = reconstruct_newobj_legacy(argtup, memo, keepalive_pointer);
else if (callable == module_state.copyreg___newobj___ex)
instance = reconstruct_newobj_ex_legacy(argtup, memo, keepalive_pointer);
instance = reconstruct_newobj_ex_legacy(argtup, memo, keepalive_pointer, tp);
else
instance = reconstruct_callable_legacy(callable, argtup, memo, keepalive_pointer);

Expand Down
48 changes: 47 additions & 1 deletion src/_reduce_helpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,48 @@ static PyObject* call_reduce_method_preferring_ex(PyObject* obj) {
return NULL;
}

static void _chain_type_error(const char* fmt, ...) {
PyObject *cause_type, *cause_val, *cause_tb;
PyErr_Fetch(&cause_type, &cause_val, &cause_tb);

va_list vargs;
va_start(vargs, fmt);
PyObject* msg = PyUnicode_FromFormatV(fmt, vargs);
va_end(vargs);

if (!msg) {
PyErr_Restore(cause_type, cause_val, cause_tb);
return;
}

if (cause_val)
PyErr_NormalizeException(&cause_type, &cause_val, &cause_tb);

PyObject* new_exc = PyObject_CallOneArg(PyExc_TypeError, msg);
Py_DECREF(msg);

if (!new_exc) {
PyErr_Restore(cause_type, cause_val, cause_tb);
return;
}

if (cause_val) {
PyException_SetCause(cause_val, new_exc);

PyErr_Restore(cause_type, cause_val, cause_tb);
return;
}

Py_XDECREF(cause_type);
Py_XDECREF(cause_tb);

PyErr_SetObject(PyExc_TypeError, new_exc);
Py_DECREF(new_exc);
}

static int validate_reduce_tuple(
PyObject* reduce_result,
PyTypeObject* reducing_type,
PyObject** out_callable,
PyObject** out_argtup,
PyObject** out_state,
Expand Down Expand Up @@ -97,8 +137,14 @@ static int validate_reduce_tuple(

if (!PyTuple_Check(argtup)) {
PyObject* coerced = PySequence_Tuple(argtup);
if (!coerced)
if (!coerced) {
_chain_type_error(
"second element of the tuple returned by %s.__reduce__ "
"must be a tuple, not %.200s",
reducing_type->tp_name, Py_TYPE(argtup)->tp_name
);
return REDUCE_ERROR;
}
PyObject* old = argtup;
PyTuple_SET_ITEM(reduce_result, 1, coerced);
Py_DECREF(old);
Expand Down
2 changes: 1 addition & 1 deletion src/copium.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ PyObject* py_copy(PyObject* self, PyObject* obj) {

PyObject *constructor = NULL, *args = NULL, *state = NULL, *listiter = NULL, *dictiter = NULL;
int unpack_result = validate_reduce_tuple(
reduce_result, &constructor, &args, &state, &listiter, &dictiter
reduce_result, obj_type, &constructor, &args, &state, &listiter, &dictiter
);
if (unpack_result == REDUCE_ERROR) {
Py_DECREF(reduce_result);
Expand Down
79 changes: 79 additions & 0 deletions tests/test_reduce_errors_causes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Test that on top of errors equivalent to stdlib, copium also chains descriptive __cause__.
"""

import copyreg

import pytest

import copium


class ArgtupNotIterable:
def __reduce__(self):
return (type(self), 42)


class NewObjExArgsNotIterable:
def __reduce__(self):
return (copyreg.__newobj_ex__, (type(self), 42, {}))


class NewObjExKwargsNotMapping:
def __reduce__(self):
return (copyreg.__newobj_ex__, (type(self), (), 42))


class DictStateNotMapping:
def __reduce__(self):
return (type(self), (), [1, 2, 3])


class SlotStateNoItems:
def __reduce__(self):
return (type(self), (), ({}, 42))


@pytest.mark.parametrize(
"obj, cause_message",
[
pytest.param(
ArgtupNotIterable(),
"second element of the tuple returned by ArgtupNotIterable.__reduce__"
" must be a tuple, not int",
id="argtup-not-iterable",
),
pytest.param(
NewObjExArgsNotIterable(),
"__newobj_ex__ args in NewObjExArgsNotIterable.__reduce__ result"
" must be a tuple, not int",
id="newobj_ex-args-not-iterable",
),
pytest.param(
NewObjExKwargsNotMapping(),
"__newobj_ex__ kwargs in NewObjExKwargsNotMapping.__reduce__ result"
" must be a dict, not int",
id="newobj_ex-kwargs-not-mapping",
),
pytest.param(
DictStateNotMapping(),
"dict state from DictStateNotMapping.__reduce__"
" must be a dict or mapping, got list",
id="dict-state-not-mapping",
),
pytest.param(
SlotStateNoItems(),
"slot state from SlotStateNoItems.__reduce__"
" must be a dict or have an items() method, got int",
id="slot-state-no-items",
),
],
)
def test_reduce_chained_type_error(obj, cause_message):
with pytest.raises(Exception) as exc_info:
copium.deepcopy(obj)

cause = exc_info.value.__cause__
assert cause is not None, f"expected chained __cause__, got bare {exc_info.value!r}"
assert isinstance(cause, TypeError), f"expected TypeError cause, got {type(cause).__name__}: {cause}"
assert str(cause) == str(cause_message)