diff --git a/src/csrc/umath/comparison_ops.cpp b/src/csrc/umath/comparison_ops.cpp index dcee2e5..9924bd3 100644 --- a/src/csrc/umath/comparison_ops.cpp +++ b/src/csrc/umath/comparison_ops.cpp @@ -260,15 +260,31 @@ NPY_NO_EXPORT int comparison_ufunc_promoter(PyObject *ufunc_obj, PyArray_DTypeMeta *const op_dtypes[], PyArray_DTypeMeta *const signature[], PyArray_DTypeMeta *new_op_dtypes[]) { - PyArray_DTypeMeta *new_signature[NPY_MAXARGS]; - memcpy(new_signature, signature, 3 * sizeof(PyArray_DTypeMeta *)); - new_signature[2] = NULL; - int res = quad_ufunc_promoter(ufunc_obj, op_dtypes, new_signature, new_op_dtypes); - if (res < 0) { - return -1; + // Reduction: accumulator is Bool, element is QuadPrecDType, output is Bool + if (op_dtypes[0] == NULL) { + Py_INCREF(&PyArray_BoolDType); + new_op_dtypes[0] = &PyArray_BoolDType; + Py_INCREF(op_dtypes[1]); + new_op_dtypes[1] = op_dtypes[1]; + Py_INCREF(&PyArray_BoolDType); + new_op_dtypes[2] = &PyArray_BoolDType; + return 0; + } + + // Normal path: promote both inputs to QuadPrecDType, output is Bool + for (int i = 0; i < 2; i++) { + if (signature[i]) { + Py_INCREF(signature[i]); + new_op_dtypes[i] = signature[i]; + } + else { + Py_INCREF(&QuadPrecDType); + new_op_dtypes[i] = &QuadPrecDType; + } } + Py_INCREF(&PyArray_BoolDType); - Py_XSETREF(new_op_dtypes[2], &PyArray_BoolDType); + new_op_dtypes[2] = &PyArray_BoolDType; return 0; } diff --git a/src/include/umath/promoters.hpp b/src/include/umath/promoters.hpp index aeafc77..d9545bc 100644 --- a/src/include/umath/promoters.hpp +++ b/src/include/umath/promoters.hpp @@ -16,14 +16,11 @@ quad_ufunc_promoter(PyObject *ufunc_obj, PyArray_DTypeMeta *const op_dtypes[], PyArray_DTypeMeta *const signature[], PyArray_DTypeMeta *new_op_dtypes[]) { PyUFuncObject *ufunc = (PyUFuncObject *)ufunc_obj; - int nin = ufunc->nin; int nargs = ufunc->nargs; - PyArray_DTypeMeta *common = NULL; - bool has_quad = false; // Handle the special case for reductions if (op_dtypes[0] == NULL) { - assert(nin == 2 && ufunc->nout == 1); /* must be reduction */ + assert(ufunc->nin == 2 && ufunc->nout == 1); /* must be reduction */ for (int i = 0; i < 3; i++) { Py_INCREF(op_dtypes[1]); new_op_dtypes[i] = op_dtypes[1]; @@ -31,59 +28,19 @@ quad_ufunc_promoter(PyObject *ufunc_obj, PyArray_DTypeMeta *const op_dtypes[], return 0; } - // Check if any input or signature is QuadPrecision - for (int i = 0; i < nin; i++) { - if (op_dtypes[i] == &QuadPrecDType) { - has_quad = true; - } - } - - if (has_quad) { - common = &QuadPrecDType; - } - else { - for (int i = nin; i < nargs; i++) { - if (signature[i] != NULL) { - if (common == NULL) { - Py_INCREF(signature[i]); - common = signature[i]; - } - else if (common != signature[i]) { - Py_CLEAR(common); // Not homogeneous, unset common - break; - } - } - } - } - // If no common output dtype, use standard promotion for inputs - if (common == NULL) { - common = PyArray_PromoteDTypeSequence(nin, const_cast(op_dtypes)); - if (common == NULL) { - if (PyErr_ExceptionMatches(PyExc_TypeError)) { - PyErr_Clear(); // Do not propagate normal promotion errors - } - - return -1; - } - } - - // Set all new_op_dtypes to the common dtype + // This promoter is only registered for patterns where at least one + // input is QuadPrecDType, so we always promote all args to QuadPrecDType. for (int i = 0; i < nargs; i++) { if (signature[i]) { - // If signature is specified for this argument, use it Py_INCREF(signature[i]); new_op_dtypes[i] = signature[i]; } else { - // Otherwise, use the common dtype - Py_INCREF(common); - - new_op_dtypes[i] = common; + Py_INCREF(&QuadPrecDType); + new_op_dtypes[i] = &QuadPrecDType; } } - Py_XDECREF(common); - return 0; }