Skip to content

feat: numpy scalars #5726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/advanced/pycpp/numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,46 @@ prevent many types of unsupported structures, it is still the user's
responsibility to use only "plain" structures that can be safely manipulated as
raw memory without violating invariants.

Scalar types
============

In some cases we may want to accept or return NumPy scalar values such as
``np.float32`` or ``np.float64``. We hope to be able to handle single-precision
and double-precision on the C-side. However, both are bound to Python's
double-precision builtin float by default, so they cannot be processed separately.
We used the ``py::buffer`` trick to implement the previous approach, which
will cause the readability of the code to drop significantly.

Luckily, there's a helper type for this occasion - ``py::numpy_scalar``:

.. code-block:: cpp

m.def("add", [](py::numpy_scalar<float> a, py::numpy_scalar<float> b) {
return py::make_scalar(a + b);
});
m.def("add", [](py::numpy_scalar<double> a, py::numpy_scalar<double> b) {
return py::make_scalar(a + b);
});

This type is trivially convertible to and from the type it wraps; currently
supported scalar types are NumPy arithmetic types: ``bool_``, ``int8``,
``int16``, ``int32``, ``int64``, ``uint8``, ``uint16``, ``uint32``,
``uint64``, ``float32``, ``float64``, ``complex64``, ``complex128``, all of
them mapping to respective C++ counterparts.

.. note::

``py::numpy_scalar<T>`` strictly matches NumPy scalar types. For example,
``py::numpy_scalar<int64_t>`` will accept ``np.int64(123)``,
but **not** a regular Python ``int`` like ``123``.

.. note::

Native C types are mapped to NumPy types in a platform specific way: for
instance, ``char`` may be mapped to either ``np.int8`` or ``np.uint8``
and ``long`` may use 4 or 8 bytes depending on the platform. Unless you
clearly understand the difference and your needs, please use ``<cstdint>``.

Vectorizing functions
=====================

Expand Down
197 changes: 161 additions & 36 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ PYBIND11_WARNING_DISABLE_MSVC(4127)
class dtype; // Forward declaration
class array; // Forward declaration

template <typename>
struct numpy_scalar; // Forward declaration

PYBIND11_NAMESPACE_BEGIN(detail)

template <>
Expand Down Expand Up @@ -245,6 +248,21 @@ struct npy_api {
NPY_UINT64_
= platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
NPY_FLOAT32_ = platform_lookup<float, double, float, long double>(
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_FLOAT64_ = platform_lookup<double, double, float, long double>(
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_COMPLEX64_
= platform_lookup<std::complex<float>,
std::complex<double>,
std::complex<float>,
std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_COMPLEX128_
= platform_lookup<std::complex<double>,
std::complex<double>,
std::complex<float>,
std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_CHAR_ = std::is_signed<char>::value ? NPY_BYTE_ : NPY_UBYTE_,
};

unsigned int PyArray_RUNTIME_VERSION_;
Expand All @@ -268,6 +286,7 @@ struct npy_api {

unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
PyObject *(*PyArray_DescrFromType_)(int);
PyObject *(*PyArray_TypeObjectFromType_)(int);
PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
PyObject *,
int,
Expand All @@ -284,6 +303,8 @@ struct npy_api {
PyTypeObject *PyVoidArrType_Type_;
PyTypeObject *PyArrayDescr_Type_;
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
Expand All @@ -301,7 +322,10 @@ struct npy_api {
API_PyArrayDescr_Type = 3,
API_PyVoidArrType_Type = 39,
API_PyArray_DescrFromType = 45,
API_PyArray_TypeObjectFromType = 46,
API_PyArray_DescrFromScalar = 57,
API_PyArray_Scalar = 60,
API_PyArray_ScalarAsCtype = 62,
API_PyArray_FromAny = 69,
API_PyArray_Resize = 80,
// CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
Expand Down Expand Up @@ -336,7 +360,10 @@ struct npy_api {
DECL_NPY_API(PyVoidArrType_Type);
DECL_NPY_API(PyArrayDescr_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_TypeObjectFromType);
DECL_NPY_API(PyArray_DescrFromScalar);
DECL_NPY_API(PyArray_Scalar);
DECL_NPY_API(PyArray_ScalarAsCtype);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_Resize);
DECL_NPY_API(PyArray_CopyInto);
Expand All @@ -355,6 +382,83 @@ struct npy_api {
}
};

template <typename T>
struct is_complex : std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> : std::true_type {};

template <typename T, typename = void>
struct npy_format_descriptor_name;

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
static constexpr auto name = const_name<std::is_same<T, bool>::value>(
const_name("numpy.bool"),
const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
+ const_name<sizeof(T) * 8>());
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
static constexpr auto name = const_name < std::is_same<T, float>::value
|| std::is_same<T, const float>::value
|| std::is_same<T, double>::value
|| std::is_same<T, const double>::value
> (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
const_name("numpy.longdouble"));
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
|| std::is_same<typename T::value_type, const float>::value
|| std::is_same<typename T::value_type, double>::value
|| std::is_same<typename T::value_type, const double>::value
> (const_name("numpy.complex")
+ const_name<sizeof(typename T::value_type) * 16>(),
const_name("numpy.longcomplex"));
};

template <typename T>
struct numpy_scalar_info {};

#define PYBIND11_NUMPY_SCALAR_IMPL(ctype_, typenum_) \
template <> \
struct numpy_scalar_info<ctype_> { \
static constexpr auto name = npy_format_descriptor_name<ctype_>::name; \
static constexpr int typenum = npy_api::typenum_##_; \
}

// boolean type
PYBIND11_NUMPY_SCALAR_IMPL(bool, NPY_BOOL);

// character types
PYBIND11_NUMPY_SCALAR_IMPL(char, NPY_CHAR);
PYBIND11_NUMPY_SCALAR_IMPL(signed char, NPY_BYTE);
PYBIND11_NUMPY_SCALAR_IMPL(unsigned char, NPY_UBYTE);

// signed integer types
PYBIND11_NUMPY_SCALAR_IMPL(std::int16_t, NPY_INT16);
PYBIND11_NUMPY_SCALAR_IMPL(std::int32_t, NPY_INT32);
PYBIND11_NUMPY_SCALAR_IMPL(std::int64_t, NPY_INT64);

// unsigned integer types
PYBIND11_NUMPY_SCALAR_IMPL(std::uint16_t, NPY_UINT16);
PYBIND11_NUMPY_SCALAR_IMPL(std::uint32_t, NPY_UINT32);
PYBIND11_NUMPY_SCALAR_IMPL(std::uint64_t, NPY_UINT64);

// floating point types
PYBIND11_NUMPY_SCALAR_IMPL(float, NPY_FLOAT);
PYBIND11_NUMPY_SCALAR_IMPL(double, NPY_DOUBLE);
PYBIND11_NUMPY_SCALAR_IMPL(long double, NPY_LONGDOUBLE);

// complex types
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<float>, NPY_CFLOAT);
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<double>, NPY_CDOUBLE);
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<long double>, NPY_CLONGDOUBLE);

#undef PYBIND11_NUMPY_SCALAR_IMPL

// This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
// This is needed to correctly handle situations where multiple typenums map to the same type,
// e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
Expand Down Expand Up @@ -453,10 +557,6 @@ template <typename T>
struct is_std_array : std::false_type {};
template <typename T, size_t N>
struct is_std_array<std::array<T, N>> : std::true_type {};
template <typename T>
struct is_complex : std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> : std::true_type {};

template <typename T>
struct array_info_scalar {
Expand Down Expand Up @@ -670,8 +770,65 @@ template <typename T, ssize_t Dim>
struct type_caster<unchecked_mutable_reference<T, Dim>>
: type_caster<unchecked_reference<T, Dim>> {};

template <typename T>
struct type_caster<numpy_scalar<T>> {
using value_type = T;
using type_info = numpy_scalar_info<T>;

PYBIND11_TYPE_CASTER(numpy_scalar<T>, type_info::name);

static handle &target_type() {
static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum);
return tp;
}

static handle &target_dtype() {
static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum);
return tp;
}

bool load(handle src, bool) {
if (isinstance(src, target_type())) {
npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value);
return true;
}
return false;
}

static handle cast(numpy_scalar<T> src, return_value_policy, handle) {
return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr);
}
};

PYBIND11_NAMESPACE_END(detail)

template <typename T>
struct numpy_scalar {
using value_type = T;

value_type value;

numpy_scalar() = default;
explicit numpy_scalar(value_type value) : value(value) {}

explicit operator value_type() const { return value; }
numpy_scalar &operator=(value_type value) {
this->value = value;
return *this;
}

friend bool operator==(const numpy_scalar &a, const numpy_scalar &b) {
return a.value == b.value;
}

friend bool operator!=(const numpy_scalar &a, const numpy_scalar &b) { return !(a == b); }
};

template <typename T>
numpy_scalar<T> make_scalar(T value) {
return numpy_scalar<T>(value);
}

class dtype : public object {
public:
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
Expand Down Expand Up @@ -1409,38 +1566,6 @@ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::valu
}
};

template <typename T, typename = void>
struct npy_format_descriptor_name;

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
static constexpr auto name = const_name<std::is_same<T, bool>::value>(
const_name("bool"),
const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
+ const_name<sizeof(T) * 8>());
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
static constexpr auto name = const_name < std::is_same<T, float>::value
|| std::is_same<T, const float>::value
|| std::is_same<T, double>::value
|| std::is_same<T, const double>::value
> (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
const_name("numpy.longdouble"));
};

template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
|| std::is_same<typename T::value_type, const float>::value
|| std::is_same<typename T::value_type, double>::value
|| std::is_same<typename T::value_type, const double>::value
> (const_name("numpy.complex")
+ const_name<sizeof(typename T::value_type) * 16>(),
const_name("numpy.longcomplex"));
};

template <typename T>
struct npy_format_descriptor<
T,
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ set(PYBIND11_TEST_FILES
test_native_enum
test_numpy_array
test_numpy_dtypes
test_numpy_scalars
test_numpy_vectorize
test_opaque_types
test_operator_overloading
Expand Down
63 changes: 63 additions & 0 deletions tests/test_numpy_scalars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
tests/test_numpy_scalars.cpp -- strict NumPy scalars

Copyright (c) 2021 Steve R. Sun

All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/

#include <pybind11/numpy.h>

#include "pybind11_tests.h"

#include <complex>
#include <cstdint>

namespace py = pybind11;

namespace pybind11_test_numpy_scalars {

template <typename T>
struct add {
T x;
explicit add(T x) : x(x) {}
T operator()(T y) const { return static_cast<T>(x + y); }
};

template <typename T, typename F>
void register_test(py::module &m, const char *name, F &&func) {
m.def((std::string("test_") + name).c_str(),
[=](py::numpy_scalar<T> v) {
return std::make_tuple(name, py::make_scalar(static_cast<T>(func(v.value))));
},
py::arg("x"));
}

} // namespace pybind11_test_numpy_scalars

using namespace pybind11_test_numpy_scalars;

TEST_SUBMODULE(numpy_scalars, m) {
using cfloat = std::complex<float>;
using cdouble = std::complex<double>;

register_test<bool>(m, "bool", [](bool x) { return !x; });
register_test<int8_t>(m, "int8", add<int8_t>(-8));
register_test<int16_t>(m, "int16", add<int16_t>(-16));
register_test<int32_t>(m, "int32", add<int32_t>(-32));
register_test<int64_t>(m, "int64", add<int64_t>(-64));
register_test<uint8_t>(m, "uint8", add<uint8_t>(8));
register_test<uint16_t>(m, "uint16", add<uint16_t>(16));
register_test<uint32_t>(m, "uint32", add<uint32_t>(32));
register_test<uint64_t>(m, "uint64", add<uint64_t>(64));
register_test<float>(m, "float32", add<float>(0.125f));
register_test<double>(m, "float64", add<double>(0.25f));
register_test<cfloat>(m, "complex64", add<cfloat>({0, -0.125f}));
register_test<cdouble>(m, "complex128", add<cdouble>({0, -0.25f}));

m.def("test_eq",
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a == b; });
m.def("test_ne",
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a != b; });
}
Loading
Loading