@@ -104,12 +104,12 @@ void CheckContiguousArray(PyArrayObject* arr, string name,
104
104
105
105
// Net constructor
106
106
shared_ptr<Net<Dtype> > Net_Init (string network_file, int phase,
107
- const int level, const bp::object& stages,
107
+ int level, const bp::object& stages,
108
108
const bp::object& weights) {
109
109
CheckFile (network_file);
110
110
111
111
// Convert stages from list to vector
112
- vector<string> stages_vector;
112
+ std:: vector<std:: string> stages_vector;
113
113
if (!stages.is_none ()) {
114
114
for (int i = 0 ; i < len (stages); i++) {
115
115
stages_vector.push_back (bp::extract<string>(stages[i]));
@@ -133,7 +133,8 @@ shared_ptr<Net<Dtype> > Net_Init(string network_file, int phase,
133
133
134
134
// Legacy Net construct-and-load convenience constructor
135
135
shared_ptr<Net<Dtype> > Net_Init_Load (
136
- string param_file, string pretrained_param_file, int phase) {
136
+ string param_file, string pretrained_param_file, int phase,
137
+ int level, const bp::object& stages) {
137
138
LOG (WARNING) << " DEPRECATION WARNING - deprecated use of Python interface" ;
138
139
LOG (WARNING) << " Use this instead (with the named \" weights\" "
139
140
<< " parameter):" ;
@@ -142,8 +143,17 @@ shared_ptr<Net<Dtype> > Net_Init_Load(
142
143
CheckFile (param_file);
143
144
CheckFile (pretrained_param_file);
144
145
146
+ // Convert stages from list to vector
147
+ std::vector<std::string> stages_vector;
148
+ if (!stages.is_none ()) {
149
+ for (int i = 0 ; i < len (stages); i++) {
150
+ stages_vector.push_back (bp::extract<string>(stages[i]));
151
+ }
152
+ }
153
+
145
154
shared_ptr<Net<Dtype> > net (new Net<Dtype>(param_file,
146
- static_cast <Phase>(phase), Caffe::GetDefaultDevice ()));
155
+ static_cast <Phase>(phase), Caffe::GetDefaultDevice (),
156
+ level, &stages_vector));
147
157
net->CopyTrainedLayersFrom (pretrained_param_file);
148
158
return net;
149
159
}
@@ -201,22 +211,31 @@ void Net_SetLayerInputArrays(Net<Dtype>* net, Layer<Dtype>* layer,
201
211
// check that we were passed appropriately-sized contiguous memory
202
212
PyArrayObject* data_arr =
203
213
reinterpret_cast <PyArrayObject*>(data_obj.ptr ());
204
- PyArrayObject* labels_arr =
205
- reinterpret_cast <PyArrayObject*>(labels_obj.ptr ());
206
214
CheckContiguousArray (data_arr, " data array" , md_layer->shape ());
207
- CheckContiguousArray (labels_arr, " labels array" , md_layer->label_shape ());
208
- if (PyArray_DIMS (data_arr)[0 ] != PyArray_DIMS (labels_arr)[0 ]) {
209
- throw std::runtime_error (" data and labels must have the same first"
210
- " dimension" );
211
- }
212
215
if (PyArray_DIMS (data_arr)[0 ] % md_layer->batch_size () != 0 ) {
213
216
throw std::runtime_error (" first dimensions of input arrays must be a"
214
217
" multiple of batch size" );
215
218
}
216
219
217
- md_layer->Reset (static_cast <Dtype*>(PyArray_DATA (data_arr)),
218
- static_cast <Dtype*>(PyArray_DATA (labels_arr)),
219
- PyArray_DIMS (data_arr)[0 ]);
220
+ PyArrayObject* labels_arr = nullptr ;
221
+
222
+ if (labels_obj.ptr () != bp::object ().ptr ()) {
223
+ labels_arr = reinterpret_cast <PyArrayObject*>(labels_obj.ptr ());
224
+ CheckContiguousArray (labels_arr, " labels array" , md_layer->label_shape ());
225
+ if (PyArray_DIMS (data_arr)[0 ] != PyArray_DIMS (labels_arr)[0 ]) {
226
+ throw std::runtime_error (" data and labels must have the same first"
227
+ " dimension" );
228
+ }
229
+ md_layer->Reset (static_cast <Dtype*>(PyArray_DATA (data_arr)),
230
+ static_cast <Dtype*>(PyArray_DATA (labels_arr)),
231
+ PyArray_DIMS (data_arr)[0 ]);
232
+ } else {
233
+ md_layer->Reset (static_cast <Dtype*>(PyArray_DATA (data_arr)),
234
+ nullptr ,
235
+ PyArray_DIMS (data_arr)[0 ]);
236
+ }
237
+
238
+
220
239
}
221
240
222
241
@@ -385,7 +404,10 @@ BOOST_PYTHON_MODULE(_caffe) {
385
404
bp::arg (" level" )=0 , bp::arg (" stages" )=bp::object (),
386
405
bp::arg (" weights" )=bp::object ())))
387
406
// Legacy constructor
388
- .def (" __init__" , bp::make_constructor (&Net_Init_Load))
407
+ .def (" __init__" , bp::make_constructor (&Net_Init_Load,
408
+ bp::default_call_policies (), (bp::arg (" network_file" ),
409
+ bp::arg (" pretrained_param_file" ), " phase" ,
410
+ bp::arg (" level" )=0 , bp::arg (" stages" )=bp::object ())))
389
411
.def (" _forward" , &ForwardFromTo_NoGIL)
390
412
.def (" _backward" , &BackwardFromTo_NoGIL)
391
413
.def (" reshape" , &Net<Dtype>::Reshape)
@@ -450,10 +472,26 @@ BOOST_PYTHON_MODULE(_caffe) {
450
472
bp::return_internal_reference<>()))
451
473
.def (" setup" , &Layer<Dtype>::LayerSetUp)
452
474
.def (" reshape" , &Layer<Dtype>::Reshape)
453
- .add_property (" type" , bp::make_function (&Layer<Dtype>::type));
475
+ .add_property (" type" , bp::make_function (&Layer<Dtype>::type))
476
+ .add_property (" layer_param" , bp::make_function (&Layer<Dtype>::layer_param,
477
+ bp::return_internal_reference<>()));
454
478
BP_REGISTER_SHARED_PTR_TO_PYTHON (Layer<Dtype>);
455
479
456
- bp::class_<LayerParameter>(" LayerParameter" , bp::no_init);
480
+ bp::class_<LayerParameter>(" LayerParameter" , bp::no_init)
481
+ .add_property (" name" , bp::make_function (
482
+ static_cast <const string& (LayerParameter::*)
483
+ (void ) const >(&LayerParameter::name),
484
+ bp::return_value_policy<bp::return_by_value>()))
485
+ .add_property (" bottom_size" , &LayerParameter::bottom_size)
486
+ .def (" get_bottom" , bp::make_function (
487
+ static_cast <const string& (LayerParameter::*)
488
+ (int ) const >(&LayerParameter::bottom),
489
+ bp::return_value_policy<bp::return_by_value>()))
490
+ .add_property (" top_size" , &LayerParameter::top_size)
491
+ .def (" get_top" , bp::make_function (
492
+ static_cast <const string& (LayerParameter::*)
493
+ (int ) const >(&LayerParameter::top),
494
+ bp::return_value_policy<bp::return_by_value>()));
457
495
458
496
bp::class_<Solver<Dtype>, shared_ptr<Solver<Dtype> >, boost::noncopyable>(
459
497
" Solver" , bp::no_init)
@@ -471,6 +509,31 @@ BOOST_PYTHON_MODULE(_caffe) {
471
509
.def (" snapshot" , &Solver<Dtype>::Snapshot);
472
510
BP_REGISTER_SHARED_PTR_TO_PYTHON (Solver<Dtype>);
473
511
512
+ bp::class_<NetState>(" NetState" , bp::init<>())
513
+ .add_property (" phase" , &NetState::phase,
514
+ &NetState::set_phase)
515
+ .add_property (" level" , &NetState::level,
516
+ &NetState::set_level)
517
+ .def (" stage_size" , &NetState::stage_size)
518
+ .def (" get_stage" , bp::make_function (
519
+ static_cast <const string& (NetState::*)
520
+ (int ) const >(&NetState::stage),
521
+ bp::return_value_policy<bp::return_by_value>()))
522
+ .def (" add_stage" , static_cast <void (NetState::*)
523
+ (const string&)>(&NetState::add_stage))
524
+ .def (" set_stage" , static_cast <void (NetState::*)
525
+ (int , const string&)>(&NetState::set_stage))
526
+ .def (" clear_stage" , &NetState::clear_stage);
527
+
528
+ bp::class_<NetParameter>(" NetParameter" , bp::init<>())
529
+ .add_property (" force_backward" , &NetParameter::force_backward,
530
+ &NetParameter::set_force_backward)
531
+ .add_property (" state" ,
532
+ bp::make_function (&NetParameter::state,
533
+ bp::return_value_policy<bp::copy_const_reference>()),
534
+ static_cast <void (NetParameter::*)(NetState*)>(
535
+ &NetParameter::set_allocated_state));
536
+
474
537
475
538
bp::class_<SolverParameter>(" SolverParameter" , bp::init<>())
476
539
.add_property (" base_lr" , &SolverParameter::base_lr,
@@ -529,7 +592,17 @@ BOOST_PYTHON_MODULE(_caffe) {
529
592
bp::make_function (&SolverParameter::train_net,
530
593
bp::return_value_policy<bp::copy_const_reference>()),
531
594
static_cast <void (SolverParameter::*)(const string&)>(
532
- &SolverParameter::set_train_net));
595
+ &SolverParameter::set_train_net))
596
+ .add_property (" net_param" ,
597
+ bp::make_function (&SolverParameter::mutable_net_param,
598
+ bp::return_value_policy<bp::reference_existing_object>()),
599
+ static_cast <void (SolverParameter::*)(NetParameter*)>(
600
+ &SolverParameter::set_allocated_net_param))
601
+ .add_property (" train_state" ,
602
+ bp::make_function (&SolverParameter::mutable_train_state,
603
+ bp::return_value_policy<bp::reference_existing_object>()),
604
+ static_cast <void (SolverParameter::*)(NetState*)>(
605
+ &SolverParameter::set_allocated_train_state));
533
606
534
607
bp::enum_<::caffe::SolverParameter_SnapshotFormat>(" snapshot_format" )
535
608
.value (" HDF5" , SolverParameter_SnapshotFormat_HDF5)
0 commit comments