16
16
17
17
#include " caffe/caffe.hpp"
18
18
#include " caffe/python_layer.hpp"
19
+ #include " caffe/sgd_solvers.hpp"
19
20
20
21
// Temporary solution for numpy < 1.7 versions: old macro, no promises.
21
22
// You're strongly advised to upgrade to >= 1.7.
@@ -50,7 +51,7 @@ static void CheckFile(const string& filename) {
50
51
}
51
52
52
53
void CheckContiguousArray (PyArrayObject* arr, string name,
53
- vector<int > shape) {
54
+ vector<int_tp > shape) {
54
55
if (!(PyArray_FLAGS (arr) & NPY_ARRAY_C_CONTIGUOUS)) {
55
56
throw std::runtime_error (name + " must be C contiguous" );
56
57
}
@@ -63,11 +64,12 @@ void CheckContiguousArray(PyArrayObject* arr, string name,
63
64
if (PyArray_TYPE (arr) != NPY_FLOAT32) {
64
65
throw std::runtime_error (name + " must be float32" );
65
66
}
66
- for (int i = 1 ; i < PyArray_NDIM (arr); ++i) {
67
+ for (int_tp i = 1 ; i < PyArray_NDIM (arr); ++i) {
67
68
if (PyArray_DIMS (arr)[i] != shape[i]) {
68
69
throw std::runtime_error (
69
70
" Shape dimension " + std::to_string (i) + " has wrong size ("
70
- + std::to_string (static_cast <int >(PyArray_DIMS (arr)[i])) + " vs. "
71
+ + std::to_string (static_cast <int_tp>
72
+ (PyArray_DIMS (arr)[i])) + " vs. "
71
73
+ std::to_string (shape[i]) + " )" );
72
74
}
73
75
}
@@ -134,8 +136,8 @@ void Net_SetInputArrays(Net<Dtype>* net, int index, bp::object data_obj,
134
136
135
137
Solver<Dtype>* GetSolverFromFile (const string& filename) {
136
138
SolverParameter param;
137
- ReadProtoFromTextFileOrDie (filename, ¶m);
138
- return GetSolver <Dtype>(param);
139
+ ReadSolverParamsFromTextFileOrDie (filename, ¶m);
140
+ return SolverRegistry <Dtype>:: CreateSolver (param);
139
141
}
140
142
141
143
struct NdarrayConverterGenerator {
@@ -165,8 +167,8 @@ struct NdarrayCallPolicies : public bp::default_call_policies {
165
167
// the shape information from the blob.
166
168
void * data = PyArray_DATA (reinterpret_cast <PyArrayObject*>(result));
167
169
Py_DECREF (result);
168
- const int num_axes = blob->num_axes ();
169
- vector<npy_intp > dims (blob->shape ().begin (), blob->shape ().end ());
170
+ const int_tp num_axes = blob->num_axes ();
171
+ vector<npy_long > dims (blob->shape ().begin (), blob->shape ().end ());
170
172
PyObject *arr_obj = PyArray_SimpleNewFromData (num_axes, dims.data (),
171
173
NPY_FLOAT32, data);
172
174
// SetBaseObject steals a ref, so we need to INCREF.
@@ -182,9 +184,9 @@ bp::object Blob_Reshape(bp::tuple args, bp::dict kwargs) {
182
184
throw std::runtime_error (" Blob.reshape takes no kwargs" );
183
185
}
184
186
Blob<Dtype>* self = bp::extract<Blob<Dtype>*>(args[0 ]);
185
- vector<int > shape (bp::len (args) - 1 );
186
- for (int i = 1 ; i < bp::len (args); ++i) {
187
- shape[i - 1 ] = bp::extract<int >(args[i]);
187
+ vector<int_tp > shape (bp::len (args) - 1 );
188
+ for (int_tp i = 1 ; i < bp::len (args); ++i) {
189
+ shape[i - 1 ] = bp::extract<int_tp >(args[i]);
188
190
}
189
191
self->Reshape (shape);
190
192
// We need to explicitly return None to use bp::raw_function.
@@ -197,9 +199,9 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
197
199
}
198
200
typedef vector<shared_ptr<Blob<Dtype> > > BlobVec;
199
201
BlobVec* self = bp::extract<BlobVec*>(args[0 ]);
200
- vector<int > shape (bp::len (args) - 1 );
201
- for (int i = 1 ; i < bp::len (args); ++i) {
202
- shape[i - 1 ] = bp::extract<int >(args[i]);
202
+ vector<int_tp > shape (bp::len (args) - 1 );
203
+ for (int_tp i = 1 ; i < bp::len (args); ++i) {
204
+ shape[i - 1 ] = bp::extract<int_tp >(args[i]);
203
205
}
204
206
self->push_back (shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
205
207
// We need to explicitly return None to use bp::raw_function.
@@ -252,14 +254,14 @@ BOOST_PYTHON_MODULE(_caffe) {
252
254
" Blob" , bp::no_init)
253
255
.add_property (" shape" ,
254
256
bp::make_function (
255
- static_cast <const vector<int >& (Blob<Dtype>::*)() const >(
257
+ static_cast <const vector<int_tp >& (Blob<Dtype>::*)() const >(
256
258
&Blob<Dtype>::shape),
257
259
bp::return_value_policy<bp::copy_const_reference>()))
258
260
.add_property (" num" , &Blob<Dtype>::num)
259
261
.add_property (" channels" , &Blob<Dtype>::channels)
260
262
.add_property (" height" , &Blob<Dtype>::height)
261
263
.add_property (" width" , &Blob<Dtype>::width)
262
- .add_property (" count" , static_cast <int (Blob<Dtype>::*)() const >(
264
+ .add_property (" count" , static_cast <int_tp (Blob<Dtype>::*)() const >(
263
265
&Blob<Dtype>::count))
264
266
.def (" reshape" , bp::raw_function (&Blob_Reshape))
265
267
.add_property (" data" , bp::make_function (&Blob<Dtype>::mutable_cpu_data,
@@ -322,8 +324,8 @@ BOOST_PYTHON_MODULE(_caffe) {
322
324
.def (bp::vector_indexing_suite<vector<shared_ptr<Layer<Dtype> > >, true >());
323
325
bp::class_<vector<string> >(" StringVec" )
324
326
.def (bp::vector_indexing_suite<vector<string> >());
325
- bp::class_<vector<int > >(" IntVec" )
326
- .def (bp::vector_indexing_suite<vector<int > >());
327
+ bp::class_<vector<int_tp > >(" IntVec" )
328
+ .def (bp::vector_indexing_suite<vector<int_tp > >());
327
329
bp::class_<vector<Dtype> >(" DtypeVec" )
328
330
.def (bp::vector_indexing_suite<vector<Dtype> >());
329
331
bp::class_<vector<shared_ptr<Net<Dtype> > > >(" NetVec" )
0 commit comments