-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
24 lines (20 loc) · 839 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
""" Training script! """
import sys
sys.path.append('../../')
import tensorflow as tf
from model.neat_config import NeatConfig
from model.interact.dataloader import input_fn_builder
from model.interact.modeling import model_fn_builder
config = NeatConfig.from_args("Train detector script", default_config_file='configs/jan5_basic0.yaml')
model_fn = model_fn_builder(config)
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=config.device['use_tpu'],
model_fn=model_fn,
config=config.device['tpu_run_config'],
train_batch_size=config.device['train_batch_size'],
eval_batch_size=config.device['val_batch_size'],
predict_batch_size=config.device['val_batch_size'],
# params={},
)
estimator.train(input_fn=input_fn_builder(config, is_training=True),
max_steps=config.optimizer['num_train_steps'])