Skip to content

Commit ee5738c

Browse files
committed
upgrade python and pytorch version for darts-cnn-cifar10 and pytorch-mninst examples
1 parent 9ee8fda commit ee5738c

File tree

8 files changed

+16
-19
lines changed

8 files changed

+16
-19
lines changed

examples/v1beta1/trial-images/darts-cnn-cifar10/Dockerfile.cpu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# TODO (tenzen-y): Upgrade Python version and Pytorch version
2-
FROM python:3.7-slim
1+
FROM python:3.9-slim
32

43
ENV TARGET_DIR /opt/darts-cnn-cifar10
54

examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import torch
1615
import copy
1716

@@ -47,13 +46,13 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
4746
gradients = torch.autograd.grad(loss, self.model.getWeights())
4847

4948
# Do virtual step (Update gradient)
50-
# Below opeartions do not need gradient tracking
49+
# Below operations do not need gradient tracking
5150
with torch.no_grad():
5251
# dict key is not the value, but the pointer. So original network weight have to
5352
# be iterated also.
5453
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
5554
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
56-
vw.copy_(w - xi * (m + g + self.w_weight_decay * w))
55+
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
5756

5857
# Sync alphas
5958
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()):
@@ -85,7 +84,7 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
8584
# Update final gradient = dalpha - xi * hessian
8685
with torch.no_grad():
8786
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
88-
alpha.grad = da - xi * h
87+
alpha.grad = da - torch.FloatTensor(xi) * h
8988

9089
def compute_hessian(self, dws, train_x, train_y):
9190
"""
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
torch==1.0.0
2-
torchvision==0.2.1
3-
Pillow==6.2.2
1+
torch==1.11.0
2+
torchvision==0.12.0
3+
Pillow>=9.1.1

examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,15 @@ def main():
146146
best_top1 = 0.
147147

148148
for epoch in range(num_epochs):
149-
lr_scheduler.step()
150-
lr = lr_scheduler.get_lr()[0]
149+
lr = lr_scheduler.get_last_lr()
151150

152151
model.print_alphas()
153152

154153
# Training
155154
print(">>> Training")
156155
train(train_loader, valid_loader, model, architect, w_optim, alpha_optim,
157156
lr, epoch, num_epochs, device, w_grad_clip, print_step)
157+
lr_scheduler.step()
158158

159159
# Validation
160160
print("\n>>> Validation")

examples/v1beta1/trial-images/darts-cnn-cifar10/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def update(self, val, n=1):
3838

3939

4040
def accuracy(output, target, topk=(1,)):
41-
""" Computes the precision@k for the specified values of k """
41+
""" Computes the precision@k for the specified values of k """
4242
maxk = max(topk)
4343
batch_size = target.size(0)
4444

@@ -53,7 +53,7 @@ def accuracy(output, target, topk=(1,)):
5353

5454
res = []
5555
for k in topk:
56-
correct_k = correct[:k].view(-1).float().sum(0)
56+
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
5757
res.append(correct_k.mul_(1.0 / batch_size))
5858

5959
return res

examples/v1beta1/trial-images/pytorch-mnist/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# TODO (tenzen-y): Upgrade Python version and Pytorch version
2-
FROM python:3.7-slim
1+
FROM python:3.9-slim
32

43
ADD examples/v1beta1/trial-images/pytorch-mnist /opt/pytorch-mnist
54
WORKDIR /opt/pytorch-mnist
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
cloudml-hypertune==0.1.0.dev6
2-
torch==1.0.0
3-
torchvision==0.2.1
4-
Pillow==6.2.2
2+
torch==1.11.0
3+
torchvision==0.12.0
4+
Pillow>=9.1.1

hack/verify-yamllint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ if [ -z "$(command -v yamllint)" ]; then
2525
fi
2626

2727
echo 'Running yamllint'
28-
yamllint -d "{extends: default, rules: {line-length: disable}}" examples/* manifests/*
28+
yamllint -d "{extends: default, rules: {line-length: disable}}" examples/* manifests/*

0 commit comments

Comments
 (0)