Skip to content

Commit 52732be

Browse files
committed
Merge, stronger PyCaffe interface.
1 parent a1a0d7f commit 52732be

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

python/caffe/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver
1+
from .pycaffe import SolverParameter, Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver
22
from ._caffe import set_mode_cpu, set_mode_gpu, set_device, enumerate_devices, Layer, get_solver, layer_type_list
33
from .proto.caffe_pb2 import TRAIN, TEST
44
from .classifier import Classifier

python/caffe/_caffe.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ BOOST_PYTHON_MODULE(_caffe) {
305305
.def("snapshot", &Solver<Dtype>::Snapshot);
306306

307307

308-
bp::class_<SolverParameter>("SolverParam", bp::no_init)
308+
bp::class_<SolverParameter>("SolverParameter", bp::init<>())
309309
.add_property("base_lr", &SolverParameter::base_lr,
310310
&SolverParameter::set_base_lr)
311311
.add_property("max_iter", &SolverParameter::max_iter,
@@ -355,7 +355,12 @@ BOOST_PYTHON_MODULE(_caffe) {
355355
bp::make_function(&SolverParameter::net,
356356
bp::return_value_policy<bp::copy_const_reference>()),
357357
static_cast<void (SolverParameter::*)(const string&)>(
358-
&SolverParameter::set_net));
358+
&SolverParameter::set_net))
359+
.add_property("train_net",
360+
bp::make_function(&SolverParameter::train_net,
361+
bp::return_value_policy<bp::copy_const_reference>()),
362+
static_cast<void (SolverParameter::*)(const string&)>(
363+
&SolverParameter::set_train_net));
359364

360365

361366
bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,

python/caffe/pycaffe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from itertools import zip_longest as izip_longest
1111
import numpy as np
1212

13-
from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \
14-
RMSPropSolver, AdaDeltaSolver, AdamSolver
13+
from ._caffe import \
14+
SolverParameter, Net, SGDSolver, NesterovSolver, AdaGradSolver, \
15+
RMSPropSolver, AdaDeltaSolver, AdamSolver
16+
1517
import caffe.io
1618

1719
# We directly update methods from Net here (rather than using composition or

0 commit comments

Comments
 (0)