@@ -904,13 +904,13 @@ def resize(self, pixels, action, smaller_side, larger_side, scale_factor, resize
904
904
905
905
crop_x , crop_y , pad_x , pad_y = (0.0 , 0.0 , 0.0 , 0.0 )
906
906
if action == self .ACTION_TYPE_CROP :
907
- target_ratio = self .parse_side_ratio (side_ratio )
907
+ target_ratio = target_width / target_height if target_width != 0 and target_height != 0 else self .parse_side_ratio (side_ratio )
908
908
if height * target_ratio < width :
909
909
crop_x = width - height * target_ratio
910
910
else :
911
911
crop_y = height - width / target_ratio
912
912
elif action == self .ACTION_TYPE_PAD :
913
- target_ratio = self .parse_side_ratio (side_ratio )
913
+ target_ratio = target_width / target_height if target_width != 0 and target_height != 0 else self .parse_side_ratio (side_ratio )
914
914
if height * target_ratio > width :
915
915
pad_x = height * target_ratio - width
916
916
else :
@@ -930,10 +930,7 @@ def resize(self, pixels, action, smaller_side, larger_side, scale_factor, resize
930
930
if (resize_mode == self .RESIZE_MODE_DOWNSCALE and scale_factor >= 1.0 ) or (resize_mode == self .RESIZE_MODE_UPSCALE and scale_factor <= 1.0 ):
931
931
scale_factor = 0.0
932
932
933
- if target_width != 0 and target_height != 0 :
934
- pixels , mask = self .interpolate_to_target_size (pixels , mask , target_height , target_width )
935
- crop_x , crop_y , pad_x , pad_y = (0.0 , 0.0 , 0.0 , 0.0 )
936
- elif scale_factor > 0.0 :
933
+ if scale_factor > 0.0 :
937
934
pixels = torch .nn .functional .interpolate (
938
935
pixels .movedim (- 1 , 1 ), scale_factor = scale_factor , mode = "bicubic" , antialias = True ).movedim (1 , - 1 ).clamp (0.0 , 1.0 )
939
936
mask = torch .nn .functional .interpolate (mask .unsqueeze (
@@ -993,6 +990,10 @@ def resize(self, pixels, action, smaller_side, larger_side, scale_factor, resize
993
990
for k in range (width ):
994
991
mask [i , height + add_y [0 ] - j - 1 , k ] = max (
995
992
mask [i , height + add_y [0 ] - j - 1 , k ], feather_strength )
993
+
994
+ if target_width != 0 and target_height != 0 :
995
+ pixels , mask = self .interpolate_to_target_size (pixels , mask , target_height , target_width )
996
+
996
997
if all_szie_8x == "crop" :
997
998
pixels = self .vae_encode_crop_pixels (pixels )
998
999
mask = self .vae_encode_crop_pixels (mask )
0 commit comments