Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7a928be

Browse files
committed
Merge branch 'temp'
2 parents aa3c416 + 42a47b1 commit 7a928be

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1501
-487
lines changed

CODEOWNERS

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
/plugin/ @pllarroy
4343

4444
# CMake
45-
CMakeLists.txt @szha @rahul003 @pllarroy
46-
/cmake/ @szha @rahul003 @pllarroy
45+
CMakeLists.txt @szha @pllarroy
46+
/cmake/ @szha @pllarroy
4747

4848
# MXNet CI
4949
dev_menu.py @pllarroy
@@ -71,4 +71,3 @@ prepare_mkl.sh @szha
7171

7272
# Github templates
7373
/.github/ @szha
74-

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ endif
190190

191191
ifeq ($(USE_OPENMP), 1)
192192
CFLAGS += -fopenmp
193+
CFLAGS += -DMXNET_USE_OPENMP=1
193194
endif
194195

195196
ifeq ($(USE_NNPACK), 1)
@@ -252,7 +253,7 @@ ifeq ($(USE_CUDNN), 1)
252253
LDFLAGS += -lcudnn
253254
endif
254255

255-
ifeq ($(USE_BLAS), open)
256+
ifeq ($(USE_BLAS), openblas)
256257
CFLAGS += -DMXNET_USE_BLAS_OPEN=1
257258
else ifeq ($(USE_BLAS), atlas)
258259
CFLAGS += -DMXNET_USE_BLAS_ATLAS=1

benchmark/opperf/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ For example, you want to run benchmarks for all NDArray Broadcast Binary Operato
7575

7676
```
7777
#!/usr/bin/python
78-
from benchmark.opperf.tensor_operations.binary_broadcast_operators import run_mx_binary_broadcast_operators_benchmarks
78+
from benchmark.opperf.nd_operations.binary_broadcast_operators import run_mx_binary_broadcast_operators_benchmarks
7979
8080
# Run all Binary Broadcast operations benchmarks with default input values
8181
print(run_mx_binary_broadcast_operators_benchmarks())

benchmark/opperf/nd_operations/README.md

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,14 @@
2828
6. reshape
2929
7. one_hot
3030
8. linalg_potri
31-
9. mp_sgd_update
3231
10. multi_sgd_update
33-
11. signum_update
3432
12. Convolution_v1
3533
13. repeat
3634
14. Custom
3735
15. softmax_cross_entropy
3836
16. SwapAxis
3937
17. norm
4038
18. Softmax
41-
19. rmspropalex_update
4239
20. fill_element_0index
4340
21. cast
4441
22. UpSampling
@@ -52,32 +49,27 @@
5249
30. Activation
5350
31. LinearRegressionOutput
5451
32. Pooling_v1
55-
33. ftml_update
5652
34. Crop
5753
35. ElementWiseSum
5854
36. diag
5955
37. Reshape
6056
38. Pad
6157
39. linalg_gemm2
6258
40. crop
63-
41. rmsprop_update
6459
43. RNN
6560
45. SoftmaxOutput
6661
46. linalg_extractdiag
67-
47. sgd_mom_update
6862
48. SequenceLast
6963
51. SequenceReverse
7064
53. SVMOutput
7165
54. linalg_trsm
7266
55. where
7367
56. SoftmaxActivation
74-
57. signsgd_update
7568
58. slice
7669
59. linalg_gelqf
7770
60. softmin
7871
61. linalg_gemm
7972
62. BilinearSampler
80-
63. mp_sgd_mom_update
8173
64. choose_element_0index
8274
65. tile
8375
67. gather_nd
@@ -110,7 +102,6 @@
110102
98. linalg_syrk
111103
99. squeeze
112104
101. ROIPooling
113-
102. ftrl_update
114105
103. SliceChannel
115106
104. slice_like
116107
106. linalg_maketrian
@@ -127,6 +118,4 @@
127118
119. normal
128119
120. take
129120
121. MakeLoss
130-
122. sgd_update
131-
123. adam_update
132-
124. concat
121+
124. concat
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import mxnet as mx
19+
from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
20+
from benchmark.opperf.utils.op_registry_utils import get_all_optimizer_operators
21+
22+
"""Performance benchmark tests for MXNet Neural Network Optimizer Update Operators.
23+
24+
1. Stochastic Gradient Descent (SGD)
25+
1.1 mp_sgd_update
26+
1.2 sgd_mom_update
27+
1.3 signsgd_update
28+
1.4 mp_sgd_mom_update
29+
1.5 sgd_update
30+
2. signum_update
31+
3. rmspropalex_update
32+
4. ftml_update
33+
5. rmsprop_update
34+
6. ftrl_update
35+
7. adam_update
36+
"""
37+
38+
39+
def run_optimizer_operators_benchmarks(ctx=mx.cpu(), dtype='float32', warmup=25, runs=100):
40+
"""Runs benchmarks with the given context and precision (dtype) for all the neural network
41+
optimizer update operators in MXNet.
42+
43+
Parameters
44+
----------
45+
ctx: mx.ctx
46+
Context to run benchmarks
47+
dtype: str, default 'float32'
48+
Precision to use for benchmarks
49+
warmup: int, default 25
50+
Number of times to run for warmup
51+
runs: int, default 100
52+
Number of runs to capture benchmark results
53+
54+
Returns
55+
-------
56+
Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
57+
58+
"""
59+
# Fetch all optimizer operators
60+
mx_optimizer_ops = get_all_optimizer_operators()
61+
62+
# Run benchmarks
63+
mx_optimizer_op_results = run_op_benchmarks(mx_optimizer_ops, dtype, ctx, warmup, runs)
64+
return mx_optimizer_op_results

benchmark/opperf/opperf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from benchmark.opperf.nd_operations.nn_conv_operators import run_pooling_operators_benchmarks, \
4040
run_convolution_operators_benchmarks, run_transpose_convolution_operators_benchmarks
4141
from benchmark.opperf.nd_operations.nn_basic_operators import run_nn_basic_operators_benchmarks
42+
from benchmark.opperf.nd_operations.nn_optimizer_operators import run_optimizer_operators_benchmarks
4243
from benchmark.opperf.nd_operations.array_rearrange import run_rearrange_operators_benchmarks
4344

4445
from benchmark.opperf.utils.common_utils import merge_map_list, save_to_file
@@ -96,6 +97,8 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32'):
9697
# Run all Convolution operations benchmarks with default input values
9798
mxnet_operator_benchmark_results.append(run_convolution_operators_benchmarks(ctx=ctx, dtype=dtype))
9899

100+
# Run all Optimizer operations benchmarks with default input values
101+
mxnet_operator_benchmark_results.append(run_optimizer_operators_benchmarks(ctx=ctx, dtype=dtype))
99102
# Run all Transpose Convolution operations benchmarks with default input values
100103
mxnet_operator_benchmark_results.append(run_transpose_convolution_operators_benchmarks(ctx=ctx, dtype=dtype))
101104

benchmark/opperf/rules/default_params.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
# For operators like - random_uniform, random_normal etc..
3636
DEFAULT_SHAPE = [(1024, 1024), (10000, 1), (10000, 100)]
37+
DEFAULT_SAMPLE = [(2,)]
3738
DEFAULT_LOW = [0]
3839
DEFAULT_HIGH = [5]
3940
DEFAULT_K = [1]
@@ -62,6 +63,31 @@
6263
# NOTE: Data used is DEFAULT_DATA
6364
DEFAULT_AXIS = [0]
6465

66+
# For optimizer operators
67+
DEFAULT_WEIGHT = [(1024, 1024), (10000, 1), (10000, 100)]
68+
DEFAULT_GRAD = [(1024, 1024), (10000, 1), (10000, 100)]
69+
DEFAULT_MOM = [(1024, 1024), (10000, 1), (10000, 100)]
70+
DEFAULT_MEAN = [(1024, 1024), (10000, 1), (10000, 100)]
71+
DEFAULT_VAR = [(1024, 1024), (10000, 1), (10000, 100)]
72+
DEFAULT_N = [(1024, 1024), (10000, 1), (10000, 100)]
73+
DEFAULT_D = [(1024, 1024), (10000, 1), (10000, 100)]
74+
DEFAULT_V = [(1024, 1024), (10000, 1), (10000, 100)]
75+
DEFAULT_Z = [(1024, 1024), (10000, 1), (10000, 100)]
76+
DEFAULT_G = [(1024, 1024), (10000, 1), (10000, 100)]
77+
DEFAULT_DELTA = [(1024, 1024), (10000, 1), (10000, 100)]
78+
DEFAULT_LRS = [(0.1,0.1)]
79+
DEFAULT_LR = [0.1,0.5,0.9]
80+
DEFAULT_GAMMA_1 = [0.1,0.5,0.9]
81+
DEFAULT_GAMMA_2 = [0.1,0.5,0.9]
82+
DEFAULT_EPSILON = [1e-08]
83+
DEFAULT_BETA_1 = [0.1,0.5,0.9]
84+
DEFAULT_BETA_2 = [0.1,0.5,0.9]
85+
DEFAULT_T = [1,5]
86+
DEFAULT_RESCALE_GRAD = [0.4, 0.77]
87+
DEFAULT_CLIP_GRADIENT = [-1.0,0.8]
88+
DEFAULT_CLIP_WEIGHTS = [-1.0,0.8]
89+
DEFAULT_LAZY_UPDATE = [0,1]
90+
6591
# For rearrange operators
6692
# NOTE: Data needs to be a 4D tensor for operators like space_to_depth and depth_to_space
6793
# Hence below we append 4d to mark the difference.
@@ -71,8 +97,10 @@
7197
DEFAULT_DIM_2 = [1, 2, 3, 0]
7298
DEFAULT_BLOCK_SIZE = [2, 5]
7399

100+
74101
# Default Inputs. MXNet Op Param Name to Default Input mapping
75102
DEFAULTS_INPUTS = {"data": DEFAULT_DATA,
103+
"sample": DEFAULT_SAMPLE,
76104
"lhs": DEFAULT_LHS,
77105
"rhs": DEFAULT_RHS,
78106
"shape": DEFAULT_SHAPE,
@@ -91,16 +119,42 @@
91119
"p_nd": DEFAULT_P_ND,
92120
"axis_shape": DEFAULT_AXIS_SHAPE,
93121
"axis": DEFAULT_AXIS,
122+
"weight" : DEFAULT_WEIGHT,
123+
"weight32" : DEFAULT_WEIGHT,
124+
"grad" : DEFAULT_GRAD,
125+
"mean" : DEFAULT_MEAN,
126+
"var" : DEFAULT_VAR,
127+
"mom" : DEFAULT_MOM,
128+
"n" : DEFAULT_N,
129+
"d" : DEFAULT_D,
130+
"v" : DEFAULT_V,
131+
"z" : DEFAULT_Z,
132+
"g" : DEFAULT_G,
133+
"delta" : DEFAULT_DELTA,
134+
"lr" : DEFAULT_LR,
135+
"lrs" : DEFAULT_LRS,
136+
"wds" : DEFAULT_LRS,
137+
"gamma1" : DEFAULT_GAMMA_1,
138+
"gamma2" : DEFAULT_GAMMA_2,
139+
"epsilon" : DEFAULT_EPSILON,
140+
"beta1" : DEFAULT_BETA_1,
141+
"beta2" : DEFAULT_BETA_2,
142+
"t" : DEFAULT_T,
143+
"rescale_grad" : DEFAULT_RESCALE_GRAD,
144+
"clip_grad" : DEFAULT_CLIP_GRADIENT,
145+
"lazy_update" : DEFAULT_LAZY_UPDATE,
94146
"data_4d": DEFAULT_DATA_4d,
95147
"dim1": DEFAULT_DIM_1,
96148
"dim2": DEFAULT_DIM_2,
97149
"block_size": DEFAULT_BLOCK_SIZE}
98150

151+
99152
# These are names of MXNet operator parameters that is of type NDArray.
100153
# We maintain this list to automatically recognize these parameters are to be
101154
# given as NDArray and translate users inputs such as a shape tuple, Numpy Array or
102155
# a list to MXNet NDArray. This is just a convenience added so benchmark utility users
103156
# can just say shape of the tensor, and we automatically create Tensors.
104-
PARAMS_OF_TYPE_NDARRAY = ["lhs", "rhs", "data", "base", "exp",
157+
PARAMS_OF_TYPE_NDARRAY = ["lhs", "rhs", "data", "base", "exp", "sample",
105158
"mu", "sigma", "lam", "alpha", "beta", "gamma", "k", "p",
106-
"low", "high", "weight", "bias", "moving_mean", "moving_var"]
159+
"low", "high", "weight", "bias", "moving_mean", "moving_var",
160+
"weight", "weight32", "grad", "mean", "var", "mom", "n", "d", "v", "z", "g", "delta"]

benchmark/opperf/utils/op_registry_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,27 @@ def get_all_reduction_operators():
244244
return reduction_mx_operators
245245

246246

247+
def get_all_optimizer_operators():
248+
"""Gets all Optimizer operators registered with MXNet.
249+
250+
Returns
251+
-------
252+
{"operator_name": {"has_backward", "nd_op_handle", "params"}}
253+
"""
254+
optimizer_ops = ['mp_sgd_update', 'signum_update', 'rmspropalex_update', 'ftml_update', 'rmsprop_update',
255+
'sgd_mom_update', 'signsgd_update', 'mp_sgd_mom_update', 'ftrl_update', 'sgd_update',
256+
'adam_update']
257+
258+
# Get all mxnet operators
259+
mx_operators = _get_all_mxnet_operators()
260+
261+
# Filter for Optimizer operators
262+
optimizer_mx_operators = {}
263+
for op_name, op_params in mx_operators.items():
264+
if op_name in optimizer_ops and op_name not in unique_ops:
265+
optimizer_mx_operators[op_name] = mx_operators[op_name]
266+
return optimizer_mx_operators
267+
247268
def get_all_sorting_searching_operators():
248269
"""Gets all Sorting and Searching operators registered with MXNet.
249270

cpp-package/example/inference/README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ The following performance numbers are collected via using C++ inference API on A
4141
```
4242
export KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0
4343
export OMP_NUM_THREADS=$(vCPUs/2)
44-
export MXNET_SUBGRAPH_BACKEND=MKLDNN
4544
export MXNET_ENGINE_TYPE=NaiveEngine
4645
```
4746
Also users are recommended to use ```numactl``` or ```taskset``` to bind a running process to the specified cores.
@@ -87,8 +86,6 @@ Follow the below steps to do inference with more models.
8786

8887
The below command lines show how to run inference with FP32/INT8 resnet50_v1 model. Because the C++ inference script provides the almost same command line as this [Python script](https://github.com/apache/incubator-mxnet/blob/master/example/quantization/imagenet_inference.py) and then users can easily go from Python to C++.
8988
```
90-
# set MKLDNN as subgraph backend
91-
export MXNET_SUBGRAPH_BACKEND=MKLDNN
9289
9390
# FP32 inference
9491
./imagenet_inference --symbol_file "./model/resnet50_v1-symbol.json" --params_file "./model/resnet50_v1-0000.params" --dataset "./data/val_256_q90.rec" --rgb_mean "123.68 116.779 103.939" --rgb_std "58.393 57.12 57.375" --batch_size 64 --num_skipped_batches 50 --num_inference_batches 500

docs/_static/js/options.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
*/
2020

2121
/* Installation page display functions for install selector */
22-
var versionSelect = defaultVersion = 'v1.4.1';
22+
var versionSelect = defaultVersion = 'v1.5.0';
2323
var platformSelect = 'Linux';
2424
var languageSelect = 'Python';
2525
var processorSelect = 'CPU';

docs/_static/mxnet-theme/index.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
<div class="container">
2424
<div class="row">
2525
<div class="col-lg-4 col-sm-12">
26-
<h3>MXNet 1.4.1 Released</h3>
27-
<p>This patch release features bug fixes and performance improvements.</p>
28-
<a href="https://github.com/apache/incubator-mxnet/releases/tag/1.4.1">Learn More</a>
26+
<h3>MXNet 1.5.0 Released</h3>
27+
<p>This release features Automatic Mixed Precision, MKL-DNN updates, CUDA10.1 support and more. </p>
28+
<a href="https://github.com/apache/incubator-mxnet/releases/tag/1.5.0">Learn More</a>
2929
</div>
3030
<div class="col-lg-4 col-sm-12">
3131
<h3>A 60-minute Gluon Crash Course</h3>

docs/api/scala/symbol.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ The following example configures a two-layer neural network.
4141
val data = Symbol.Variable("data")
4242
val fc1 = Symbol.api.FullyConnected(Some(data), num_hidden = 128, name = "fc1")
4343
val act1 = Symbol.api.Activation(Some(fc1), "relu", "relu1")
44-
val fc2 = Symbol.api.FullyConnected(some(act1), num_hidden = 64, name = "fc2")
44+
val fc2 = Symbol.api.FullyConnected(Some(act1), num_hidden = 64, name = "fc2")
4545
val net = Symbol.api.SoftmaxOutput(Some(fc2), name = "out")
4646
:type net
4747
// org.apache.mxnet.Symbol

docs/faq/env_var.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
307307
- This variable controls how many CuDNN dropout state resources to create for each GPU context for use in operator.
308308

309309
* MXNET_SUBGRAPH_BACKEND
310-
- Values: String ```(default="")```
310+
- Values: String ```(default="MKLDNN")``` if MKLDNN is avaliable, otherwise ```(default="")```
311311
- This variable controls the subgraph partitioning in MXNet.
312312
- This variable is used to perform MKL-DNN FP32 operator fusion and quantization. Please refer to the [MKL-DNN operator list](../tutorials/mkldnn/operator_list.md) for how this variable is used and the list of fusion passes.
313+
- Set ```MXNET_SUBGRAPH_BACKEND=NONE``` to disable subgraph backend.
313314

314315
* MXNET_SAFE_ACCUMULATION
315316
- Values: Values: 0(false) or 1(true) ```(default=0)```

docs/install/download.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ These source archives are generated from tagged releases. Updates and patches wi
2121

2222
| Version | Source | PGP | SHA |
2323
|---------|-------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------|
24+
| 1.5.0 | [Download](https://apache.org/dist/incubator/mxnet/1.5.0/apache-mxnet-src-1.5.0-incubating.tar.gz) | [Download](https://apache.org/dist/incubator/mxnet/1.5.0/apache-mxnet-src-1.5.0-incubating.tar.gz.asc) | [Download](https://apache.org/dist/incubator/mxnet/1.5.0/apache-mxnet-src-1.5.0-incubating.tar.gz.sha512) |
2425
| 1.4.1 | [Download](https://www.apache.org/dyn/closer.cgi/incubator/mxnet/1.4.1/apache-mxnet-src-1.4.1-incubating.tar.gz) | [Download](https://apache.org/dist/incubator/mxnet/1.4.1/apache-mxnet-src-1.4.1-incubating.tar.gz.asc) | [Download](https://apache.org/dist/incubator/mxnet/1.4.1/apache-mxnet-src-1.4.1-incubating.tar.gz.sha512) |
2526
| 1.4.0 | [Download](https://www.apache.org/dyn/closer.cgi/incubator/mxnet/1.4.0/apache-mxnet-src-1.4.0-incubating.tar.gz) | [Download](https://apache.org/dist/incubator/mxnet/1.4.0/apache-mxnet-src-1.4.0-incubating.tar.gz.asc) | [Download](https://apache.org/dist/incubator/mxnet/1.4.0/apache-mxnet-src-1.4.0-incubating.tar.gz.sha512) |
2627
| 1.3.1 | [Download](https://www.apache.org/dyn/closer.cgi/incubator/mxnet/1.3.1/apache-mxnet-src-1.3.1-incubating.tar.gz) | [Download](https://apache.org/dist/incubator/mxnet/1.3.1/apache-mxnet-src-1.3.1-incubating.tar.gz.asc) | [Download](https://apache.org/dist/incubator/mxnet/1.3.1/apache-mxnet-src-1.3.1-incubating.tar.gz.sha512) |

0 commit comments

Comments
 (0)