Skip to content

Commit e634ea9

Browse files
committed
fix tf 1.7 test.py support
1 parent 221f5dd commit e634ea9

File tree

4 files changed

+24
-22
lines changed

4 files changed

+24
-22
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ Check out the *train.py* file for more input argument options. Each run produces
4242

4343
To evaluate the model, run the *test.py* file passing to it the *model_id* parameter (the name of the folder created inside *tboard_logs* during training).
4444

45+
Note: Make sure the `test.tfrecords` is downloaded and placed inside `./dataset/tfrecords`.
46+
4547
```
4648
python test.py --model_id=16645
4749
```

preprocessing/read_data.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tarfile
55
import os
66

7-
def random_flip_image_and_annotation(image_tensor, annotation_tensor):
7+
def random_flip_image_and_annotation(image_tensor, annotation_tensor, image_shape):
88
"""Accepts image tensor and annotation tensor and returns randomly flipped tensors of both.
99
The function performs random flip of image and annotation tensors with probability of 1/2
1010
The flip is performed or not performed for image and annotation consistently, so that
@@ -44,10 +44,10 @@ def random_flip_image_and_annotation(image_tensor, annotation_tensor):
4444
true_fn=lambda: tf.image.flip_left_right(annotation_tensor),
4545
false_fn=lambda: annotation_tensor)
4646

47-
return randomly_flipped_img, tf.reshape(randomly_flipped_annotation, original_shape)
47+
return randomly_flipped_img, tf.reshape(randomly_flipped_annotation, original_shape, name="reshape_random_flip_image_and_annotation"), image_shape
4848

4949

50-
def rescale_image_and_annotation_by_factor(image, annotation, nin_scale=0.5, max_scale=2):
50+
def rescale_image_and_annotation_by_factor(image, annotation, image_shape, nin_scale=0.5, max_scale=2):
5151
#We apply data augmentation by randomly scaling theinput images(from 0.5 to 2.0)
5252
#and randomly left - right flipping during training.
5353
input_shape = tf.shape(image)[0:2]
@@ -66,7 +66,7 @@ def rescale_image_and_annotation_by_factor(image, annotation, nin_scale=0.5, max
6666
annotation = tf.image.resize_images(annotation, scaled_input_shape,
6767
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
6868

69-
return image, annotation
69+
return image, annotation, image_shape
7070

7171

7272
def download_resnet_checkpoint_if_necessary(resnet_checkpoints_path, resnet_model_name):
@@ -93,7 +93,7 @@ def download_resnet_checkpoint_if_necessary(resnet_checkpoints_path, resnet_mode
9393
print("ResNet checkpoints file successfully found.")
9494

9595

96-
def scale_image_with_crop_padding(image, annotation, crop_size):
96+
def scale_image_with_crop_padding(image, annotation, image_shape, crop_size):
9797

9898
image_croped = tf.image.resize_image_with_crop_or_pad(image,crop_size,crop_size)
9999

@@ -108,7 +108,7 @@ def scale_image_with_crop_padding(image, annotation, crop_size):
108108
annotation_additional_mask_out = tf.to_int32(tf.equal(cropped_padded_annotation, 0)) * (mask_out_number+1)
109109
cropped_padded_annotation = cropped_padded_annotation + annotation_additional_mask_out - 1
110110

111-
return image_croped, tf.squeeze(cropped_padded_annotation)
111+
return image_croped, tf.squeeze(cropped_padded_annotation), image_shape
112112

113113
def tf_record_parser(record):
114114
keys_to_features = {
@@ -131,9 +131,9 @@ def tf_record_parser(record):
131131
annotation = tf.reshape(annotation, (height,width,1), name="annotation_reshape")
132132
annotation = tf.to_int32(annotation)
133133

134-
return tf.to_float(image), annotation
134+
return tf.to_float(image), annotation, (height, width)
135135

136-
def distort_randomly_image_color(image_tensor, annotation_tensor):
136+
def distort_randomly_image_color(image_tensor, annotation_tensor, image_shape):
137137
"""Accepts image tensor of (width, height, 3) and returns color distorted image.
138138
The function performs random brightness, saturation, hue, contrast change as it is performed
139139
for inception model training in TF-Slim (you can find the link below in comments). All the
@@ -167,4 +167,4 @@ def distort_randomly_image_color(image_tensor, annotation_tensor):
167167

168168
img_float_distorted_original_range = distorted_image * 255
169169

170-
return img_float_distorted_original_range, annotation_tensor
170+
return img_float_distorted_original_range, annotation_tensor, image_shape

test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
parser = argparse.ArgumentParser()
1717

1818
envarg = parser.add_argument_group('Eval params')
19-
envarg.add_argument("--model_id", type=int, help="Model id name to be loaded.")
19+
envarg.add_argument("--model_id", default=16645, type=int, help="Model id name to be loaded.")
2020
input_args = parser.parse_args()
2121

2222
# best: 16645
@@ -68,13 +68,13 @@ class Dotdict(dict):
6868
class_labels[-1] = 255
6969

7070
LOG_FOLDER = './tboard_logs'
71-
TEST_DATASET_DIR="./dataset/"
71+
TEST_DATASET_DIR="./dataset/tfrecords"
7272
TEST_FILE = 'test.tfrecords'
7373

7474
test_filenames = [os.path.join(TEST_DATASET_DIR,TEST_FILE)]
7575
test_dataset = tf.data.TFRecordDataset(test_filenames)
7676
test_dataset = test_dataset.map(tf_record_parser) # Parse the record into tensors.
77-
test_dataset = test_dataset.map(scale_image_with_crop_padding)
77+
test_dataset = test_dataset.map(lambda image, annotation, image_shape: scale_image_with_crop_padding(image, annotation, image_shape, args.crop_size))
7878
test_dataset = test_dataset.shuffle(buffer_size=100)
7979
test_dataset = test_dataset.batch(args.batch_size)
8080

@@ -88,8 +88,8 @@ class Dotdict(dict):
8888
logits_batch_tensor=logits_tf,
8989
class_labels=class_labels)
9090

91-
cross_entropies_tf = tf.nn.softmax_cross_entropy_with_logits(logits=valid_logits_batch_tf,
92-
labels=valid_labels_batch_tf)
91+
cross_entropies_tf = tf.nn.softmax_cross_entropy_with_logits_v2(logits=valid_logits_batch_tf,
92+
labels=valid_labels_batch_tf)
9393

9494
cross_entropy_mean_tf = tf.reduce_mean(cross_entropies_tf)
9595
tf.summary.scalar('cross_entropy', cross_entropy_mean_tf)
@@ -154,12 +154,12 @@ class Dotdict(dict):
154154
mean_IoU.append(IoU)
155155
mean_freq_weighted_IU.append(freq_weighted_IU)
156156

157-
#f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8, 8))
157+
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8, 8))
158158

159-
#ax1.imshow(input_image.astype(np.uint8))
160-
#ax2.imshow(label_image)
161-
#ax3.imshow(pred_image)
162-
#plt.show()
159+
ax1.imshow(input_image.astype(np.uint8))
160+
ax2.imshow(label_image)
161+
ax3.imshow(pred_image)
162+
plt.show()
163163

164164
except tf.errors.OutOfRangeError:
165165
break

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
training_dataset = training_dataset.map(tf_record_parser)
6969
training_dataset = training_dataset.map(rescale_image_and_annotation_by_factor)
7070
training_dataset = training_dataset.map(distort_randomly_image_color)
71-
training_dataset = training_dataset.map(lambda image, annotation: scale_image_with_crop_padding(image, annotation, crop_size))
71+
training_dataset = training_dataset.map(lambda image, annotation, image_shape: scale_image_with_crop_padding(image, annotation, image_shape, crop_size))
7272
training_dataset = training_dataset.map(random_flip_image_and_annotation) # Parse the record into tensors.
7373
training_dataset = training_dataset.repeat() # number of epochs
7474
training_dataset = training_dataset.shuffle(buffer_size=500)
@@ -77,7 +77,7 @@
7777
validation_filenames = [os.path.join(TRAIN_DATASET_DIR,VALIDATION_FILE)]
7878
validation_dataset = tf.data.TFRecordDataset(validation_filenames)
7979
validation_dataset = validation_dataset.map(tf_record_parser) # Parse the record into tensors.
80-
validation_dataset = validation_dataset.map(lambda image, annotation: scale_image_with_crop_padding(image, annotation, crop_size))
80+
validation_dataset = validation_dataset.map(lambda image, annotation, image_shape: scale_image_with_crop_padding(image, annotation, image_shape, crop_size))
8181
validation_dataset = validation_dataset.shuffle(buffer_size=100)
8282
validation_dataset = validation_dataset.batch(args.batch_size)
8383

@@ -92,7 +92,7 @@
9292

9393
iterator = tf.data.Iterator.from_string_handle(
9494
handle, training_dataset.output_types, training_dataset.output_shapes)
95-
batch_images_tf, batch_labels_tf = iterator.get_next()
95+
batch_images_tf, batch_labels_tf, _ = iterator.get_next()
9696

9797
# You can use feedable iterators with a variety of different kinds of iterator
9898
# (such as one-shot and initializable iterators).

0 commit comments

Comments
 (0)