This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Dynamic custom operator GPU support #17270
Merged
Merged
Changes from 11 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
09ebadf
poc gpu customop end to end
rondogency bf3cfdc
add backward and device id
rondogency 4bcda15
clear up customop makefile
rondogency 61db076
new fcomp register
rondogency 2f8cf79
new setforward to pass custom context to c_api
rondogency cb33160
resolve sam comment: add cond register and fix setforward char
rondogency 589df6b
tmp stateful op
rondogency 7c50e5b
passing ctx of stateful op
rondogency 0d121f2
add gpu alloc and refactor all fcomp
rondogency 8132602
resolve sam comments and refactor alloc
rondogency 9499173
add gpu check to pass cpu build
rondogency c455751
add unittest and resolve ptrend comments
rondogency 6a06b84
add cmake and jenkins
rondogency 114456a
fix windows
rondogency f388e73
windows gpu cmake build fix
rondogency 6e3b739
remove verbose
rondogency File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2020 by Contributors | ||
* \file relu_lib.cu | ||
* \brief simple custom relu operator implemented using CUDA function | ||
*/ | ||
|
||
#include <iostream> | ||
#include "lib_api.h" | ||
|
||
__global__ void relu_gpu_forward(float *out, float *in, int64_t N) { | ||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (tid < N) | ||
out[tid] = in[tid] > 0 ? in[tid] : 0; | ||
} | ||
|
||
__global__ void relu_gpu_backward(float *out, float *in, int64_t N) { | ||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (tid < N) | ||
out[tid] = in[tid] > 0 ? 1 : 0; | ||
rondogency marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
MXReturnValue forwardCPU(std::map<std::string, std::string> attrs, | ||
std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource res) { | ||
float* in_data = inputs[0].data<float>(); | ||
float* out_data = outputs[0].data<float>(); | ||
for (int i=0; i<inputs[0].size(); i++) { | ||
out_data[i] = in_data[i] > 0 ? in_data[i] : 0; | ||
} | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue backwardCPU(std::map<std::string, std::string> attrs, | ||
std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource res) { | ||
float* in_data = inputs[0].data<float>(); | ||
float* out_data = outputs[0].data<float>(); | ||
for (int i=0; i<inputs[0].size(); i++) { | ||
out_data[i] = in_data[i] > 0 ? 1 : 0; | ||
} | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue forwardGPU(std::map<std::string, std::string> attrs, | ||
std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource res) { | ||
float* in_data = inputs[0].data<float>(); | ||
float* out_data = outputs[0].data<float>(); | ||
|
||
// test on memory resource allocation | ||
void *workspace_cpu = res.alloc_cpu(8 * sizeof(float)); | ||
void *workspace_gpu = res.alloc_gpu(8 * sizeof(float)); | ||
rondogency marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
mx_stream_t cuda_stream = res.get_cuda_stream(); | ||
int64_t N = inputs[0].size(); | ||
int block = 256; | ||
int grid = (N + (block - 1)) / block; | ||
relu_gpu_forward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N); | ||
|
||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue backwardGPU(std::map<std::string, std::string> attrs, | ||
std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource res) { | ||
float* in_data = inputs[0].data<float>(); | ||
float* out_data = outputs[0].data<float>(); | ||
|
||
mx_stream_t cuda_stream = res.get_cuda_stream(); | ||
int64_t N = inputs[0].size(); | ||
int block = 256; | ||
int grid = (N + (block - 1)) / block; | ||
relu_gpu_backward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N); | ||
|
||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) { | ||
*num_in = 1; | ||
*num_out = 1; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue inferType(std::map<std::string, std::string> attrs, | ||
std::vector<int> &intypes, | ||
std::vector<int> &outtypes) { | ||
outtypes[0] = intypes[0]; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue inferShape(std::map<std::string, std::string> attrs, | ||
std::vector<std::vector<unsigned int>> &inshapes, | ||
std::vector<std::vector<unsigned int>> &outshapes) { | ||
outshapes[0] = inshapes[0]; | ||
return MX_SUCCESS; | ||
} | ||
|
||
REGISTER_OP(my_relu) | ||
.setParseAttrs(parseAttrs) | ||
.setInferType(inferType) | ||
.setInferShape(inferShape) | ||
.setForward(forwardCPU, "cpu") | ||
.setForward(forwardGPU, "gpu") | ||
.setBackward(backwardCPU, "cpu") | ||
.setBackward(backwardGPU, "gpu"); | ||
|
||
class MyStatefulReluCPU : public CustomStatefulOp { | ||
public: | ||
explicit MyStatefulReluCPU() {} | ||
MXReturnValue Forward(std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource op_res) { | ||
std::map<std::string, std::string> attrs; | ||
return forwardCPU(attrs, inputs, outputs, op_res); | ||
} | ||
MXReturnValue Backward(std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource op_res) { | ||
std::map<std::string, std::string> attrs; | ||
return backwardCPU(attrs, inputs, outputs, op_res); | ||
} | ||
~MyStatefulReluCPU() {} | ||
}; | ||
|
||
class MyStatefulReluGPU : public CustomStatefulOp { | ||
public: | ||
explicit MyStatefulReluGPU() {} | ||
MXReturnValue Forward(std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource op_res) { | ||
std::map<std::string, std::string> attrs; | ||
return forwardGPU(attrs, inputs, outputs, op_res); | ||
} | ||
MXReturnValue Backward(std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource op_res) { | ||
std::map<std::string, std::string> attrs; | ||
return backwardGPU(attrs, inputs, outputs, op_res); | ||
} | ||
~MyStatefulReluGPU() {} | ||
}; | ||
|
||
MXReturnValue createOpStateCPU(std::map<std::string, std::string> attrs, | ||
CustomStatefulOp** op_inst) { | ||
*op_inst = new MyStatefulReluCPU(); | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue createOpStateGPU(std::map<std::string, std::string> attrs, | ||
CustomStatefulOp** op_inst) { | ||
*op_inst = new MyStatefulReluGPU(); | ||
return MX_SUCCESS; | ||
} | ||
|
||
REGISTER_OP(my_state_relu) | ||
.setParseAttrs(parseAttrs) | ||
.setInferType(inferType) | ||
.setInferShape(inferShape) | ||
.setCreateOpState(createOpStateCPU, "cpu") | ||
.setCreateOpState(createOpStateGPU, "gpu"); | ||
|
||
MXReturnValue initialize(int version) { | ||
if (version >= 10400) { | ||
rondogency marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::cout << "MXNet version " << version << " supported" << std::endl; | ||
return MX_SUCCESS; | ||
} else { | ||
std::cout << "MXNet version " << version << " not supported" << std::endl; | ||
return MX_FAIL; | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable=arguments-differ | ||
|
||
# This test checks dynamic loading of custom library into MXNet | ||
# and checks end to end compute of a simple 2D gemm custom op | ||
|
||
import mxnet as mx | ||
import os | ||
import time | ||
|
||
#load library | ||
if (os.name=='posix'): | ||
path = os.path.abspath('librelu_lib.so') | ||
mx.library.load(path) | ||
|
||
a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu()) | ||
b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu()) | ||
|
||
print("--------start ndarray compute---------") | ||
print(mx.nd.my_relu(a)) | ||
print(mx.nd.my_relu(b)) | ||
print(mx.nd.my_state_relu(a)) | ||
print(mx.nd.my_state_relu(b)) | ||
|
||
print("--------start symbolic compute--------") | ||
c = mx.sym.Variable('c') | ||
d = mx.sym.my_relu(c) | ||
in_grad = [mx.nd.empty((2,2), ctx=mx.gpu())] | ||
exe = d.bind(ctx=mx.gpu(), args={'c':b}, args_grad=in_grad) | ||
out = exe.forward() | ||
print(out) | ||
|
||
print("--------start backward compute--------") | ||
out_grad = mx.nd.ones((2,2), ctx=mx.gpu()) | ||
exe.backward([out_grad]) | ||
print(in_grad) | ||
|
||
print("--------start stress test---------") | ||
a = mx.nd.uniform(shape=(1000,1000,100), ctx=mx.cpu()) | ||
b = mx.nd.uniform(shape=(1000,1000,100), ctx=mx.gpu()) | ||
t1 = time.time() | ||
r1 = mx.nd.my_relu(a) | ||
t2 = time.time() | ||
r2 = mx.nd.my_relu(b) | ||
t3 = time.time() | ||
print("CPU running time:") | ||
print(t2 - t1) | ||
rondogency marked this conversation as resolved.
Show resolved
Hide resolved
|
||
print("GPU running time:") | ||
print(t3 - t2) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.