Skip to content

Commit 73c055f

Browse files
committed
fix JSON caster so it catches circular reference, add AeroMode bindings
1 parent 094b5bf commit 73c055f

File tree

4 files changed

+71
-47
lines changed

4 files changed

+71
-47
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ set(PyPartMC_sources
6969
json_resource.cpp
7070
bin_grid.F90
7171
bin_grid.cpp
72+
aero_mode.F90
7273
# json_resource.cpp spec_file_pypartmc.cpp sys.cpp
7374
# run_part.F90 run_part_opt.F90 util.F90 aero_data.F90 aero_state.F90 env_state.F90 gas_data.F90
7475
# gas_state.F90 scenario.F90 condense.F90 aero_particle.F90 bin_grid.F90

src/aero_mode.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#pragma once
88

99
#include "pmc_resource.hpp"
10-
#include "pybind11/stl.h"
1110
#include "aero_data.hpp"
1211
#include "bin_grid.hpp"
1312

src/pypartmc.cpp

Lines changed: 69 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "nanobind/nanobind.h"
99
#include "nanobind/stl/vector.h"
1010
#include "nanobind/stl/string.h"
11-
#include <nanobind/ndarray.h>
11+
#include "nanobind/ndarray.h"
1212
#include "nlohmann/json.hpp"
1313
// #include "nanobind_json/nanobind_json.hpp"
1414
// #include "pybind11_json/pybind11_json.hpp"
@@ -22,7 +22,7 @@
2222
// #include "run_part_opt.hpp"
2323
#include "aero_data.hpp"
2424
// #include "aero_dist.hpp"
25-
// #include "aero_mode.hpp"
25+
#include "aero_mode.hpp"
2626
// #include "aero_state.hpp"
2727
// #include "env_state.hpp"
2828
// // #include "gas_data.hpp"
@@ -39,12 +39,9 @@
3939
#define STRINGIFY(x) #x
4040
#define MACRO_STRINGIFY(x) STRINGIFY(x)
4141

42-
// namespace py = pybind11;
43-
4442
namespace nb = nanobind;
4543
namespace nl = nlohmann;
4644

47-
#include <iostream>
4845
namespace pyjson
4946
{
5047
inline nb::handle from_json(const nl::json& j)
@@ -91,7 +88,7 @@ namespace pyjson
9188
}
9289
}
9390

94-
nl::json to_json(const nb::handle& obj)
91+
inline nl::json to_json(const nb::handle& obj, std::set<const PyObject*> prevs)
9592
{
9693
if (obj.ptr() == nullptr || obj.is_none())
9794
{
@@ -109,32 +106,43 @@ namespace pyjson
109106
{
110107
return nb::cast<double>(obj);
111108
}
112-
// if (nb::isinstance<nb::bytes>(obj))
113-
// {
114-
// nb::module base64 = nb::module::import("base64");
115-
// // return base64.attr("b64encode")(obj).attr("decode")("utf-8").cast<std::string>();
116-
// return nb::cast<std::string>(base64.attr("b64encode")(obj).attr("decode")("utf-8"));
117-
// }
109+
if (nb::isinstance<nb::bytes>(obj))
110+
{
111+
nb::module_ base64 = nb::module_::import_("base64");
112+
return nb::cast<std::string>(base64.attr("b64encode")(obj).attr("decode")("utf-8"));
113+
}
118114
if (nb::isinstance<nb::str>(obj))
119115
{
120116
return nb::cast<std::string>(obj);
121117
}
122118
if (nb::isinstance<nb::tuple>(obj) || nb::isinstance<nb::list>(obj))
123119
{
120+
auto insert_return = prevs.insert(obj.ptr());
121+
if (!insert_return.second) {
122+
throw std::runtime_error("Circular reference detected");
123+
}
124+
124125
auto out = nl::json::array();
126+
125127
for (const nb::handle value : obj)
126128
{
127-
out.push_back(to_json(value));
129+
out.push_back(to_json(value, prevs));
128130
}
129131

130132
return out;
131133
}
132-
if (nb::isinstance<nb::dict>(obj) || nb::isinstance<nb::tuple>(obj) || nb::isinstance<nb::list>(obj))
134+
if (nb::isinstance<nb::dict>(obj))
133135
{
136+
auto insert_return = prevs.insert(obj.ptr());
137+
if (!insert_return.second) {
138+
throw std::runtime_error("Circular reference detected");
139+
}
140+
134141
auto out = nl::json::object();
142+
135143
for (const nb::handle key : obj)
136144
{
137-
out[nb::cast<std::string>(nb::str(key))] = to_json(obj[key]);
145+
out[nb::cast<std::string>(nb::str(key))] = to_json(obj[key], prevs);
138146
}
139147
return out;
140148
}
@@ -208,7 +216,7 @@ namespace nanobind
208216

209217
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
210218
try {
211-
value = pyjson::to_json(src);
219+
value = pyjson::to_json(src, std::set<const PyObject*>());
212220
return true;
213221
}
214222
catch (...)
@@ -225,14 +233,30 @@ namespace nanobind
225233
};
226234

227235
template <typename Type> struct type_caster<std::valarray<Type>> {
228-
NB_TYPE_CASTER(std::valarray<Type>, const_name("[") +
229-
const_name("valarray") +
230-
const_name("]"))
236+
NB_TYPE_CASTER(std::valarray<Type>, const_name("[") + const_name("std::valarray") + const_name("]"))
231237

232238
using Caster = make_caster<Type>;
233239

234240
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
235-
if (nb::isinstance<nb::list>(src) || nb::isinstance<nb::ndarray<Type, nb::ndim<1>>>(src)) {
241+
if (nb::isinstance<nb::list>(src)) {
242+
try {
243+
auto py_array = nb::cast<nb::list>(src);
244+
size_t size = py_array.size();
245+
246+
value.resize(size);
247+
248+
for (size_t i = 0; i < size; i++) {
249+
value[i] = nb::cast<Type>(py_array[i]);
250+
}
251+
252+
return true;
253+
}
254+
catch (...) {
255+
PyErr_Clear();
256+
return false;
257+
}
258+
}
259+
else if (nb::isinstance<nb::ndarray<Type, nb::ndim<1>>>(src)) {
236260
try {
237261
auto py_array = nb::cast<nb::ndarray<Type, nb::ndim<1>>>(src);
238262
size_t size = py_array.size();
@@ -668,31 +692,31 @@ NB_MODULE(_PyPartMC, m) {
668692
.def_prop_ro("centers", BinGrid::centers, "Bin centers")
669693
;
670694

671-
// py::class_<AeroMode>(m,"AeroMode")
672-
// .def(py::init<AeroData&, const nlohmann::json&>())
673-
// .def_property("num_conc", &AeroMode::get_num_conc, &AeroMode::set_num_conc,
674-
// "provides access (read or write) to the total number concentration of a mode")
675-
// .def("num_dist", &AeroMode::num_dist,
676-
// "returns the binned number concenration of a mode")
677-
// .def_property("vol_frac", &AeroMode::get_vol_frac,
678-
// &AeroMode::set_vol_frac, "Species fractions by volume")
679-
// .def_property("vol_frac_std", &AeroMode::get_vol_frac_std,
680-
// &AeroMode::set_vol_frac_std, "Species fraction standard deviation")
681-
// .def_property("char_radius", &AeroMode::get_char_radius,
682-
// &AeroMode::set_char_radius,
683-
// "Characteristic radius, with meaning dependent on mode type (m)")
684-
// .def_property("gsd", &AeroMode::get_gsd,
685-
// &AeroMode::set_gsd, "Geometric standard deviation")
686-
// .def("set_sample", &AeroMode::set_sampled)
687-
// .def_property_readonly("sample_num_conc", &AeroMode::get_sample_num_conc,
688-
// "Sample bin number concentrations (m^{-3})")
689-
// .def_property_readonly("sample_radius", &AeroMode::get_sample_radius,
690-
// "Sample bin radii (m).")
691-
// .def_property("type", &AeroMode::get_type, &AeroMode::set_type,
692-
// "Mode type (given by module constants)")
693-
// .def_property("name", &AeroMode::get_name, &AeroMode::set_name,
694-
// "Mode name, used to track particle sources")
695-
// ;
695+
nb::class_<AeroMode>(m,"AeroMode")
696+
.def(nb::init<AeroData&, const nlohmann::json&>())
697+
.def_prop_rw("num_conc", &AeroMode::get_num_conc, &AeroMode::set_num_conc,
698+
"provides access (read or write) to the total number concentration of a mode")
699+
.def("num_dist", &AeroMode::num_dist,
700+
"returns the binned number concenration of a mode")
701+
.def_prop_rw("vol_frac", &AeroMode::get_vol_frac,
702+
&AeroMode::set_vol_frac, "Species fractions by volume")
703+
.def_prop_rw("vol_frac_std", &AeroMode::get_vol_frac_std,
704+
&AeroMode::set_vol_frac_std, "Species fraction standard deviation")
705+
.def_prop_rw("char_radius", &AeroMode::get_char_radius,
706+
&AeroMode::set_char_radius,
707+
"Characteristic radius, with meaning dependent on mode type (m)")
708+
.def_prop_rw("gsd", &AeroMode::get_gsd,
709+
&AeroMode::set_gsd, "Geometric standard deviation")
710+
.def("set_sample", &AeroMode::set_sampled)
711+
.def_prop_ro("sample_num_conc", &AeroMode::get_sample_num_conc,
712+
"Sample bin number concentrations (m^{-3})")
713+
.def_prop_ro("sample_radius", &AeroMode::get_sample_radius,
714+
"Sample bin radii (m).")
715+
.def_prop_rw("type", &AeroMode::get_type, &AeroMode::set_type,
716+
"Mode type (given by module constants)")
717+
.def_prop_rw("name", &AeroMode::get_name, &AeroMode::set_name,
718+
"Mode name, used to track particle sources")
719+
;
696720

697721
// py::class_<AeroDist>(m,"AeroDist")
698722
// .def(py::init<std::shared_ptr<AeroData>, const nlohmann::json&>())

tests/test_aero_mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def test_fixed_segfault_case_on_circular_reference():
305305
ppmc.AeroMode(aero_data, fishy_ctor_arg)
306306

307307
# assert
308-
assert "incompatible constructor arguments" in str(exc_info.value)
308+
assert "incompatible function arguments" in str(exc_info.value)
309309

310310
@staticmethod
311311
@pytest.mark.skipif(platform.machine() == "arm64", reason="TODO #348")

0 commit comments

Comments
 (0)