@@ -224,6 +224,13 @@ def build_OpenBLAS(host: RemoteHost, git_clone_flags: str = "") -> None:
224
224
host .run_cmd (f"pushd OpenBLAS; make { make_flags } -j8; sudo make { make_flags } install; popd; rm -rf OpenBLAS" )
225
225
226
226
227
+ def build_ArmComputeLibrary (host : RemoteHost , git_clone_flags : str = "" ) -> None :
228
+ print ('Building Arm Compute Library' )
229
+ host .run_cmd ("mkdir $HOME/acl" )
230
+ host .run_cmd (f"git clone https://github.com/ARM-software/ComputeLibrary.git -b v22.05 { git_clone_flags } " )
231
+ host .run_cmd (f"pushd ComputeLibrary; export acl_install_dir=$HOME/acl; scons Werror=1 -j8 debug=0 neon=1 opencl=0 os=linux openmp=1 cppthreads=0 arch=armv8.2-a multi_isa=1 build=native build_dir=$acl_install_dir/build; cp -r arm_compute $acl_install_dir; cp -r include $acl_install_dir; cp -r utils $acl_install_dir; cp -r support $acl_install_dir; popd" )
232
+
233
+
227
234
def build_FFTW (host : RemoteHost , git_clone_flags : str = "" ) -> None :
228
235
print ("Building FFTW3" )
229
236
host .run_cmd ("sudo apt-get install -y ocaml ocamlbuild autoconf automake indent libtool fig2dev texinfo" )
@@ -395,7 +402,8 @@ def build_torchaudio(host: RemoteHost, *,
395
402
def configure_system (host : RemoteHost , * ,
396
403
compiler = "gcc-8" ,
397
404
use_conda = True ,
398
- python_version = "3.8" ) -> None :
405
+ python_version = "3.8" ,
406
+ enable_mkldnn = False ) -> None :
399
407
if use_conda :
400
408
install_condaforge_python (host , python_version )
401
409
@@ -405,7 +413,7 @@ def configure_system(host: RemoteHost, *,
405
413
host .run_cmd ("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip" )
406
414
else :
407
415
host .run_cmd ("yum install -y sudo" )
408
- host .run_cmd ("conda install -y ninja" )
416
+ host .run_cmd ("conda install -y ninja scons " )
409
417
410
418
if not use_conda :
411
419
host .run_cmd ("sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip" )
@@ -427,16 +435,21 @@ def start_build(host: RemoteHost, *,
427
435
compiler = "gcc-8" ,
428
436
use_conda = True ,
429
437
python_version = "3.8" ,
430
- shallow_clone = True ) -> Tuple [str , str ]:
438
+ shallow_clone = True ,
439
+ enable_mkldnn = False ) -> Tuple [str , str ]:
431
440
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
432
441
if host .using_docker () and not use_conda :
433
442
print ("Auto-selecting conda option for docker images" )
434
443
use_conda = True
444
+ if not host .using_docker ():
445
+ print ("Diable mkldnn for host builds" )
446
+ enable_mkldnn = False
435
447
436
448
configure_system (host ,
437
449
compiler = compiler ,
438
450
use_conda = use_conda ,
439
- python_version = python_version )
451
+ python_version = python_version ,
452
+ enable_mkldnn = enable_mkldnn )
440
453
build_OpenBLAS (host , git_clone_flags )
441
454
# build_FFTW(host, git_clone_flags)
442
455
@@ -465,7 +478,21 @@ def start_build(host: RemoteHost, *,
465
478
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={ branch [1 :branch .find ('-' )]} PYTORCH_BUILD_NUMBER=1"
466
479
if host .using_docker ():
467
480
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
468
- host .run_cmd (f"cd pytorch ; { build_vars } python3 setup.py bdist_wheel" )
481
+ if enable_mkldnn :
482
+ build_ArmComputeLibrary (host , git_clone_flags )
483
+ print ("build pytorch with mkldnn+acl backend" )
484
+ build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
485
+ host .run_cmd (f"cd pytorch ; export ACL_ROOT_DIR=$HOME/acl; { build_vars } python3 setup.py bdist_wheel" )
486
+ print ('Repair the wheel' )
487
+ pytorch_wheel_name = host .list_dir ("pytorch/dist" )[0 ]
488
+ host .run_cmd (f"export LD_LIBRARY_PATH=$HOME/acl/build:$HOME/pytorch/build/lib; auditwheel repair $HOME/pytorch/dist/{ pytorch_wheel_name } " )
489
+ print ('replace the original wheel with the repaired one' )
490
+ pytorch_repaired_wheel_name = host .list_dir ("wheelhouse" )[0 ]
491
+ host .run_cmd (f"cp $HOME/wheelhouse/{ pytorch_repaired_wheel_name } $HOME/pytorch/dist/{ pytorch_wheel_name } " )
492
+ else :
493
+ print ("build pytorch without mkldnn backend" )
494
+ host .run_cmd (f"cd pytorch ; { build_vars } python3 setup.py bdist_wheel" )
495
+
469
496
print ("Deleting build folder" )
470
497
host .run_cmd ("cd pytorch; rm -rf build" )
471
498
pytorch_wheel_name = host .list_dir ("pytorch/dist" )[0 ]
@@ -616,6 +643,7 @@ def parse_arguments():
616
643
parser .add_argument ("--use-docker" , action = "store_true" )
617
644
parser .add_argument ("--compiler" , type = str , choices = ['gcc-7' , 'gcc-8' , 'gcc-9' , 'clang' ], default = "gcc-8" )
618
645
parser .add_argument ("--use-torch-from-pypi" , action = "store_true" )
646
+ parser .add_argument ("--enable-mkldnn" , action = "store_true" )
619
647
return parser .parse_args ()
620
648
621
649
@@ -673,7 +701,8 @@ def parse_arguments():
673
701
if args .use_torch_from_pypi :
674
702
configure_system (host ,
675
703
compiler = args .compiler ,
676
- python_version = python_version )
704
+ python_version = python_version ,
705
+ enable_mkldnn = False )
677
706
print ("Installing PyTorch wheel" )
678
707
host .run_cmd ("pip3 install torch" )
679
708
build_torchvision (host ,
@@ -683,7 +712,8 @@ def parse_arguments():
683
712
start_build (host ,
684
713
branch = args .branch ,
685
714
compiler = args .compiler ,
686
- python_version = python_version )
715
+ python_version = python_version ,
716
+ enable_mkldnn = args .enable_mkldnn )
687
717
if not args .keep_running :
688
718
print (f'Waiting for instance { inst .id } to terminate' )
689
719
inst .terminate ()
0 commit comments