Skip to content

Commit e36827a

Browse files
committed
improve get_crop_region
1 parent f939bce commit e36827a

File tree

2 files changed

+10
-35
lines changed

2 files changed

+10
-35
lines changed

modules/masking.py

+9-34
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,15 @@
33

44
def get_crop_region(mask, pad=0):
55
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
6-
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
7-
8-
h, w = mask.shape
9-
10-
crop_left = 0
11-
for i in range(w):
12-
if not (mask[:, i] == 0).all():
13-
break
14-
crop_left += 1
15-
16-
crop_right = 0
17-
for i in reversed(range(w)):
18-
if not (mask[:, i] == 0).all():
19-
break
20-
crop_right += 1
21-
22-
crop_top = 0
23-
for i in range(h):
24-
if not (mask[i] == 0).all():
25-
break
26-
crop_top += 1
27-
28-
crop_bottom = 0
29-
for i in reversed(range(h)):
30-
if not (mask[i] == 0).all():
31-
break
32-
crop_bottom += 1
33-
34-
return (
35-
int(max(crop_left-pad, 0)),
36-
int(max(crop_top-pad, 0)),
37-
int(min(w - crop_right + pad, w)),
38-
int(min(h - crop_bottom + pad, h))
39-
)
6+
For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""
7+
mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
8+
box = mask_img.getbbox()
9+
if box:
10+
x1, y1, x2, y2 = box
11+
else: # when no box is found
12+
x1, y1 = mask_img.size
13+
x2 = y2 = 0
14+
return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])
4015

4116

4217
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):

modules/processing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1562,7 +1562,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
15621562
if self.inpaint_full_res:
15631563
self.mask_for_overlay = image_mask
15641564
mask = image_mask.convert('L')
1565-
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
1565+
crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
15661566
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
15671567
x1, y1, x2, y2 = crop_region
15681568

0 commit comments

Comments
 (0)