@@ -475,6 +475,7 @@ def start_build(host: RemoteHost, *,
475
475
compiler = "gcc-8" ,
476
476
use_conda = True ,
477
477
python_version = "3.8" ,
478
+ pytorch_only :bool = False ,
478
479
shallow_clone = True ,
479
480
enable_mkldnn = False ) -> Tuple [str , str ]:
480
481
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
@@ -543,6 +544,8 @@ def start_build(host: RemoteHost, *,
543
544
print ('Installing PyTorch wheel' )
544
545
host .run_cmd (f"pip3 install pytorch/dist/{ pytorch_wheel_name } " )
545
546
547
+ if pytorch_only :
548
+ return pytorch_wheel_name , None
546
549
vision_wheel_name = build_torchvision (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
547
550
build_torchaudio (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
548
551
build_torchtext (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
@@ -674,9 +677,10 @@ def parse_arguments():
674
677
parser .add_argument ("--build-only" , action = "store_true" )
675
678
parser .add_argument ("--test-only" , type = str )
676
679
parser .add_argument ("--os" , type = str , choices = list (os_amis .keys ()), default = 'ubuntu18_04' )
677
- parser .add_argument ("--python-version" , type = str , choices = ['3.6' , '3.7' , '3.8' , '3.9' , '3.10' ], default = None )
680
+ parser .add_argument ("--python-version" , type = str , choices = ['3.6' , '3.7' , '3.8' , '3.9' , '3.10' , '3.11' ], default = None )
678
681
parser .add_argument ("--alloc-instance" , action = "store_true" )
679
682
parser .add_argument ("--list-instances" , action = "store_true" )
683
+ parser .add_argument ("--pytorch-only" , action = "store_true" )
680
684
parser .add_argument ("--keep-running" , action = "store_true" )
681
685
parser .add_argument ("--terminate-instances" , action = "store_true" )
682
686
parser .add_argument ("--instance-type" , type = str , default = "t4g.2xlarge" )
@@ -754,6 +758,7 @@ def parse_arguments():
754
758
branch = args .branch ,
755
759
compiler = args .compiler ,
756
760
python_version = python_version ,
761
+ pytorch_only = args .pytorch_only ,
757
762
enable_mkldnn = args .enable_mkldnn )
758
763
if not args .keep_running :
759
764
print (f'Waiting for instance { inst .id } to terminate' )
0 commit comments