@@ -62,42 +62,41 @@ def vignette_image(self, image: torch.Tensor, intensity: float, center_x: float,
62
62
batch_size , height , width , _ = image .shape
63
63
result = torch .zeros_like (image )
64
64
65
- for b in range (batch_size ):
66
- tensor_image = image [b ].numpy ()
67
-
68
- # Apply vignette
69
- vignette_image = self .apply_vignette (tensor_image , intensity , center_x , center_y )
70
-
71
- tensor = torch .from_numpy (vignette_image ).unsqueeze (0 )
72
- result [b ] = tensor
73
-
74
- return (result ,)
75
-
76
- def apply_vignette (self , image , vignette_strength , center_x_ratio , center_y_ratio ):
77
- if vignette_strength == 0 :
65
+ if intensity == 0 :
78
66
return image
79
67
80
- height , width , _ = image .shape
81
-
68
+ # Generate the vignette for each image in the batch
82
69
# Create linear space but centered around the provided center point ratios
83
70
x = np .linspace (- 1 , 1 , width )
84
71
y = np .linspace (- 1 , 1 , height )
85
- X , Y = np .meshgrid (x - (2 * center_x_ratio - 1 ), y - (2 * center_y_ratio - 1 ))
72
+ X , Y = np .meshgrid (x - (2 * center_x - 1 ), y - (2 * center_y - 1 ))
86
73
87
74
# Calculate distances to the furthest corner
88
75
distances_to_corners = [
89
- np .sqrt ((0 - center_x_ratio ) ** 2 + (0 - center_y_ratio ) ** 2 ),
90
- np .sqrt ((1 - center_x_ratio ) ** 2 + (0 - center_y_ratio ) ** 2 ),
91
- np .sqrt ((0 - center_x_ratio ) ** 2 + (1 - center_y_ratio ) ** 2 ),
92
- np .sqrt ((1 - center_x_ratio ) ** 2 + (1 - center_y_ratio ) ** 2 )
76
+ np .sqrt ((0 - center_x ) ** 2 + (0 - center_y ) ** 2 ),
77
+ np .sqrt ((1 - center_x ) ** 2 + (0 - center_y ) ** 2 ),
78
+ np .sqrt ((0 - center_x ) ** 2 + (1 - center_y ) ** 2 ),
79
+ np .sqrt ((1 - center_x ) ** 2 + (1 - center_y ) ** 2 )
93
80
]
94
81
max_distance_to_corner = np .max (distances_to_corners )
95
82
96
83
radius = np .sqrt (X ** 2 + Y ** 2 )
97
84
radius = radius / (max_distance_to_corner * np .sqrt (2 )) # Normalize radius
98
- opacity = np .clip (vignette_strength , 0 , 1 )
85
+ opacity = np .clip (intensity , 0 , 1 )
99
86
vignette = 1 - radius * opacity
100
87
88
+ for b in range (batch_size ):
89
+ tensor_image = image [b ].numpy ()
90
+
91
+ # Apply vignette
92
+ vignette_image = self .apply_vignette (tensor_image , vignette )
93
+
94
+ tensor = torch .from_numpy (vignette_image ).unsqueeze (0 )
95
+ result [b ] = tensor
96
+
97
+ return (result ,)
98
+
99
+ def apply_vignette (self , image , vignette ):
101
100
# If image needs to be normalized (0-1 range)
102
101
needs_normalization = image .max () > 1
103
102
if needs_normalization :
@@ -206,6 +205,7 @@ def apply_filmgrain(self, image, gray_scale, grain_type, grain_sat, grain_power,
206
205
out_image = filmgrainer .process (image , scale , src_gamma ,
207
206
grain_power , shadows , highs , grain_type ,
208
207
grain_sat , gray_scale , sharpen , seed )
208
+
209
209
return out_image
210
210
211
211
@@ -263,37 +263,37 @@ def radialblur_image(self, image: torch.Tensor, blur_strength: float, center_x:
263
263
batch_size , height , width , _ = image .shape
264
264
result = torch .zeros_like (image )
265
265
266
+ # Generate the vignette for each image in the batch
267
+ c_x , c_y = int (width * center_x ), int (height * center_y )
268
+
269
+ # Calculate distances to all corners from the center
270
+ distances_to_corners = [
271
+ np .sqrt ((c_x - 0 )** 2 + (c_y - 0 )** 2 ),
272
+ np .sqrt ((c_x - width )** 2 + (c_y - 0 )** 2 ),
273
+ np .sqrt ((c_x - 0 )** 2 + (c_y - height )** 2 ),
274
+ np .sqrt ((c_x - width )** 2 + (c_y - height )** 2 )
275
+ ]
276
+ max_distance_to_corner = max (distances_to_corners )
277
+
278
+ # Create and adjust radial mask
279
+ X , Y = np .meshgrid (np .arange (width ) - c_x , np .arange (height ) - c_y )
280
+ radial_mask = np .sqrt (X ** 2 + Y ** 2 ) / max_distance_to_corner
281
+
266
282
for b in range (batch_size ):
267
283
tensor_image = image [b ].numpy ()
268
284
269
285
# Apply blur
270
- blur_image = self .apply_radialblur (tensor_image , blur_strength , center_x , center_y , focus_spread , steps )
286
+ blur_image = self .apply_radialblur (tensor_image , blur_strength , radial_mask , focus_spread , steps )
271
287
272
288
tensor = torch .from_numpy (blur_image ).unsqueeze (0 )
273
289
result [b ] = tensor
274
290
275
291
return (result ,)
276
292
277
- def apply_radialblur (self , image , blur_strength , center_x_ratio , center_y_ratio , focus_spread , steps ):
293
+ def apply_radialblur (self , image , blur_strength , radial_mask , focus_spread , steps ):
278
294
needs_normalization = image .max () > 1
279
295
if needs_normalization :
280
296
image = image .astype (np .float32 ) / 255
281
-
282
- height , width = image .shape [:2 ]
283
- center_x , center_y = int (width * center_x_ratio ), int (height * center_y_ratio )
284
-
285
- # Calculate distances to all corners from the center
286
- distances_to_corners = [
287
- np .sqrt ((center_x - 0 )** 2 + (center_y - 0 )** 2 ),
288
- np .sqrt ((center_x - width )** 2 + (center_y - 0 )** 2 ),
289
- np .sqrt ((center_x - 0 )** 2 + (center_y - height )** 2 ),
290
- np .sqrt ((center_x - width )** 2 + (center_y - height )** 2 )
291
- ]
292
- max_distance_to_corner = max (distances_to_corners )
293
-
294
- # Create and adjust radial mask
295
- X , Y = np .meshgrid (np .arange (width ) - center_x , np .arange (height ) - center_y )
296
- radial_mask = np .sqrt (X ** 2 + Y ** 2 ) / max_distance_to_corner
297
297
298
298
blurred_images = processing_utils .generate_blurred_images (image , blur_strength , steps , focus_spread )
299
299
final_image = processing_utils .apply_blurred_images (image , blurred_images , radial_mask )
0 commit comments