Skip to content

Commit 280eaed

Browse files
committed
chore: gofmt, black, and prettier run across PBT changes
1 parent 7402c48 commit 280eaed

File tree

11 files changed

+381
-175
lines changed

11 files changed

+381
-175
lines changed

examples/v1beta1/trial-images/simple-pbt/pbt_test.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import tensorflow as tf
1212
import time
1313

14-
class PBTBenchmarkExample():
14+
15+
class PBTBenchmarkExample:
1516
"""Toy PBT problem for benchmarking adaptive learning rate.
1617
The goal is to optimize this trainable's accuracy. The accuracy increases
1718
fastest at the optimal lr, which is a function of the current accuracy.
@@ -36,24 +37,23 @@ def __init__(self, lr, log_dir: str, log_interval: int, checkpoint: str):
3637
self._log_interval = log_interval
3738
self._lr = lr
3839

39-
self._checkpoint_file = os.path.join(checkpoint, 'training.ckpt')
40+
self._checkpoint_file = os.path.join(checkpoint, "training.ckpt")
4041
if os.path.exists(self._checkpoint_file):
41-
with open(self._checkpoint_file, 'rb') as fin:
42+
with open(self._checkpoint_file, "rb") as fin:
4243
checkpoint_data = pickle.load(fin)
43-
self._accuracy = checkpoint_data['accuracy']
44-
self._step = checkpoint_data['step']
44+
self._accuracy = checkpoint_data["accuracy"]
45+
self._step = checkpoint_data["step"]
4546
else:
4647
os.makedirs(checkpoint, exist_ok=True)
4748
self._step = 1
4849
self._accuracy = 0.0
49-
5050

5151
def save_checkpoint(self):
52-
with open(self._checkpoint_file, 'wb') as fout:
53-
pickle.dump({'step': self._step, 'accuracy': self._accuracy}, fout)
52+
with open(self._checkpoint_file, "wb") as fout:
53+
pickle.dump({"step": self._step, "accuracy": self._accuracy}, fout)
5454

5555
def step(self):
56-
midpoint = 100 # lr starts decreasing after acc > midpoint
56+
midpoint = 100 # lr starts decreasing after acc > midpoint
5757
q_tolerance = 3 # penalize exceeding lr by more than this multiple
5858
noise_level = 2 # add gaussian noise to the acc increase
5959
# triangle wave:
@@ -80,32 +80,53 @@ def step(self):
8080
if not self._writer:
8181
self._writer = tf.summary.create_file_writer(self._log_dir)
8282
with self._writer.as_default():
83-
tf.summary.scalar("Validation-accuracy", self._accuracy, step=self._step)
83+
tf.summary.scalar(
84+
"Validation-accuracy", self._accuracy, step=self._step
85+
)
8486
tf.summary.scalar("lr", self._lr, step=self._step)
8587
self._writer.flush()
8688

8789
self._step += 1
8890

8991
def __repr__(self):
90-
return "epoch {}:\nlr={:0.4f}\nValidation-accuracy={:0.4f}".format(self._step, self._lr, self._accuracy)
92+
return "epoch {}:\nlr={:0.4f}\nValidation-accuracy={:0.4f}".format(
93+
self._step, self._lr, self._accuracy
94+
)
9195

9296

9397
if __name__ == "__main__":
9498
# Parse CLI arguments
95-
parser = argparse.ArgumentParser(description='PBT Basic Test')
96-
parser.add_argument('--lr', type=float, default=0.0001,
97-
help='learning rate (default: 0.0001)')
98-
parser.add_argument('--epochs', type=int, default=20,
99-
help='number of epochs to train (default: 20)')
100-
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
101-
help='how many batches to wait before logging training status (default: 1)')
102-
parser.add_argument('--log-path', type=str, default="/var/log/katib/tfevent/",
103-
help='tfevent output path (default: /var/log/katib/tfevent/)')
104-
parser.add_argument('--checkpoint', type=str, default="/var/log/katib/checkpoints/",
105-
help='checkpoint directory (resume and save)')
99+
parser = argparse.ArgumentParser(description="PBT Basic Test")
100+
parser.add_argument(
101+
"--lr", type=float, default=0.0001, help="learning rate (default: 0.0001)"
102+
)
103+
parser.add_argument(
104+
"--epochs", type=int, default=20, help="number of epochs to train (default: 20)"
105+
)
106+
parser.add_argument(
107+
"--log-interval",
108+
type=int,
109+
default=10,
110+
metavar="N",
111+
help="how many batches to wait before logging training status (default: 1)",
112+
)
113+
parser.add_argument(
114+
"--log-path",
115+
type=str,
116+
default="/var/log/katib/tfevent/",
117+
help="tfevent output path (default: /var/log/katib/tfevent/)",
118+
)
119+
parser.add_argument(
120+
"--checkpoint",
121+
type=str,
122+
default="/var/log/katib/checkpoints/",
123+
help="checkpoint directory (resume and save)",
124+
)
106125
opt = parser.parse_args()
107126

108-
benchmark = PBTBenchmarkExample(opt.lr, opt.log_path, opt.log_interval, opt.checkpoint)
127+
benchmark = PBTBenchmarkExample(
128+
opt.lr, opt.log_path, opt.log_interval, opt.checkpoint
129+
)
109130
for i in range(opt.epochs):
110131
benchmark.step()
111132
time.sleep(0.2)

pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ func (g *General) SyncAssignments(
184184
if responseSuggestion.Annotations != nil {
185185
assignment.Annotations = responseSuggestion.Annotations[n].Annotations
186186
}
187-
trialAssignments = append(trialAssignments, assignment)
187+
trialAssignments = append(trialAssignments, assignment)
188188
}
189189

190190
instance.Status.Suggestions = append(instance.Status.Suggestions, trialAssignments...)

pkg/new-ui/v1beta1/frontend/src/app/pages/experiment-details/pbt/pbt-tab-loader.module.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ import { PbtTabComponent } from './pbt-tab.component';
1010

1111
@NgModule({
1212
declarations: [PbtTabComponent],
13-
imports: [CommonModule, FormsModule, MatFormFieldModule, MatSelectModule, MatCheckboxModule],
13+
imports: [
14+
CommonModule,
15+
FormsModule,
16+
MatFormFieldModule,
17+
MatSelectModule,
18+
MatCheckboxModule,
19+
],
1420
exports: [PbtTabComponent],
1521
})
1622
export class PbtTabModule {}
Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
<div class="pbt-wrapper">
2-
32
<div class="pbt-options-wrapper">
43
<mat-form-field appearance="fill" class="pbt-option">
54
<mat-label>Y-Axis</mat-label>
6-
<mat-select [(ngModel)]="selectedName" (ngModelChange)="onDropdownChange()">
5+
<mat-select
6+
[(ngModel)]="selectedName"
7+
(ngModelChange)="onDropdownChange()"
8+
>
79
<mat-option *ngFor="let name of selectableNames" [value]="name">
8-
{{name}}
10+
{{ name }}
911
</mat-option>
1012
</mat-select>
1113
</mat-form-field>
1214

13-
<mat-checkbox [(ngModel)]="displayTrace" (ngModelChange)="onTraceChange()">Display Seed Traces</mat-checkbox>
15+
<mat-checkbox [(ngModel)]="displayTrace" (ngModelChange)="onTraceChange()"
16+
>Display Seed Traces</mat-checkbox
17+
>
1418
</div>
1519

1620
<div #pbtGraph id="pbt-graph" class="d3-tab-graph"></div>
17-
1821
</div>

pkg/new-ui/v1beta1/frontend/src/app/pages/experiment-details/pbt/pbt-tab.component.scss

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
}
1818

1919
.pbt-option {
20-
margin: 10px
20+
margin: 10px;
2121
}
2222

2323
.d3-tab-graph {

0 commit comments

Comments
 (0)