Skip to content

Commit e81291a

Browse files
authored
Optimise shared data, generate once (#18)
* Optimise shared data, generate once * fixup
1 parent bad7182 commit e81291a

File tree

1 file changed

+39
-39
lines changed

1 file changed

+39
-39
lines changed

nodes.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -62,42 +62,41 @@ def vignette_image(self, image: torch.Tensor, intensity: float, center_x: float,
6262
batch_size, height, width, _ = image.shape
6363
result = torch.zeros_like(image)
6464

65-
for b in range(batch_size):
66-
tensor_image = image[b].numpy()
67-
68-
# Apply vignette
69-
vignette_image = self.apply_vignette(tensor_image, intensity, center_x, center_y)
70-
71-
tensor = torch.from_numpy(vignette_image).unsqueeze(0)
72-
result[b] = tensor
73-
74-
return (result,)
75-
76-
def apply_vignette(self, image, vignette_strength, center_x_ratio, center_y_ratio):
77-
if vignette_strength == 0:
65+
if intensity == 0:
7866
return image
7967

80-
height, width, _ = image.shape
81-
68+
# Generate the vignette for each image in the batch
8269
# Create linear space but centered around the provided center point ratios
8370
x = np.linspace(-1, 1, width)
8471
y = np.linspace(-1, 1, height)
85-
X, Y = np.meshgrid(x - (2 * center_x_ratio - 1), y - (2 * center_y_ratio - 1))
72+
X, Y = np.meshgrid(x - (2 * center_x - 1), y - (2 * center_y - 1))
8673

8774
# Calculate distances to the furthest corner
8875
distances_to_corners = [
89-
np.sqrt((0 - center_x_ratio) ** 2 + (0 - center_y_ratio) ** 2),
90-
np.sqrt((1 - center_x_ratio) ** 2 + (0 - center_y_ratio) ** 2),
91-
np.sqrt((0 - center_x_ratio) ** 2 + (1 - center_y_ratio) ** 2),
92-
np.sqrt((1 - center_x_ratio) ** 2 + (1 - center_y_ratio) ** 2)
76+
np.sqrt((0 - center_x) ** 2 + (0 - center_y) ** 2),
77+
np.sqrt((1 - center_x) ** 2 + (0 - center_y) ** 2),
78+
np.sqrt((0 - center_x) ** 2 + (1 - center_y) ** 2),
79+
np.sqrt((1 - center_x) ** 2 + (1 - center_y) ** 2)
9380
]
9481
max_distance_to_corner = np.max(distances_to_corners)
9582

9683
radius = np.sqrt(X ** 2 + Y ** 2)
9784
radius = radius / (max_distance_to_corner * np.sqrt(2)) # Normalize radius
98-
opacity = np.clip(vignette_strength, 0, 1)
85+
opacity = np.clip(intensity, 0, 1)
9986
vignette = 1 - radius * opacity
10087

88+
for b in range(batch_size):
89+
tensor_image = image[b].numpy()
90+
91+
# Apply vignette
92+
vignette_image = self.apply_vignette(tensor_image, vignette)
93+
94+
tensor = torch.from_numpy(vignette_image).unsqueeze(0)
95+
result[b] = tensor
96+
97+
return (result,)
98+
99+
def apply_vignette(self, image, vignette):
101100
# If image needs to be normalized (0-1 range)
102101
needs_normalization = image.max() > 1
103102
if needs_normalization:
@@ -206,6 +205,7 @@ def apply_filmgrain(self, image, gray_scale, grain_type, grain_sat, grain_power,
206205
out_image = filmgrainer.process(image, scale, src_gamma,
207206
grain_power, shadows, highs, grain_type,
208207
grain_sat, gray_scale, sharpen, seed)
208+
209209
return out_image
210210

211211

@@ -263,37 +263,37 @@ def radialblur_image(self, image: torch.Tensor, blur_strength: float, center_x:
263263
batch_size, height, width, _ = image.shape
264264
result = torch.zeros_like(image)
265265

266+
# Generate the vignette for each image in the batch
267+
c_x, c_y = int(width * center_x), int(height * center_y)
268+
269+
# Calculate distances to all corners from the center
270+
distances_to_corners = [
271+
np.sqrt((c_x - 0)**2 + (c_y - 0)**2),
272+
np.sqrt((c_x - width)**2 + (c_y - 0)**2),
273+
np.sqrt((c_x - 0)**2 + (c_y - height)**2),
274+
np.sqrt((c_x - width)**2 + (c_y - height)**2)
275+
]
276+
max_distance_to_corner = max(distances_to_corners)
277+
278+
# Create and adjust radial mask
279+
X, Y = np.meshgrid(np.arange(width) - c_x, np.arange(height) - c_y)
280+
radial_mask = np.sqrt(X**2 + Y**2) / max_distance_to_corner
281+
266282
for b in range(batch_size):
267283
tensor_image = image[b].numpy()
268284

269285
# Apply blur
270-
blur_image = self.apply_radialblur(tensor_image, blur_strength, center_x, center_y, focus_spread, steps)
286+
blur_image = self.apply_radialblur(tensor_image, blur_strength, radial_mask, focus_spread, steps)
271287

272288
tensor = torch.from_numpy(blur_image).unsqueeze(0)
273289
result[b] = tensor
274290

275291
return (result,)
276292

277-
def apply_radialblur(self, image, blur_strength, center_x_ratio, center_y_ratio, focus_spread, steps):
293+
def apply_radialblur(self, image, blur_strength, radial_mask, focus_spread, steps):
278294
needs_normalization = image.max() > 1
279295
if needs_normalization:
280296
image = image.astype(np.float32) / 255
281-
282-
height, width = image.shape[:2]
283-
center_x, center_y = int(width * center_x_ratio), int(height * center_y_ratio)
284-
285-
# Calculate distances to all corners from the center
286-
distances_to_corners = [
287-
np.sqrt((center_x - 0)**2 + (center_y - 0)**2),
288-
np.sqrt((center_x - width)**2 + (center_y - 0)**2),
289-
np.sqrt((center_x - 0)**2 + (center_y - height)**2),
290-
np.sqrt((center_x - width)**2 + (center_y - height)**2)
291-
]
292-
max_distance_to_corner = max(distances_to_corners)
293-
294-
# Create and adjust radial mask
295-
X, Y = np.meshgrid(np.arange(width) - center_x, np.arange(height) - center_y)
296-
radial_mask = np.sqrt(X**2 + Y**2) / max_distance_to_corner
297297

298298
blurred_images = processing_utils.generate_blurred_images(image, blur_strength, steps, focus_spread)
299299
final_image = processing_utils.apply_blurred_images(image, blurred_images, radial_mask)

0 commit comments

Comments
 (0)