4
4
import tarfile
5
5
import os
6
6
7
- def random_flip_image_and_annotation (image_tensor , annotation_tensor ):
7
+ def random_flip_image_and_annotation (image_tensor , annotation_tensor , image_shape ):
8
8
"""Accepts image tensor and annotation tensor and returns randomly flipped tensors of both.
9
9
The function performs random flip of image and annotation tensors with probability of 1/2
10
10
The flip is performed or not performed for image and annotation consistently, so that
@@ -44,10 +44,10 @@ def random_flip_image_and_annotation(image_tensor, annotation_tensor):
44
44
true_fn = lambda : tf .image .flip_left_right (annotation_tensor ),
45
45
false_fn = lambda : annotation_tensor )
46
46
47
- return randomly_flipped_img , tf .reshape (randomly_flipped_annotation , original_shape )
47
+ return randomly_flipped_img , tf .reshape (randomly_flipped_annotation , original_shape , name = "reshape_random_flip_image_and_annotation" ), image_shape
48
48
49
49
50
- def rescale_image_and_annotation_by_factor (image , annotation , nin_scale = 0.5 , max_scale = 2 ):
50
+ def rescale_image_and_annotation_by_factor (image , annotation , image_shape , nin_scale = 0.5 , max_scale = 2 ):
51
51
#We apply data augmentation by randomly scaling theinput images(from 0.5 to 2.0)
52
52
#and randomly left - right flipping during training.
53
53
input_shape = tf .shape (image )[0 :2 ]
@@ -66,7 +66,7 @@ def rescale_image_and_annotation_by_factor(image, annotation, nin_scale=0.5, max
66
66
annotation = tf .image .resize_images (annotation , scaled_input_shape ,
67
67
method = tf .image .ResizeMethod .NEAREST_NEIGHBOR )
68
68
69
- return image , annotation
69
+ return image , annotation , image_shape
70
70
71
71
72
72
def download_resnet_checkpoint_if_necessary (resnet_checkpoints_path , resnet_model_name ):
@@ -93,7 +93,7 @@ def download_resnet_checkpoint_if_necessary(resnet_checkpoints_path, resnet_mode
93
93
print ("ResNet checkpoints file successfully found." )
94
94
95
95
96
- def scale_image_with_crop_padding (image , annotation , crop_size ):
96
+ def scale_image_with_crop_padding (image , annotation , image_shape , crop_size ):
97
97
98
98
image_croped = tf .image .resize_image_with_crop_or_pad (image ,crop_size ,crop_size )
99
99
@@ -108,7 +108,7 @@ def scale_image_with_crop_padding(image, annotation, crop_size):
108
108
annotation_additional_mask_out = tf .to_int32 (tf .equal (cropped_padded_annotation , 0 )) * (mask_out_number + 1 )
109
109
cropped_padded_annotation = cropped_padded_annotation + annotation_additional_mask_out - 1
110
110
111
- return image_croped , tf .squeeze (cropped_padded_annotation )
111
+ return image_croped , tf .squeeze (cropped_padded_annotation ), image_shape
112
112
113
113
def tf_record_parser (record ):
114
114
keys_to_features = {
@@ -131,9 +131,9 @@ def tf_record_parser(record):
131
131
annotation = tf .reshape (annotation , (height ,width ,1 ), name = "annotation_reshape" )
132
132
annotation = tf .to_int32 (annotation )
133
133
134
- return tf .to_float (image ), annotation
134
+ return tf .to_float (image ), annotation , ( height , width )
135
135
136
- def distort_randomly_image_color (image_tensor , annotation_tensor ):
136
+ def distort_randomly_image_color (image_tensor , annotation_tensor , image_shape ):
137
137
"""Accepts image tensor of (width, height, 3) and returns color distorted image.
138
138
The function performs random brightness, saturation, hue, contrast change as it is performed
139
139
for inception model training in TF-Slim (you can find the link below in comments). All the
@@ -167,4 +167,4 @@ def distort_randomly_image_color(image_tensor, annotation_tensor):
167
167
168
168
img_float_distorted_original_range = distorted_image * 255
169
169
170
- return img_float_distorted_original_range , annotation_tensor
170
+ return img_float_distorted_original_range , annotation_tensor , image_shape
0 commit comments