-
Notifications
You must be signed in to change notification settings - Fork 6.8k
dynamic custom operator support #15921
Changes from 113 commits
5030a65
23a226a
67c22c0
915c1d5
f568e3d
8e12588
1e27a47
9aecf86
02deacf
cf9350d
ada3895
38e77a5
7b8f6a2
5b817bd
3bccfbe
a8c19c8
2f34471
3502aa9
52e687b
5438a35
592249a
3186d60
e8b413b
bf549b4
439ee20
bba25db
a681f61
711f9a3
172129f
5af1736
33d9cd7
4576570
6f3e3d9
9587483
34efb2b
ff9a868
0be218b
570a059
4b01932
e4be175
8cfcc85
9884ec6
8e21600
794e30b
8fbf664
e7c6e8f
6047378
d1587ab
0ee56c9
adc9770
5c06d47
9136839
ffe7623
435e01e
0de79a9
0d6f7b0
18b028e
a4690b4
c901828
5ddb919
7b4c4e6
18117ec
ee65419
c66438c
698a0b6
bd55612
35ff973
f243e2f
efbb858
0032143
14ef3a7
eec71d6
abcb8cb
9cf0455
82f1bff
f7ff481
ba563d2
a9b7215
7bf4f7a
8aec7ac
39e3d6b
b3ba028
c9d8498
1686273
7009ad4
4b73179
dca521e
baed04e
aedcf91
75102a3
9c29deb
c5a3ed6
44683f1
d1b6c8e
44affc7
ef1d4cf
24d8cc3
279a989
5db9e97
75b1169
79c0e3a
11d3344
de157a8
28450b5
9504b33
cf27d57
7f456d4
e50819b
5984f3a
bd2c3a0
b07e46b
2466d67
e041400
f16942c
50a6b64
adb0415
6148ef8
6d9ac54
7c256cd
6e824fb
40c471b
dfb5946
5146fd5
1b9fee2
5761891
ef840b4
bba61b3
53d18ec
e0c778c
141328f
deacae2
2b2c6a4
ed8ac16
56b0e28
1bd166e
50c8aea
34a9ee9
9910c39
5fd4314
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
all: subgraph_lib gemm_lib | ||
|
||
gemm_lib: | ||
g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet | ||
|
||
subgraph_lib: | ||
g++ -shared -fPIC -std=c++11 subgraph_lib.cc -o libsubgraph_lib.so -I ../../../include/mxnet | ||
|
||
clean: | ||
rm -rf libsubgraph_lib.so libgemm_lib.so |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file gemm_lib.cc | ||
* \brief Sample 2D gemm custom operator implementation library file | ||
*/ | ||
|
||
#include <iostream> | ||
#include "lib_api.h" | ||
|
||
// main matrix multiplication routine | ||
void gemm(const float* A, const float* B, float* C, | ||
const unsigned n, const unsigned k, const unsigned m) { | ||
unsigned i, j, kk; | ||
for (i = 0; i < n; i++) { | ||
for (j = 0; j < m; j++) { | ||
C[i*m+j] = 0; | ||
for (kk = 0; kk < k; kk++) { | ||
C[i*m+j] += A[i*k+kk] * B[kk*m+j]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void transpose(const float* A, float* At, const unsigned n, const unsigned m) { | ||
unsigned i, j; | ||
for (i = 0; i < n; i++) { | ||
for (j = 0; j < m; j++) { | ||
At[i*m+j] = A[j*n+i]; | ||
} | ||
} | ||
} | ||
|
||
/* | ||
* Executes C = A * B | ||
* inputs[0] = A; inputs[1] = B; outputs[0] = C | ||
*/ | ||
MXReturnValue forward(std::map<std::string, std::string> attrs, | ||
std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource res) { | ||
// simple example of using runtime data type | ||
if (inputs[0].dtype == kFloat32) { | ||
samskalicky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
typedef float DType; | ||
// extract data pointers from tensors | ||
DType* A = inputs[0].data<DType>(); | ||
DType* B = inputs[1].data<DType>(); | ||
DType* C = outputs[0].data<DType>(); | ||
// set tensor shapes | ||
unsigned n = inputs[0].shape[0]; | ||
unsigned k = inputs[0].shape[1]; | ||
unsigned m = inputs[1].shape[1]; | ||
|
||
gemm(A, B, C, n, k, m); | ||
} | ||
return MX_SUCCESS; | ||
} | ||
|
||
/* | ||
* Executes dA = dC * B.T; Executes dB = A.T * dC | ||
***** gradient inputs | ||
* inputs[0] = dC | ||
***** original inputs | ||
* inputs[1] = A; inputs[2] = B | ||
***** original outputs | ||
* inputs[3] = C | ||
***** gradient outputs | ||
* outputs[0] = dA; outputs[1] = dB | ||
*/ | ||
MXReturnValue backward(std::map<std::string, std::string> attrs, | ||
std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource res) { | ||
// extract data pointers from tensors | ||
float* dC = inputs[0].data<float>(); | ||
float* A = inputs[1].data<float>(); | ||
float* B = inputs[2].data<float>(); | ||
float* dA = outputs[0].data<float>(); | ||
float* dB = outputs[1].data<float>(); | ||
// set tensor shapes | ||
unsigned n = inputs[1].shape[0]; | ||
unsigned k = inputs[1].shape[1]; | ||
unsigned m = inputs[2].shape[1]; | ||
|
||
float *At = new float[k*n]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we suggesting users manage memory themselves? Why not using mxnet storage API to request for memory? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rondogency can we use OpResource to allocate the memory instead of using "new" like @eric-haibin-lin is suggesting?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I will make the change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should distinguish CPU and GPU allocation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wkcn remember in this PR there is only CPU operator support. The next PR will have GPU support and we’ll make sure to distinguish cpu/gpu allocations. |
||
float *Bt = new float[m*k]; | ||
|
||
transpose(A, At, k, n); | ||
transpose(B, Bt, m, k); | ||
gemm(dC, Bt, dA, n, m, k); | ||
gemm(At, dC, dB, k, n, m); | ||
|
||
delete[] At; | ||
delete[] Bt; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) { | ||
*num_in = 2; | ||
*num_out = 1; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue inferType(std::map<std::string, std::string> attrs, | ||
std::vector<int> &intypes, | ||
std::vector<int> &outtypes) { | ||
// validate inputs | ||
samskalicky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (intypes.size() != 2) { | ||
std::cout << "Expected 2 inputs to inferType" << std::endl; | ||
return MX_FAIL; | ||
} | ||
for (unsigned i = 0; i < intypes.size(); i++) { | ||
if (intypes[i] != kFloat32) { | ||
std::cout << "Expected input " << i << " to have float32 type" << std::endl; | ||
return MX_FAIL; | ||
} | ||
} | ||
|
||
outtypes[0] = intypes[0]; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue inferShape(std::map<std::string, std::string> attrs, | ||
std::vector<std::vector<unsigned int>> &inshapes, | ||
std::vector<std::vector<unsigned int>> &outshapes) { | ||
// validate inputs | ||
if (inshapes.size() != 2) { | ||
std::cout << "Expected 2 inputs to inferShape" << std::endl; | ||
return MX_FAIL; | ||
} | ||
if (inshapes[0].size() != 2) { | ||
std::cout << "Expected 2D for first input to inferShape" << std::endl; | ||
return MX_FAIL; | ||
} | ||
if (inshapes[1].size() != 2) { | ||
std::cout << "Expected 2D for second input to inferShape" << std::endl; | ||
return MX_FAIL; | ||
} | ||
|
||
unsigned n = inshapes[0][0]; | ||
unsigned k = inshapes[0][1]; | ||
unsigned kk = inshapes[1][0]; | ||
unsigned m = inshapes[1][1]; | ||
if (k != kk) { | ||
std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl; | ||
return MX_FAIL; | ||
} | ||
|
||
outshapes[0].push_back(n); | ||
samskalicky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
outshapes[0].push_back(m); | ||
return MX_SUCCESS; | ||
} | ||
|
||
REGISTER_OP(my_gemm) | ||
.setForward(forward) | ||
.setBackward(backward) | ||
.setParseAttrs(parseAttrs) | ||
.setInferType(inferType) | ||
.setInferShape(inferShape); | ||
|
||
/* ------------------------------------------------------------------------- */ | ||
|
||
class MyStatefulGemm : public CustomStatefulOp { | ||
public: | ||
explicit MyStatefulGemm(int count) : count(count) {} | ||
|
||
MXReturnValue Forward(std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource op_res) { | ||
int* p = static_cast<int*>(op_res.alloc(sizeof(int))); | ||
*p = ++count; | ||
std::cout << "Info: cpu malloc test: keyword + number of forward: " << *p << std::endl; | ||
|
||
std::map<std::string, std::string> attrs; | ||
return forward(attrs, inputs, outputs, op_res); | ||
} | ||
|
||
MXReturnValue Backward(std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource op_res) { | ||
std::map<std::string, std::string> attrs; | ||
return backward(attrs, inputs, outputs, op_res); | ||
} | ||
|
||
~MyStatefulGemm() {} | ||
|
||
private: | ||
int count; | ||
}; | ||
|
||
MXReturnValue createOpState(std::map<std::string, std::string> attrs, | ||
CustomStatefulOp** op_inst) { | ||
int count = 0; | ||
if (attrs.count("test_kw") > 0) | ||
count = std::stoi(attrs["test_kw"]); | ||
*op_inst = new MyStatefulGemm(count); | ||
std::cout << "Info: stateful operator created" << std::endl; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue mutateInputs(std::map<std::string, std::string> attrs, | ||
std::vector<int> &input_indices) { | ||
// input_indices.push_back(1); // mark mutate input | ||
return MX_SUCCESS; | ||
} | ||
|
||
REGISTER_OP(state_gemm) | ||
.setParseAttrs(parseAttrs) | ||
.setInferType(inferType) | ||
.setInferShape(inferShape) | ||
.setMutateInputs(mutateInputs) | ||
.setCreateOpState(createOpState); | ||
|
||
MXReturnValue initialize(int version) { | ||
if (version >= 10400) { | ||
std::cout << "MXNet version " << version << " supported" << std::endl; | ||
return MX_SUCCESS; | ||
} else { | ||
std::cout << "MXNet version " << version << " not supported" << std::endl; | ||
return MX_FAIL; | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.