Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a8f9ee3

Browse files
committed
Merge pull request #3 from dmlc/master
Update dev branch
2 parents b55212f + 480604f commit a8f9ee3

28 files changed

+1488
-99
lines changed

.gitignore

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,14 @@
3030
dmlc-core
3131
mshadow
3232
config.mk
33+
34+
*.pyc
35+
.Rhistory
36+
*log
37+
Debug
38+
*suo
39+
40+
# vim
41+
*.swp
42+
*.swo
43+
*.swn

Makefile

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,23 @@ endif
4848

4949
BIN = test/api_registry_test
5050
OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o
51-
OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o
51+
# add threaded engine after it is done
52+
OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o engine.o
5253
CUOBJ = narray_op_gpu.o operator_gpu.o
53-
54+
SLIB = api/libmxnet.so
55+
ALIB = api/libmxnet.a
5456
LIB_DEP = $(DMLC_CORE)/libdmlc.a
5557

5658
.PHONY: clean all
5759

58-
all: $(OBJ) $(OBJCXX11) $(CUOBJ) $(BIN)
60+
all: $(ALIB) $(SLIB) $(BIN)
5961

6062
$(DMLC_CORE)/libdmlc.a:
6163
+ cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR)
6264

6365
storage.o: src/storage/storage.cc
6466
engine.o: src/dag_engine/simple_engine.cc
67+
threaded_engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h
6568
narray.o: src/narray/narray.cc
6669
narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h
6770
narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h
@@ -71,7 +74,10 @@ operator_gpu.o: src/operator/operator_gpu.cu
7174
api_registry.o: src/api_registry.cc
7275
mxnet_api.o: api/mxnet_api.cc
7376

74-
test/api_registry_test: test/api_registry_test.cc $(OBJ) $(OBJCXX11) $(CUOBJ)
77+
api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
78+
api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ)
79+
80+
test/api_registry_test: test/api_registry_test.cc api/libmxnet.a
7581

7682
$(BIN) :
7783
$(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS)
@@ -85,12 +91,15 @@ $(OBJCXX11) :
8591
$(SLIB) :
8692
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS)
8793

94+
$(ALIB):
95+
ar cr $@ $+
96+
8897
$(CUOBJ) :
8998
$(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $(filter %.cu, $^)
9099

91100
$(CUBIN) :
92101
$(NVCC) -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -Xlinker "$(LDFLAGS)" $(filter %.cu %.cpp %.o, $^)
93102

94103
clean:
95-
$(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) *~ */*~ */*/*~
104+
$(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) $(ALIB) *~ */*~ */*/*~ */*/*/*~
96105
cd $(DMLC_CORE); make clean; cd -

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
# MXNet
2-
This is an experimental project to put cxxnet and minerva together, nothing is working yet.
2+
This is a project that combines lessons and ideas we learnt from [cxxnet](https://github.com/dmlc/cxxnet), [minerva](https://github.com/dmlc/minerva) and [purine2](https://github.com/purine/purine2).
3+
- The interface is designed in collaboration by authors of three projects.
4+
- Nothing is yet working
35

46
# Guidelines
57
* Use google c style
68
* Put module header in [include](include)
7-
- move them to ```project-name/include``` when we finalized the name
89
* Depend on [dmlc-core](https://github.com/dmlc/dmlc-core)
910
* Doxygen comment every function, class and variable for the module headers
1011
- Ref headers in [dmlc-core/include](https://github.com/dmlc/dmlc-core/tree/master/include/dmlc)
1112
- Use the same style as dmlc-core
12-
* Try write some use-cases of interface in [test](test)
13-
- They do not need to link, but need to pass compile
1413
* Minimize dependency, if possible only depend on dmlc-core
1514
* Macro Guard CXX11 code by
1615
- Try to make interface compile when c++11 was not avaialable(but with some functionalities pieces missing)

api/mxnet_api.cc

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,191 @@
1+
#include <dmlc/base.h>
2+
#include <dmlc/logging.h>
13
#include <mxnet/base.h>
24
#include <mxnet/narray.h>
5+
#include <mxnet/api_registry.h>
36
#include "./mxnet_api.h"
47

8+
// NOTE: all functions return 0 upon success
9+
// consider add try/catch block for user error
10+
// handling in the future
11+
using namespace mxnet;
12+
13+
// macro to guard beginning and end section of all functions
14+
// every function starts with API_BEGIN(); and finishes with API_END();
15+
#define API_BEGIN() try {
16+
#define API_END() } catch(dmlc::Error &e) { return MXHandleException(e); } return 0;
17+
18+
/*!
19+
* \brief a helper function for error handling
20+
* will set the last error to be str_set when it is not NULL
21+
* \param str_set the error to set
22+
* \return a pointer message to last error
23+
*/
24+
const char *MXSetGetLastError_(const char *str_set) {
25+
// use last_error to record last error
26+
static thread_local std::string last_error;
27+
if (str_set != NULL) {
28+
last_error = str_set;
29+
}
30+
return last_error.c_str();
31+
}
32+
33+
/*! \brief return str message of the last error */
34+
const char *MXGetLastError() {
35+
return MXSetGetLastError_(NULL);
36+
}
37+
38+
/*!
39+
* \brief handle exception throwed out
40+
* \param e the exception
41+
* \return the return value of API after exception is handled
42+
*/
43+
int MXHandleException(const dmlc::Error &e) {
44+
MXSetGetLastError_(e.what());
45+
return -1;
46+
}
47+
48+
// NOTE: return value is added in API_END
49+
int MXNArrayCreateNone(NArrayHandle *out) {
50+
API_BEGIN();
51+
*out = new NArray();
52+
API_END();
53+
}
54+
55+
int MXNArrayCreateShareMem(mx_float *data,
56+
mx_uint *shape,
57+
mx_uint ndim,
58+
NArrayHandle *out) {
59+
API_BEGIN();
60+
*out = new NArray(TBlob(data, TShape(shape, shape + ndim),
61+
cpu::kDevMask), 0);
62+
API_END();
63+
}
64+
65+
int MXNArrayCreate(const mx_uint *shape,
66+
mx_uint ndim,
67+
int dev_mask,
68+
int dev_id,
69+
int delay_alloc,
70+
NArrayHandle *out) {
71+
API_BEGIN();
72+
*out = new NArray(TShape(shape, shape + ndim),
73+
Context(dev_mask, dev_id),
74+
delay_alloc != 0);
75+
API_END();
76+
}
77+
78+
int MXNArrayWait(NArrayHandle handle) {
79+
API_BEGIN();
80+
static_cast<NArray*>(handle)->Wait();
81+
API_END();
82+
}
83+
84+
int MXNArrayWaitAll() {
85+
API_BEGIN();
86+
DAGEngine::Get()->WaitForAll();
87+
API_END();
88+
}
89+
90+
int MXNArrayFree(NArrayHandle handle) {
91+
API_BEGIN();
92+
delete static_cast<NArray*>(handle);
93+
API_END();
94+
}
95+
96+
int MXNArrayGetShape(NArrayHandle handle,
97+
mx_uint *out_dim,
98+
const mx_uint **out_pdata) {
99+
API_BEGIN();
100+
NArray *arr = static_cast<NArray*>(handle);
101+
if (!arr->is_none()) {
102+
const TShape &s = arr->shape();
103+
*out_dim = s.ndim();
104+
*out_pdata = s.data();
105+
} else {
106+
*out_dim = 0;
107+
}
108+
API_END();
109+
}
110+
111+
int MXNArrayGetData(NArrayHandle handle,
112+
mx_float **out_pdata) {
113+
API_BEGIN();
114+
NArray *arr = static_cast<NArray*>(handle);
115+
if (!arr->is_none()) {
116+
CHECK(arr->ctx().dev_mask == cpu::kDevMask)
117+
<< "MXNArrayGetData can only be called for NArray on CPU";
118+
const TBlob &b = arr->data();
119+
CHECK(b.CheckContiguous());
120+
*out_pdata = b.FlatTo2D<cpu, mx_float>().dptr_;
121+
} else {
122+
*out_pdata = nullptr;
123+
}
124+
API_END();
125+
}
126+
127+
int MXNArrayGetContext(NArrayHandle handle,
128+
int *out_dev_mask,
129+
int *out_dev_id) {
130+
API_BEGIN();
131+
NArray *arr = static_cast<NArray*>(handle);
132+
if (!arr->is_none()) {
133+
const Context &ctx = arr->ctx();
134+
*out_dev_mask = ctx.dev_mask;
135+
*out_dev_id = ctx.dev_id;
136+
} else {
137+
*out_dev_mask = 0;
138+
*out_dev_id = 0;
139+
}
140+
API_END();
141+
}
142+
143+
int MXListFunctions(mx_uint *out_size,
144+
FunctionHandle **out_array) {
145+
API_BEGIN();
146+
auto &vec = FunctionRegistry::List();
147+
*out_size = static_cast<mx_uint>(vec.size());
148+
*out_array = (FunctionHandle*)(dmlc::BeginPtr(vec));
149+
API_END();
150+
}
151+
152+
int MXGetFunction(const char *name,
153+
FunctionHandle *out) {
154+
API_BEGIN();
155+
*out = FunctionRegistry::Find(name);
156+
API_END();
157+
}
158+
159+
int MXFuncGetName(FunctionHandle fun,
160+
const char **out_name) {
161+
API_BEGIN();
162+
auto *f = static_cast<const FunctionRegistry::Entry *>(fun);
163+
*out_name = f->name.c_str();
164+
API_END();
165+
}
166+
167+
int MXFuncDescribe(FunctionHandle fun,
168+
mx_uint *num_use_vars,
169+
mx_uint *num_scalars,
170+
mx_uint *num_mutate_vars,
171+
int *type_mask) {
172+
API_BEGIN();
173+
auto *f = static_cast<const FunctionRegistry::Entry *>(fun);
174+
*num_use_vars = f->num_use_vars;
175+
*num_scalars = f->num_scalars;
176+
*num_mutate_vars = f->num_mutate_vars;
177+
*type_mask = f->type_mask;
178+
API_END();
179+
}
180+
181+
int MXFuncInvoke(FunctionHandle fun,
182+
NArrayHandle *use_vars,
183+
mx_float *scalar_args,
184+
NArrayHandle *mutate_vars) {
185+
API_BEGIN();
186+
auto *f = static_cast<const FunctionRegistry::Entry *>(fun);
187+
(*f)((NArray**)(use_vars),
188+
scalar_args,
189+
(NArray**)(mutate_vars));
190+
API_END();
191+
}

api/mxnet_api.h

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,24 @@ typedef float mx_float;
2626
/*! \brief handle to NArray */
2727
typedef void *NArrayHandle;
2828
/*! \brief handle to a mxnet narray function that changes NArray */
29-
typedef void *FunctionHandle;
29+
typedef const void *FunctionHandle;
3030
/*! \brief handle to a symbol that can be bind as operator */
3131
typedef void *SymbolHandle;
3232
/*! \brief handle to a NArrayOperator */
3333
typedef void *OperatorHandle;
3434
/*! \brief handle to a DataIterator */
3535
typedef void *DataIterHandle;
3636

37+
/*!
38+
* \brief return str message of the last error
39+
* all function in this file will return 0 when success
40+
* and -1 when an error occured,
41+
* MXGetLastError can be called to retrieve the error
42+
*
43+
* this function is threadsafe and can be called by different thread
44+
*/
45+
MXNET_DLL const char *MXGetLastError();
46+
3747
//--------------------------------
3848
// Part 1: NArray creation and deletion
3949
//--------------------------------
@@ -71,13 +81,16 @@ MXNET_DLL int MXNArrayCreateShareMem(mx_float *data,
7181
* \param ndim the dimension of the shape
7282
* \param dev_mask device mask, specify device we want to take
7383
* \param dev_id the device id of the specific device
84+
* \param delay_alloc whether to delay allocation until
85+
* the narray is first mutated
7486
* \param out the returning handle
7587
* \return 0 when success, -1 when failure happens
7688
*/
7789
MXNET_DLL int MXNArrayCreate(const mx_uint *shape,
7890
mx_uint ndim,
7991
int dev_mask,
8092
int dev_id,
93+
int delay_alloc,
8194
NArrayHandle *out);
8295
/*!
8396
* \brief wait until all the operation with respect NArray
@@ -105,25 +118,27 @@ MXNET_DLL int MXNArrayFree(NArrayHandle handle);
105118
* \param out_pdata pointer holder to get data pointer of the shape
106119
* \return 0 when success, -1 when failure happens
107120
*/
108-
MXNET_DLL int MXNArrayGetShape(NArrayHandle *handle,
121+
MXNET_DLL int MXNArrayGetShape(NArrayHandle handle,
109122
mx_uint *out_dim,
110-
mx_uint **out_pdata);
123+
const mx_uint **out_pdata);
111124
/*!
112125
* \brief get the content of the data in NArray
113126
* \param handle the handle to the narray
114127
* \param out_pdata pointer holder to get pointer of data
115128
* \return 0 when success, -1 when failure happens
116129
*/
117-
MXNET_DLL int MXNArrayGetData(NArrayHandle *handle,
130+
MXNET_DLL int MXNArrayGetData(NArrayHandle handle,
118131
mx_float **out_pdata);
119132
/*!
120-
* \brief get the device of the NArray
133+
* \brief get the context of the NArray
121134
* \param handle the handle to the narray
122-
* \param out_device the output device mask
135+
* \param out_dev_mask the output device mask
136+
* \param out_dev_id the output device id
123137
* \return 0 when success, -1 when failure happens
124138
*/
125-
MXNET_DLL int MXNArrayGetDevice(NArrayHandle *handle,
126-
int *out_device);
139+
MXNET_DLL int MXNArrayGetContext(NArrayHandle handle,
140+
int *out_dev_mask,
141+
int *out_dev_id);
127142

128143
//--------------------------------
129144
// Part 2: functions on NArray
@@ -158,13 +173,15 @@ MXNET_DLL int MXFuncGetName(FunctionHandle fun,
158173
* \param num_use_vars how many NArrays to be passed in as used_vars
159174
* \param num_scalars scalar variable is needed
160175
* \param num_mutate_vars how many NArrays to be passed in as mutate_vars
176+
* \param type_mask the type mask of this function
161177
* \return 0 when success, -1 when failure happens
162178
* \sa MXFuncInvoke
163179
*/
164-
MXNET_DLL int MXFuncDescribeArgs(FunctionHandle fun,
165-
mx_uint *num_use_vars,
166-
mx_uint *num_scalars,
167-
mx_uint *num_mutate_vars);
180+
MXNET_DLL int MXFuncDescribe(FunctionHandle fun,
181+
mx_uint *num_use_vars,
182+
mx_uint *num_scalars,
183+
mx_uint *num_mutate_vars,
184+
int *type_mask);
168185

169186
/*!
170187
* \brief invoke a function, the array size of passed in arguments

api/python/mxnet/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
"""MXNet: a concise, fast and flexible framework for deep learning
4+
5+
MXNet is a project that evolves from cxxnet, minerva and purine2.
6+
The interface is designed in collaboration by authors of three projects.
7+
8+
Version : 0.10
9+
"""
10+
from __future__ import absolute_import
11+
12+
from .context import Context, current_context
13+
from .narray import NArray, _init_function_registry
14+
from .function import _FunctionRegistry
15+
16+
# this is a global function registry that can be used to invoke functions
17+
op = _init_function_registry(_FunctionRegistry())

0 commit comments

Comments
 (0)