@@ -431,25 +431,25 @@ def lut_image(self, image: torch.Tensor, lut_name, strength: float, log: bool):
431
431
batch_size , height , width , _ = image .shape
432
432
result = torch .zeros_like (image )
433
433
434
+ # Read the LUT
435
+ lut_path = os .path .join (dir_luts , lut_name )
436
+ lut = loading_utils .read_lut (lut_path , clip = True )
437
+
434
438
for b in range (batch_size ):
435
439
tensor_image = image [b ].numpy ()
436
440
437
441
# Apply LUT
438
- lut_image = self .apply_lut (tensor_image , lut_name , strength , log )
442
+ lut_image = self .apply_lut (tensor_image , lut , strength , log )
439
443
440
444
tensor = torch .from_numpy (lut_image ).unsqueeze (0 )
441
445
result [b ] = tensor
442
446
443
447
return (result ,)
444
448
445
- def apply_lut (self , image , lut_name , strength , log ):
449
+ def apply_lut (self , image , lut , strength , log ):
446
450
if strength == 0 :
447
451
return image
448
452
449
- # Read the LUT
450
- lut_path = os .path .join (dir_luts , lut_name )
451
- lut = loading_utils .read_lut (lut_path , clip = True )
452
-
453
453
# Apply the LUT
454
454
is_non_default_domain = not np .array_equal (lut .domain , np .array ([[0. , 0. , 0. ], [1. , 1. , 1. ]]))
455
455
dom_scale = None
0 commit comments