Skip to content

Commit ccf0985

Browse files
committed
enable mkldnn acl backend for pytorch cpu libary
1 parent 202060c commit ccf0985

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

build_aarch64_wheel.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,11 @@ def build_OpenBLAS(host: RemoteHost, git_clone_flags: str = "") -> None:
223223
make_flags = "NUM_THREADS=64 USE_OPENMP=1 NO_SHARED=1 DYNAMIC_ARCH=1 TARGET=ARMV8"
224224
host.run_cmd(f"pushd OpenBLAS; make {make_flags} -j8; sudo make {make_flags} install; popd; rm -rf OpenBLAS")
225225

226+
def build_ArmComputeLibrary(host: RemoteHost, git_clone_flags: str = "") -> None:
227+
print('Building Arm Compute Library')
228+
host.run_cmd("mkdir $HOME/acl")
229+
host.run_cmd(f"git clone https://github.com/ARM-software/ComputeLibrary.git -b v22.05 {git_clone_flags}")
230+
host.run_cmd(f"pushd ComputeLibrary; export acl_install_dir=$HOME/acl; scons Werror=1 -j8 debug=0 neon=1 opencl=0 os=linux openmp=1 cppthreads=0 arch=armv8.2-a multi_isa=1 build=native build_dir=$acl_install_dir/build; cp -r arm_compute $acl_install_dir; cp -r include $acl_install_dir; cp -r utils $acl_install_dir; cp -r support $acl_install_dir; popd")
226231

227232
def build_FFTW(host: RemoteHost, git_clone_flags: str = "") -> None:
228233
print("Building FFTW3")
@@ -401,18 +406,19 @@ def build_torchaudio(host: RemoteHost, *,
401406
def configure_system(host: RemoteHost, *,
402407
compiler="gcc-8",
403408
use_conda=True,
404-
python_version="3.8") -> None:
409+
python_version="3.8",
410+
enable_mkldnn=False) -> None:
405411
if use_conda:
406412
install_condaforge_python(host, python_version)
407413

408414
print('Configuring the system')
409415
if not host.using_docker():
410416
update_apt_repo(host)
411-
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip pkg-config")
417+
host.run_cmd("sudo apt-get install -y ninja-build g++ git cmake gfortran unzip pkg-config scons")
412418
else:
413419
host.run_cmd("yum install -y sudo")
414420
host.run_cmd("yum install -y pkgconfig")
415-
host.run_cmd("conda install -y ninja")
421+
host.run_cmd("conda install -y ninja scons")
416422

417423
if not use_conda:
418424
host.run_cmd("sudo apt-get install -y python3-dev python3-yaml python3-setuptools python3-wheel python3-pip")
@@ -424,7 +430,7 @@ def configure_system(host: RemoteHost, *,
424430

425431
# Install and switch to gcc-10 on Ubuntu-20.04. This is required to support
426432
# SVE instruction set for newer aarch64 architectures
427-
if not host.using_docker() and host.ami == ubuntu20_04_ami and compiler == 'gcc-10':
433+
if not host.using_docker() and host.ami == ubuntu20_04_ami and (compiler == 'gcc-10' or enable_mkldnn):
428434
host.run_cmd("sudo apt-get install -y g++-10 gfortran-10")
429435
host.run_cmd("sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100")
430436
host.run_cmd("sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 100")
@@ -436,7 +442,8 @@ def start_build(host: RemoteHost, *,
436442
compiler="gcc-8",
437443
use_conda=True,
438444
python_version="3.8",
439-
shallow_clone=True) -> Tuple[str, str]:
445+
shallow_clone=True,
446+
enable_mkldnn=False) -> Tuple[str, str]:
440447
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
441448
if host.using_docker() and not use_conda:
442449
print("Auto-selecting conda option for docker images")
@@ -445,7 +452,8 @@ def start_build(host: RemoteHost, *,
445452
configure_system(host,
446453
compiler=compiler,
447454
use_conda=use_conda,
448-
python_version=python_version)
455+
python_version=python_version,
456+
enable_mkldnn=enable_mkldnn)
449457
build_OpenBLAS(host, git_clone_flags)
450458
# build_FFTW(host, git_clone_flags)
451459

@@ -474,7 +482,16 @@ def start_build(host: RemoteHost, *,
474482
build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1:branch.find('-')]} PYTORCH_BUILD_NUMBER=1"
475483
if host.using_docker():
476484
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
477-
host.run_cmd(f"cd pytorch ; {build_vars} python3 setup.py bdist_wheel")
485+
486+
if enable_mkldnn:
487+
build_ArmComputeLibrary(host, git_clone_flags)
488+
print("build pytorch with mkldnn+acl backend")
489+
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
490+
host.run_cmd(f"cd pytorch ; export ACL_ROOT_DIR=$HOME/acl; {build_vars} python3 setup.py bdist_wheel")
491+
else:
492+
print("build pytorch without mkldnn backend")
493+
host.run_cmd(f"cd pytorch ; {build_vars} python3 setup.py bdist_wheel")
494+
478495
print("Deleting build folder")
479496
host.run_cmd("cd pytorch; rm -rf build")
480497
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
@@ -625,6 +642,7 @@ def parse_arguments():
625642
parser.add_argument("--use-docker", action="store_true")
626643
parser.add_argument("--compiler", type=str, choices=['gcc-7', 'gcc-8', 'gcc-9', 'gcc-10', 'clang'], default="gcc-8")
627644
parser.add_argument("--use-torch-from-pypi", action="store_true")
645+
parser.add_argument("--enable-mkldnn", action="store_true")
628646
return parser.parse_args()
629647

630648

@@ -682,7 +700,8 @@ def parse_arguments():
682700
if args.use_torch_from_pypi:
683701
configure_system(host,
684702
compiler=args.compiler,
685-
python_version=python_version)
703+
python_version=python_version,
704+
enable_mkldnn=False)
686705
print("Installing PyTorch wheel")
687706
host.run_cmd("pip3 install torch")
688707
build_torchvision(host,
@@ -692,7 +711,10 @@ def parse_arguments():
692711
start_build(host,
693712
branch=args.branch,
694713
compiler=args.compiler,
695-
python_version=python_version)
714+
use_conda=True,
715+
python_version=python_version,
716+
shallow_clone=True,
717+
enable_mkldnn=args.enable_mkldnn)
696718
if not args.keep_running:
697719
print(f'Waiting for instance {inst.id} to terminate')
698720
inst.terminate()

0 commit comments

Comments
 (0)