Skip to content

Commit 7ae51b1

Browse files
committed
[aarch64] add mkldnn acl backend build support for pytorch cpu libary
1 parent ac931b5 commit 7ae51b1

File tree

1 file changed

+43
-13
lines changed

1 file changed

+43
-13
lines changed

build_aarch64_wheel.py

+43-13
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,13 @@ def build_OpenBLAS(host: RemoteHost, git_clone_flags: str = "") -> None:
224224
host.run_cmd(f"pushd OpenBLAS; make {make_flags} -j8; sudo make {make_flags} install; popd; rm -rf OpenBLAS")
225225

226226

227+
def build_ArmComputeLibrary(host: RemoteHost, git_clone_flags: str = "") -> None:
228+
print('Building Arm Compute Library')
229+
host.run_cmd("mkdir $HOME/acl")
230+
host.run_cmd(f"git clone https://github.com/ARM-software/ComputeLibrary.git -b v22.05 {git_clone_flags}")
231+
host.run_cmd(f"pushd ComputeLibrary; export acl_install_dir=$HOME/acl; scons Werror=1 -j8 debug=0 neon=1 opencl=0 os=linux openmp=1 cppthreads=0 arch=armv8.2-a multi_isa=1 build=native build_dir=$acl_install_dir/build; cp -r arm_compute $acl_install_dir; cp -r include $acl_install_dir; cp -r utils $acl_install_dir; cp -r support $acl_install_dir; popd")
232+
233+
227234
def build_FFTW(host: RemoteHost, git_clone_flags: str = "") -> None:
228235
print("Building FFTW3")
229236
host.run_cmd("sudo apt-get install -y ocaml ocamlbuild autoconf automake indent libtool fig2dev texinfo")
@@ -233,7 +240,7 @@ def build_FFTW(host: RemoteHost, git_clone_flags: str = "") -> None:
233240
host.run_cmd("pushd fftw3; sh bootstrap.sh; make -j8; sudo make install; popd")
234241

235242

236-
def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
243+
def embed_libgomp_acl(host: RemoteHost, use_conda, wheel_name, enable_mkldnn=False) -> None:
237244
host.run_cmd("pip3 install auditwheel")
238245
host.run_cmd("conda install -y patchelf" if use_conda else "sudo apt-get install -y patchelf")
239246
from tempfile import NamedTemporaryFile
@@ -244,7 +251,10 @@ def embed_libgomp(host: RemoteHost, use_conda, wheel_name) -> None:
244251

245252
print('Embedding libgomp into wheel')
246253
if host.using_docker():
247-
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
254+
if enable_mkldnn:
255+
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag --enable_mkldnn")
256+
else:
257+
host.run_cmd(f"python3 embed_library.py {wheel_name} --update-tag")
248258
else:
249259
host.run_cmd(f"python3 embed_library.py {wheel_name}")
250260

@@ -302,7 +312,7 @@ def build_torchvision(host: RemoteHost, *,
302312

303313
host.run_cmd(f"cd vision; {build_vars} python3 setup.py bdist_wheel")
304314
vision_wheel_name = host.list_dir("vision/dist")[0]
305-
embed_libgomp(host, use_conda, os.path.join('vision', 'dist', vision_wheel_name))
315+
embed_libgomp_acl(host, use_conda, os.path.join('vision', 'dist', vision_wheel_name))
306316

307317
print('Copying TorchVision wheel')
308318
host.download_wheel(os.path.join('vision', 'dist', vision_wheel_name))
@@ -344,7 +354,7 @@ def build_torchtext(host: RemoteHost, *,
344354

345355
host.run_cmd(f"cd text; {build_vars} python3 setup.py bdist_wheel")
346356
wheel_name = host.list_dir("text/dist")[0]
347-
embed_libgomp(host, use_conda, os.path.join('text', 'dist', wheel_name))
357+
embed_libgomp_acl(host, use_conda, os.path.join('text', 'dist', wheel_name))
348358

349359
print('Copying TorchText wheel')
350360
host.download_wheel(os.path.join('text', 'dist', wheel_name))
@@ -384,7 +394,7 @@ def build_torchaudio(host: RemoteHost, *,
384394

385395
host.run_cmd(f"cd audio; {build_vars} python3 setup.py bdist_wheel")
386396
wheel_name = host.list_dir("audio/dist")[0]
387-
embed_libgomp(host, use_conda, os.path.join('audio', 'dist', wheel_name))
397+
embed_libgomp_acl(host, use_conda, os.path.join('audio', 'dist', wheel_name))
388398

389399
print('Copying TorchAudio wheel')
390400
host.download_wheel(os.path.join('audio', 'dist', wheel_name))
@@ -395,7 +405,8 @@ def build_torchaudio(host: RemoteHost, *,
395405
def configure_system(host: RemoteHost, *,
396406
compiler="gcc-8",
397407
use_conda=True,
398-
python_version="3.8") -> None:
408+
python_version="3.8",
409+
enable_mkldnn=False) -> None:
399410
if use_conda:
400411
install_condaforge_python(host, python_version)
401412

@@ -405,7 +416,7 @@ def configure_system(host: RemoteHost, *,
405416
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip")
406417
else:
407418
host.run_cmd("yum install -y sudo")
408-
host.run_cmd("conda install -y ninja")
419+
host.run_cmd("conda install -y ninja scons")
409420

410421
if not use_conda:
411422
host.run_cmd("sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip")
@@ -427,16 +438,21 @@ def start_build(host: RemoteHost, *,
427438
compiler="gcc-8",
428439
use_conda=True,
429440
python_version="3.8",
430-
shallow_clone=True) -> Tuple[str, str]:
441+
shallow_clone=True,
442+
enable_mkldnn=False) -> Tuple[str, str]:
431443
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
432444
if host.using_docker() and not use_conda:
433445
print("Auto-selecting conda option for docker images")
434446
use_conda = True
447+
if not host.using_docker():
448+
print("Diable mkldnn for host builds")
449+
enable_mkldnn = False
435450

436451
configure_system(host,
437452
compiler=compiler,
438453
use_conda=use_conda,
439-
python_version=python_version)
454+
python_version=python_version,
455+
enable_mkldnn=enable_mkldnn)
440456
build_OpenBLAS(host, git_clone_flags)
441457
# build_FFTW(host, git_clone_flags)
442458

@@ -465,11 +481,18 @@ def start_build(host: RemoteHost, *,
465481
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1:branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
466482
if host.using_docker():
467483
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
468-
host.run_cmd(f"cd pytorch ; {build_vars} python3 setup.py bdist_wheel")
484+
if enable_mkldnn:
485+
build_ArmComputeLibrary(host, git_clone_flags)
486+
print("build pytorch with mkldnn+acl backend")
487+
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
488+
host.run_cmd(f"cd pytorch ; export ACL_ROOT_DIR=$HOME/acl; {build_vars} python3 setup.py bdist_wheel")
489+
else:
490+
print("build pytorch without mkldnn backend")
491+
host.run_cmd(f"cd pytorch ; {build_vars} python3 setup.py bdist_wheel")
469492
print("Deleting build folder")
470493
host.run_cmd("cd pytorch; rm -rf build")
471494
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
472-
embed_libgomp(host, use_conda, os.path.join('pytorch', 'dist', pytorch_wheel_name))
495+
embed_libgomp_acl(host, use_conda, os.path.join('pytorch', 'dist', pytorch_wheel_name), enable_mkldnn)
473496
print('Copying the wheel')
474497
host.download_wheel(os.path.join('pytorch', 'dist', pytorch_wheel_name))
475498

@@ -556,6 +579,10 @@ def embed_library(whl_path, lib_soname, update_tag=False):
556579
557580
if __name__ == '__main__':
558581
embed_library(sys.argv[1], 'libgomp.so.1', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
582+
if (len(sys.argv) > 3 and sys.argv[3] == '--enable_mkldnn'):
583+
embed_library(sys.argv[1], 'libarm_compute.so', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
584+
embed_library(sys.argv[1], 'libarm_compute_graph.so', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
585+
embed_library(sys.argv[1], 'libarm_compute_core.so', len(sys.argv) > 2 and sys.argv[2] == '--update-tag')
559586
"""
560587

561588

@@ -616,6 +643,7 @@ def parse_arguments():
616643
parser.add_argument("--use-docker", action="store_true")
617644
parser.add_argument("--compiler", type=str, choices=['gcc-7', 'gcc-8', 'gcc-9', 'clang'], default="gcc-8")
618645
parser.add_argument("--use-torch-from-pypi", action="store_true")
646+
parser.add_argument("--enable-mkldnn", action="store_true")
619647
return parser.parse_args()
620648

621649

@@ -673,7 +701,8 @@ def parse_arguments():
673701
if args.use_torch_from_pypi:
674702
configure_system(host,
675703
compiler=args.compiler,
676-
python_version=python_version)
704+
python_version=python_version,
705+
enable_mkldnn=False)
677706
print("Installing PyTorch wheel")
678707
host.run_cmd("pip3 install torch")
679708
build_torchvision(host,
@@ -683,7 +712,8 @@ def parse_arguments():
683712
start_build(host,
684713
branch=args.branch,
685714
compiler=args.compiler,
686-
python_version=python_version)
715+
python_version=python_version,
716+
enable_mkldnn=args.enable_mkldnn)
687717
if not args.keep_running:
688718
print(f'Waiting for instance {inst.id} to terminate')
689719
inst.terminate()

0 commit comments

Comments
 (0)