Skip to content

Improve mask_overlay efficiency #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion range_compression/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
from .range_compression import RangeCompressedMask, mask_encode, rcm_find_index, rcm_load, calc_area_from_encodings, calc_area_from_mask
from .range_compression import (
RangeCompressedMask,
mask_encode,
rcm_find_index,
rcm_load,
calc_area_from_encodings,
calc_area_from_mask,
mask_overlay,
)
229 changes: 184 additions & 45 deletions range_compression/range_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,35 @@ class RangeCompressedMask:
h: int

encodings: np.ndarray
'表达区间用,shape: (n, 4),[(start_x, end_x, value, row_index), ...]'
"表达区间用,shape: (n, 4),[(start_x, end_x, value, row_index), ...]"

row_indexes: np.ndarray
'快速从 y 查找编码行,shape: (n, 2),[(start_y, encoding_count), ...]'
"快速从 y 查找编码行,shape: (n, 2),[(start_y, encoding_count), ...]"


def save(self, base_dir: Union[str, PathLike], compression: tap.ParquetCompression='gzip'):
def save(
self,
base_dir: Union[str, PathLike],
compression: tap.ParquetCompression = "gzip",
):
base_dir = Path(base_dir)

base_dir.mkdir(exist_ok=True)
dfe = pl.from_numpy(self.encodings, schema=['start', 'end', 'v', 'row_index'])
dfe.write_parquet(base_dir / f'encodings.parquet', compression=compression)

dfer = pl.from_numpy(self.row_indexes, schema=['row_start_index', 'row_count'])
dfer.write_parquet(base_dir / f'row_indexes.parquet', compression=compression)

(base_dir / 'meta.json').write_text(json.dumps({
'w': self.w,
'h': self.h,
'datetime': datetime.now(),
}, default=str))
dfe = pl.from_numpy(self.encodings, schema=["start", "end", "v", "row_index"])
dfe.write_parquet(base_dir / f"encodings.parquet", compression=compression)

dfer = pl.from_numpy(self.row_indexes, schema=["row_start_index", "row_count"])
dfer.write_parquet(base_dir / f"row_indexes.parquet", compression=compression)

(base_dir / "meta.json").write_text(
json.dumps(
{
"w": self.w,
"h": self.h,
"datetime": datetime.now(),
},
default=str,
)
)

def three_columns_encodings(self, try_contiguous=True):
res = self.encodings[:, :3]
Expand All @@ -46,40 +54,44 @@ def three_columns_encodings(self, try_contiguous=True):
return res

@staticmethod
def load(base_dir: Union[str, PathLike], chip: Optional[str] = None, no_row_index=True):
'''从文件夹中导入
def load(
base_dir: Union[str, PathLike], chip: Optional[str] = None, no_row_index=True
):
"""从文件夹中导入

base_dir: 文件夹路径
chip: 如果有 chip,那么会在 base_dir/chip 中寻找文件
no_row_index: 原始 encodings 是四列的,最后一列是 row_index,
如果不需要 row_index,可以设置为 True,可以把更多数据放到缓存里
'''
"""
base_dir = Path(base_dir)

if chip is not None:
base_dir = base_dir / chip

dfe = pl.read_parquet(base_dir / f'encodings.parquet')
dfe = pl.read_parquet(base_dir / f"encodings.parquet")
encodings = dfe.to_numpy()[:, :3]
if no_row_index:
encodings = np.ascontiguousarray(encodings[:, :3])
dfer = pl.read_parquet(base_dir / f'row_indexes.parquet')
dfer = pl.read_parquet(base_dir / f"row_indexes.parquet")
row_indexes = dfer.to_numpy()

meta = json.loads((base_dir / 'meta.json').read_text())
meta = json.loads((base_dir / "meta.json").read_text())
return RangeCompressedMask(
w=meta['w'],
h=meta['h'],
w=meta["w"],
h=meta["h"],
encodings=encodings,
row_indexes=row_indexes,
)

@staticmethod
def targets(base_dir: Union[str, PathLike, None] = None, chip: Optional[str] = None):
def targets(
base_dir: Union[str, PathLike, None] = None, chip: Optional[str] = None
):
base = [
'encodings.parquet',
'row_indexes.parquet',
'meta.json',
"encodings.parquet",
"row_indexes.parquet",
"meta.json",
]

if base_dir is not None:
Expand All @@ -90,14 +102,11 @@ def targets(base_dir: Union[str, PathLike, None] = None, chip: Optional[str] = N

return base

def find_index(
self,
X: np.ndarray, Y: np.ndarray,
binary_search=False
):
def find_index(self, X: np.ndarray, Y: np.ndarray, binary_search=False):
if self.encodings.shape[1] == 4:
import warnings
warnings.warn(f'`row_indexes` has 4 columns.')

warnings.warn(f"`row_indexes` has 4 columns.")
if not isinstance(X, np.ndarray):
X = np.array(X)
if not isinstance(Y, np.ndarray):
Expand All @@ -115,6 +124,7 @@ def to_mask(self):
def calc_area(self):
return calc_area_from_encodings(self.encodings, self.row_indexes)


@nb.njit
def _mask_encode(mask):
n_rows, n_cols = mask.shape
Expand Down Expand Up @@ -149,21 +159,20 @@ def _mask_encode(mask):

return encodings_np, row_indexes_np


def mask_encode(mask: np.ndarray):
'''把 mask 编码为区间压缩。'''
"""把 mask 编码为区间压缩。"""

h, w = mask.shape
encodings_np, row_indexes_np = _mask_encode(mask)
return RangeCompressedMask(
w=w, h=h,
encodings=encodings_np,
row_indexes=row_indexes_np
w=w, h=h, encodings=encodings_np, row_indexes=row_indexes_np
)


@nb.njit
def find_encoding_in_row(row_encodings, col):
'''如果要直接调用此函数,应当手动保证 col 不要在非法值域'''
"""如果要直接调用此函数,应当手动保证 col 不要在非法值域"""
if len(row_encodings) == 1 and row_encodings[0, 0] == 0:
return 0
if row_encodings[0][0] > col:
Expand All @@ -180,6 +189,7 @@ def find_encoding_in_row(row_encodings, col):
# 这个 return 0 用于处理 col > row_encodings[-1][1] 的情况
return 0


@nb.njit
def find_encoding_in_row_binary(row_encodings, col):
# 在细胞中,二分搜索不影响速度,在分区中,二分搜索拖慢 50%
Expand Down Expand Up @@ -251,7 +261,7 @@ def _find_index_binary(row_indexes, encodings_np, X, Y):

@nb.njit(parallel=True)
def to_mask(encodings: np.ndarray, row_indexes: np.ndarray, w: int, h: int):
assert encodings.shape[1] == 3, 'encodings should have 3 columns'
assert encodings.shape[1] == 3, "encodings should have 3 columns"

out = np.zeros((h, w), dtype="int32")
for row in nb.prange(h):
Expand All @@ -261,25 +271,27 @@ def to_mask(encodings: np.ndarray, row_indexes: np.ndarray, w: int, h: int):
out[row, start : stop + 1] = value
return out


@nb.njit
def calc_area_from_encodings(encodings: np.ndarray, row_indexes: np.ndarray):
'''计算每一个分块儿的面积'''
"""计算每一个分块儿的面积"""
areas = {0: 0}

for i in range(len(row_indexes)):
start_index, length = row_indexes[i]
row_encoding = encodings[start_index : start_index + length, :3]
for start, stop, value in row_encoding:
if value == 0: continue
if value == 0:
continue
areas[value] = areas.get(value, 0) + (stop - start + 1)

return areas


@nb.njit
def calc_area_from_mask(mask: np.ndarray):
'''计算每一个分块儿的面积'''
areas = {0:{0:0}}
"""计算每一个分块儿的面积"""
areas = {0: {0: 0}}

for i in range(nb.get_num_threads()):
_areas = {}
Expand All @@ -290,22 +302,149 @@ def calc_area_from_mask(mask: np.ndarray):
areas_ = calc_area_from_mask_row(mask[row, :])
for col_v, cnt in areas_.items():
# print(type(col_v))
areas[nb.get_thread_id()][col_v] = areas[nb.get_thread_id()].get(col_v, 0) + cnt
areas[nb.get_thread_id()][col_v] = (
areas[nb.get_thread_id()].get(col_v, 0) + cnt
)
areas_res = {0: 0}
for d in areas.values():
for col_v, cnt in d.items():
areas_res[col_v] = areas_res.get(col_v, 0) + cnt
return areas_res


@nb.njit
def calc_area_from_mask_row(row: np.ndarray):
areas = {}
areas[np.int64(0)] = np.int64(0)

for col_v in row:
if col_v == 0: continue
if col_v == 0:
continue
areas[col_v] = areas.get(col_v, 0) + 1
return areas


@nb.njit
def _collect_remove_ids(enc_a, rows_a, enc_b, rows_b, h):
remove_dict = nb.typed.Dict.empty(nb.int32, nb.int8)

for row in range(h):
sa, la = rows_a[row]
sb, lb = rows_b[row]

ia = 0
ib = 0
while ia < la and ib < lb:
a_start, a_end, a_val, _ = enc_a[sa + ia]
b_start, b_end, b_val, _ = enc_b[sb + ib]

if a_val == 0:
ia += 1
continue
if b_val == 0:
ib += 1
continue

if a_end < b_start:
ia += 1
continue
if b_end < a_start:
ib += 1
continue

remove_dict[np.int32(a_val)] = 1

if a_end <= b_end:
ia += 1
else:
ib += 1

remove_ids_list = nb.typed.List.empty_list(nb.int32)
for k in remove_dict.keys():
remove_ids_list.append(k)
return remove_ids_list


def _merge_masks(enc_a, rows_a, enc_b, rows_b, remove_ids, w, h):
new_encodings = []
new_row_indexes = []

for row in range(h):
sa, la = rows_a[row]
sb, lb = rows_b[row]

row_segs = []

for i in range(la):
s, e, v, _ = enc_a[sa + i]
if v == 0:
continue
skip = False
for rid in remove_ids:
if v == rid:
skip = True
break
if skip:
continue
row_segs.append((int(s), int(e), int(v)))

for i in range(lb):
s, e, v, _ = enc_b[sb + i]
if v == 0:
continue
row_segs.append((int(s), int(e), int(v)))

row_segs.sort(key=lambda x: x[0])

if not row_segs:
row_segs = [(0, w - 1, 0)]

start_index = len(new_encodings)
for seg in row_segs:
new_encodings.append(seg)
new_row_indexes.append((start_index, len(row_segs)))

enc_np = np.array(new_encodings, dtype=np.int32)
idx_np = np.array(new_row_indexes, dtype=np.int32)
return enc_np, idx_np


def mask_overlay(
rcm_a: RangeCompressedMask, rcm_b: RangeCompressedMask
) -> RangeCompressedMask:
"""把 ``rcm_b`` 覆盖到 ``rcm_a`` 上, 直接在区间压缩数据上完成。

``rcm_a`` 与 ``rcm_b`` 的尺寸必须一致。若 ``rcm_b`` 与 ``rcm_a`` 的任
意细胞有交集, ``rcm_a`` 中这些细胞会被完全删除, 最终返回 ``rcm_a``
剩余细胞与 ``rcm_b`` 的并集。
"""

if rcm_a.w != rcm_b.w or rcm_a.h != rcm_b.h:
raise ValueError("rcm_a and rcm_b must have the same shape")

w, h = rcm_a.w, rcm_a.h

remove_ids_list = _collect_remove_ids(
rcm_a.encodings, rcm_a.row_indexes, rcm_b.encodings, rcm_b.row_indexes, h
)
remove_ids = np.array(list(remove_ids_list), dtype=np.int32)
new_enc, new_rows = _merge_masks(
rcm_a.encodings,
rcm_a.row_indexes,
rcm_b.encodings,
rcm_b.row_indexes,
remove_ids,
w,
h,
)

return RangeCompressedMask(
w=w,
h=h,
encodings=new_enc,
row_indexes=new_rows,
)


rcm_load = RangeCompressedMask.load
rcm_find_index = RangeCompressedMask.find_index
Loading