@@ -224,6 +224,13 @@ def build_OpenBLAS(host: RemoteHost, git_clone_flags: str = "") -> None:
224
224
host .run_cmd (f"pushd OpenBLAS; make { make_flags } -j8; sudo make { make_flags } install; popd; rm -rf OpenBLAS" )
225
225
226
226
227
+ def build_ArmComputeLibrary (host : RemoteHost , git_clone_flags : str = "" ) -> None :
228
+ print ('Building Arm Compute Library' )
229
+ host .run_cmd ("mkdir $HOME/acl" )
230
+ host .run_cmd (f"git clone https://github.com/ARM-software/ComputeLibrary.git -b v22.05 { git_clone_flags } " )
231
+ host .run_cmd (f"pushd ComputeLibrary; export acl_install_dir=$HOME/acl; scons Werror=1 -j8 debug=0 neon=1 opencl=0 os=linux openmp=1 cppthreads=0 arch=armv8.2-a multi_isa=1 build=native build_dir=$acl_install_dir/build; cp -r arm_compute $acl_install_dir; cp -r include $acl_install_dir; cp -r utils $acl_install_dir; cp -r support $acl_install_dir; popd" )
232
+
233
+
227
234
def build_FFTW (host : RemoteHost , git_clone_flags : str = "" ) -> None :
228
235
print ("Building FFTW3" )
229
236
host .run_cmd ("sudo apt-get install -y ocaml ocamlbuild autoconf automake indent libtool fig2dev texinfo" )
@@ -233,7 +240,7 @@ def build_FFTW(host: RemoteHost, git_clone_flags: str = "") -> None:
233
240
host .run_cmd ("pushd fftw3; sh bootstrap.sh; make -j8; sudo make install; popd" )
234
241
235
242
236
- def embed_libgomp (host : RemoteHost , use_conda , wheel_name ) -> None :
243
+ def embed_libgomp_acl (host : RemoteHost , use_conda , wheel_name , enable_mkldnn = False ) -> None :
237
244
host .run_cmd ("pip3 install auditwheel" )
238
245
host .run_cmd ("conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf" )
239
246
from tempfile import NamedTemporaryFile
@@ -244,7 +251,10 @@ def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
244
251
245
252
print ('Embedding libgomp into wheel' )
246
253
if host .using_docker ():
247
- host .run_cmd (f"python3 embed_library.py { wheel_name } --update-tag" )
254
+ if enable_mkldnn :
255
+ host .run_cmd (f"python3 embed_library.py { wheel_name } --update-tag --enable_mkldnn" )
256
+ else :
257
+ host .run_cmd (f"python3 embed_library.py { wheel_name } --update-tag" )
248
258
else :
249
259
host .run_cmd (f"python3 embed_library.py { wheel_name } " )
250
260
@@ -302,7 +312,7 @@ def build_torchvision(host: RemoteHost, *,
302
312
303
313
host .run_cmd (f"cd vision; { build_vars } python3 setup.py bdist_wheel" )
304
314
vision_wheel_name = host .list_dir ("vision/dist" )[0 ]
305
- embed_libgomp (host , use_conda , os .path .join ('vision' , 'dist' , vision_wheel_name ))
315
+ embed_libgomp_acl (host , use_conda , os .path .join ('vision' , 'dist' , vision_wheel_name ))
306
316
307
317
print ('Copying TorchVision wheel' )
308
318
host .download_wheel (os .path .join ('vision' , 'dist' , vision_wheel_name ))
@@ -344,7 +354,7 @@ def build_torchtext(host: RemoteHost, *,
344
354
345
355
host .run_cmd (f"cd text; { build_vars } python3 setup.py bdist_wheel" )
346
356
wheel_name = host .list_dir ("text/dist" )[0 ]
347
- embed_libgomp (host , use_conda , os .path .join ('text' , 'dist' , wheel_name ))
357
+ embed_libgomp_acl (host , use_conda , os .path .join ('text' , 'dist' , wheel_name ))
348
358
349
359
print ('Copying TorchText wheel' )
350
360
host .download_wheel (os .path .join ('text' , 'dist' , wheel_name ))
@@ -384,7 +394,7 @@ def build_torchaudio(host: RemoteHost, *,
384
394
385
395
host .run_cmd (f"cd audio; { build_vars } python3 setup.py bdist_wheel" )
386
396
wheel_name = host .list_dir ("audio/dist" )[0 ]
387
- embed_libgomp (host , use_conda , os .path .join ('audio' , 'dist' , wheel_name ))
397
+ embed_libgomp_acl (host , use_conda , os .path .join ('audio' , 'dist' , wheel_name ))
388
398
389
399
print ('Copying TorchAudio wheel' )
390
400
host .download_wheel (os .path .join ('audio' , 'dist' , wheel_name ))
@@ -395,7 +405,8 @@ def build_torchaudio(host: RemoteHost, *,
395
405
def configure_system (host : RemoteHost , * ,
396
406
compiler = "gcc-8" ,
397
407
use_conda = True ,
398
- python_version = "3.8" ) -> None :
408
+ python_version = "3.8" ,
409
+ enable_mkldnn = False ) -> None :
399
410
if use_conda :
400
411
install_condaforge_python (host , python_version )
401
412
@@ -405,7 +416,7 @@ def configure_system(host: RemoteHost, *,
405
416
host .run_cmd ("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip" )
406
417
else :
407
418
host .run_cmd ("yum install -y sudo" )
408
- host .run_cmd ("conda install -y ninja" )
419
+ host .run_cmd ("conda install -y ninja scons " )
409
420
410
421
if not use_conda :
411
422
host .run_cmd ("sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip" )
@@ -427,16 +438,21 @@ def start_build(host: RemoteHost, *,
427
438
compiler = "gcc-8" ,
428
439
use_conda = True ,
429
440
python_version = "3.8" ,
430
- shallow_clone = True ) -> Tuple [str , str ]:
441
+ shallow_clone = True ,
442
+ enable_mkldnn = False ) -> Tuple [str , str ]:
431
443
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
432
444
if host .using_docker () and not use_conda :
433
445
print ("Auto-selecting conda option for docker images" )
434
446
use_conda = True
447
+ if not host .using_docker ():
448
+ print ("Diable mkldnn for host builds" )
449
+ enable_mkldnn = False
435
450
436
451
configure_system (host ,
437
452
compiler = compiler ,
438
453
use_conda = use_conda ,
439
- python_version = python_version )
454
+ python_version = python_version ,
455
+ enable_mkldnn = enable_mkldnn )
440
456
build_OpenBLAS (host , git_clone_flags )
441
457
# build_FFTW(host, git_clone_flags)
442
458
@@ -465,11 +481,18 @@ def start_build(host: RemoteHost, *,
465
481
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={ branch [1 :branch .find ('-' )]} PYTORCH_BUILD_NUMBER=1"
466
482
if host .using_docker ():
467
483
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
468
- host .run_cmd (f"cd pytorch ; { build_vars } python3 setup.py bdist_wheel" )
484
+ if enable_mkldnn :
485
+ build_ArmComputeLibrary (host , git_clone_flags )
486
+ print ("build pytorch with mkldnn+acl backend" )
487
+ build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
488
+ host .run_cmd (f"cd pytorch ; export ACL_ROOT_DIR=$HOME/acl; { build_vars } python3 setup.py bdist_wheel" )
489
+ else :
490
+ print ("build pytorch without mkldnn backend" )
491
+ host .run_cmd (f"cd pytorch ; { build_vars } python3 setup.py bdist_wheel" )
469
492
print ("Deleting build folder" )
470
493
host .run_cmd ("cd pytorch; rm -rf build" )
471
494
pytorch_wheel_name = host .list_dir ("pytorch/dist" )[0 ]
472
- embed_libgomp (host , use_conda , os .path .join ('pytorch' , 'dist' , pytorch_wheel_name ))
495
+ embed_libgomp_acl (host , use_conda , os .path .join ('pytorch' , 'dist' , pytorch_wheel_name ), enable_mkldnn )
473
496
print ('Copying the wheel' )
474
497
host .download_wheel (os .path .join ('pytorch' , 'dist' , pytorch_wheel_name ))
475
498
@@ -556,6 +579,10 @@ def embed_library(whl_path, lib_soname, update_tag=False):
556
579
557
580
if __name__ == '__main__':
558
581
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
582
+ if (len(sys.argv) > 3 and sys.argv[3] == '--enable_mkldnn'):
583
+ embed_library(sys.argv[1], 'libarm_compute.so', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
584
+ embed_library(sys.argv[1], 'libarm_compute_graph.so', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
585
+ embed_library(sys.argv[1], 'libarm_compute_core.so', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
559
586
"""
560
587
561
588
@@ -616,6 +643,7 @@ def parse_arguments():
616
643
parser .add_argument ("--use-docker" , action = "store_true" )
617
644
parser .add_argument ("--compiler" , type = str , choices = ['gcc-7' , 'gcc-8' , 'gcc-9' , 'clang' ], default = "gcc-8" )
618
645
parser .add_argument ("--use-torch-from-pypi" , action = "store_true" )
646
+ parser .add_argument ("--enable-mkldnn" , action = "store_true" )
619
647
return parser .parse_args ()
620
648
621
649
@@ -673,7 +701,8 @@ def parse_arguments():
673
701
if args .use_torch_from_pypi :
674
702
configure_system (host ,
675
703
compiler = args .compiler ,
676
- python_version = python_version )
704
+ python_version = python_version ,
705
+ enable_mkldnn = False )
677
706
print ("Installing PyTorch wheel" )
678
707
host .run_cmd ("pip3 install torch" )
679
708
build_torchvision (host ,
@@ -683,7 +712,8 @@ def parse_arguments():
683
712
start_build (host ,
684
713
branch = args .branch ,
685
714
compiler = args .compiler ,
686
- python_version = python_version )
715
+ python_version = python_version ,
716
+ enable_mkldnn = args .enable_mkldnn )
687
717
if not args .keep_running :
688
718
print (f'Waiting for instance { inst .id } to terminate' )
689
719
inst .terminate ()
0 commit comments