From 70632cd4ff807018f9aa915e4ae8e3b8b94009cd Mon Sep 17 00:00:00 2001 From: Amil Date: Wed, 17 May 2023 17:42:33 -0700 Subject: [PATCH] :fire: Removed and renamed files for PyPi release --- cells/streamlit/cellgeometry/Hello.py | 22 ++ cells/streamlit/cellgeometry/__init__.py | 0 ...astic_Metric_for_Cell_Boundary_Analysis.py | 216 ++++++++++++ .../cellgeometry/pages/1-Load_Data.py | 119 +++++++ .../cellgeometry/pages/Cell_Shear.py | 107 ++++++ cells/streamlit/cellgeometry/utils/basic.py | 34 ++ .../cellgeometry/utils/data_utils.py | 100 ++++++ .../cellgeometry/utils/experimental.py | 317 ++++++++++++++++++ 8 files changed, 915 insertions(+) create mode 100644 cells/streamlit/cellgeometry/Hello.py create mode 100644 cells/streamlit/cellgeometry/__init__.py create mode 100644 cells/streamlit/cellgeometry/pages/ Elastic_Metric_for_Cell_Boundary_Analysis.py create mode 100644 cells/streamlit/cellgeometry/pages/1-Load_Data.py create mode 100644 cells/streamlit/cellgeometry/pages/Cell_Shear.py create mode 100644 cells/streamlit/cellgeometry/utils/basic.py create mode 100644 cells/streamlit/cellgeometry/utils/data_utils.py create mode 100644 cells/streamlit/cellgeometry/utils/experimental.py diff --git a/cells/streamlit/cellgeometry/Hello.py b/cells/streamlit/cellgeometry/Hello.py new file mode 100644 index 0000000..73841fc --- /dev/null +++ b/cells/streamlit/cellgeometry/Hello.py @@ -0,0 +1,22 @@ +import streamlit as st + + + + +st.set_page_config( + page_title="Welcome", + page_icon="šŸ‘‹", +) + +st.write("# Welcome to the Cell Shape Analysis App! šŸ‘‹") + +st.sidebar.success("Select a demo above.") + +st.markdown( + """ + Geomstats is an open-source Python package for computations, statistics, and machine learning on nonlinear manifolds. Data from many application fields are elements of manifolds. For instance, the manifold of 3D rotations SO(3) naturally appears when performing statistical learning on articulated objects like the human spine or robotics arms. + ** + + šŸ‘ˆ Select a demo from the sidebar** +""" +) \ No newline at end of file diff --git a/cells/streamlit/cellgeometry/__init__.py b/cells/streamlit/cellgeometry/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cells/streamlit/cellgeometry/pages/ Elastic_Metric_for_Cell_Boundary_Analysis.py b/cells/streamlit/cellgeometry/pages/ Elastic_Metric_for_Cell_Boundary_Analysis.py new file mode 100644 index 0000000..b231f5b --- /dev/null +++ b/cells/streamlit/cellgeometry/pages/ Elastic_Metric_for_Cell_Boundary_Analysis.py @@ -0,0 +1,216 @@ +import os +import subprocess +import geomstats.backend as gs +import streamlit as st +import time +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from scipy import stats + +from sklearn.cluster import KMeans +from sklearn.metrics import confusion_matrix +from sklearn.decomposition import PCA + +from scipy.optimize import linear_sum_assignment as linear_assignment +from sklearn import manifold +from joblib import Parallel, delayed +from numba import jit, njit, prange + +from geomstats.geometry.euclidean import Euclidean +from geomstats.geometry.discrete_curves import R2, DiscreteCurves, ClosedDiscreteCurves + +from geomstats.learning.frechet_mean import FrechetMean +from geomstats.learning.kmeans import RiemannianKMeans +from geomstats.learning.mdm import RiemannianMinimumDistanceToMean +from geomstats.learning.pca import TangentPCA + +import sys +sys.path.append("/app/utils") + +# import utils +from utils import experimental as experimental +from utils import basic as basic + + + + +st.set_page_config(page_title="Elastic Metric for Cell Boundary Analysis", page_icon="šŸ“ˆ") + +st.markdown("# Shape Analysis of Cancer Cells") +st.sidebar.header("Shape Analysis of Cancer Cells") +st.write( + """This notebook studies Osteosarcoma (bone cancer) cells and the impact of drug treatment on their morphological shapes, by analyzing cell images obtained from fluorescence microscopy. + +This analysis relies on the elastic metric between discrete curves from Geomstats. We will study to which extent this metric can detect how the cell shape is associated with the response to treatment.""" +) + +dataset_name = "osteosarcoma" + +n_sampling_points = st.slider('Select the Number of Sampling Points', 0, 250, 100) +n_cells = 650 +# n_sampling_points = 100 +labels_a_name = "lines" +labels_b_name = "treatments" + +quotient = ["rotation"] #["scaling"] #, "rotation"] +do_not_quotient = False + + +if dataset_name == "osteosarcoma": + cells, cell_shapes, labels_a, labels_b = experimental.load_treated_osteosarcoma_cells( + n_cells=n_cells, n_sampling_points=n_sampling_points, quotient=quotient + ) +else: + pass + + +labels_a_dict = {lab: i_lab for i_lab, lab in enumerate(np.unique(labels_a))} +labels_b_dict = {lab: i_lab for i_lab, lab in enumerate(np.unique(labels_b))} + +print(f"Dictionary associated to label \"{labels_a_name}\":") +print(labels_a_dict) +print(f"Dictionary associated to label \"{labels_b_name}\":") +print(labels_b_dict) + +if do_not_quotient: + cell_shapes = cells + +n_cells_to_plot = 10 + +fig = plt.figure(figsize=(16, 6)) +count = 1 +for label_b in np.unique(labels_b): + for i_lab_a, label_a in enumerate(np.unique(labels_a)): + cell_data = [cell for cell, lab_a, lab_b in zip(cell_shapes, labels_a, labels_b) if lab_a == label_a and lab_b == label_b] + for i_to_plot in range(n_cells_to_plot): + cell = gs.random.choice(a=cell_data) + fig.add_subplot(len(np.unique(labels_b)), len(np.unique(labels_a)) * n_cells_to_plot, count) + count += 1 + plt.plot(cell[:, 0], cell[:, 1], color=f"C{i_lab_a}" ) + plt.axis("equal") + plt.axis("off") + if i_to_plot == n_cells_to_plot // 2: + plt.title(f"{label_a} - {label_b}", fontsize=20) +st.pyplot(fig) + +# Define shape space +R1 = Euclidean(dim=1) +CLOSED_CURVES_SPACE = ClosedDiscreteCurves(R2) +CURVES_SPACE = DiscreteCurves(R2) +SRV_METRIC = CURVES_SPACE.srv_metric +L2_METRIC = CURVES_SPACE.l2_curves_metric + +ELASTIC_METRIC = {} +AS = [1, 2, 0.75, 0.5, 0.25, 0.01] #, 1.6] #, 1.4, 1.2, 1, 0.5, 0.2, 0.1] +BS = [0.5, 1, 0.5, 0.5, 0.5, 0.5] #, 2, 2, 2, 2, 2, 2, 2] +for a, b in zip(AS, BS): + ELASTIC_METRIC[a, b] = DiscreteCurves(R2, a=a, b=b).elastic_metric +METRICS = {} +METRICS["Linear"] = L2_METRIC +METRICS["SRV"] = SRV_METRIC + +means = {} + +means["Linear"] = gs.mean(cell_shapes, axis=0) +means["SRV"] = FrechetMean( + metric=SRV_METRIC, + method="default").fit(cell_shapes).estimate_ + +for a, b in zip(AS, BS): + means[a, b] = FrechetMean( + metric=ELASTIC_METRIC[a, b], + method="default").fit(cell_shapes).estimate_ + +st.header("Sample Means") +st.markdown("We compare results when computing the mean cell versus the mean cell shapes with different elastic metrics.") +fig = plt.figure(figsize=(18, 8)) + +ncols = len(means) // 2 + +for i, (mean_name, mean) in enumerate(means.items()): + ax = fig.add_subplot(2, ncols, i+1) + ax.plot(mean[:, 0], mean[:, 1], "black") + ax.set_aspect("equal") + ax.axis("off") + axs_title = mean_name + if mean_name not in ["Linear", "SRV"]: + a = mean_name[0] + b = mean_name[1] + ratio = a / (2 * b) + mean_name = f"Elastic {mean_name}\n a / (2b) = {ratio}" + ax.set_title(mean_name) + +st.pyplot(fig) + + + +fig = plt.figure(figsize=(18, 8)) + +ncols = len(means) // 2 + +for i, (mean_name, mean) in enumerate(means.items()): + ax = fig.add_subplot(2, ncols, i+1) + mean = CLOSED_CURVES_SPACE.projection(mean) + ax.plot(mean[:, 0], mean[:, 1], "black") + ax.set_aspect("equal") + ax.axis("off") + axs_title = mean_name + if mean_name not in ["Linear", "SRV"]: + a = mean_name[0] + b = mean_name[1] + ratio = a / (2 * b) + mean_name = f"Elastic {mean_name}\n a / (2b) = {ratio}" + ax.set_title(mean_name) + + +st.markdown("__Remark:__ Unfortunately, there are some numerical issues with the projection in the space of closed curves, as shown by the V-shaped results above.") + +st.markdown("Since ratios of 1 give the same results as for the SRV metric, we only select AS, BS with a ratio that is not 1 for the elastic metrics.") + +st.markdown("We also continue the analysis with the space of open curves, as opposed to the space of closed curves, for the numerical issues observed above.") + + +NEW_AS = [0.75, 0.5, 0.25, 0.01] #, 1.6] #, 1.4, 1.2, 1, 0.5, 0.2, 0.1] +NEW_BS = [0.5, 0.5, 0.5, 0.5] #, 2, 2, 2, 2, 2, 2, 2] + +st.markdown("## Distances to the Mean") + +# We multiply the distances by a 100, for visualization purposes. It amounts to a change of units. +dists = {} + +dists["Linear"] = [100 * gs.linalg.norm(means["Linear"] - cell) / n_sampling_points for cell in cell_shapes] + +dists["SRV"] = [ + 100 * SRV_METRIC.dist(means["SRV"], cell) / n_sampling_points for cell in cell_shapes +] + +for a, b in zip(NEW_AS, NEW_BS): + dists[a, b] = [ + 100 * ELASTIC_METRIC[a, b].dist(means[a, b], cell) / n_sampling_points for cell in cell_shapes +] + + +dists_summary = pd.DataFrame( + data={ + labels_a_name: labels_a, + labels_b_name: labels_b, + "Linear": dists["Linear"], + "SRV": dists["SRV"], + } +) + +for a, b in zip(NEW_AS, NEW_BS): + dists_summary[f"Elastic({a}, {b})"] = dists[a, b] + +st.dataframe(dists_summary) +# SAVEFIG = True +# if SAVEFIG: +# figs_dir = os.path.join(work_dir, f"cells/saved_figs/{dataset_name}") +# if not os.path.exists(figs_dir): +# os.makedirs(figs_dir) +# print(f"Will save figs to {figs_dir}") +# from datetime import datetime + +# now = datetime.now().strftime("%Y%m%d_%H_%M_%S") +# print(now) \ No newline at end of file diff --git a/cells/streamlit/cellgeometry/pages/1-Load_Data.py b/cells/streamlit/cellgeometry/pages/1-Load_Data.py new file mode 100644 index 0000000..ebc6173 --- /dev/null +++ b/cells/streamlit/cellgeometry/pages/1-Load_Data.py @@ -0,0 +1,119 @@ +import streamlit as st +import pandas as pd +import os +import time +import matplotlib.pyplot as plt + +import sys +sys.path.append("/app/utils") + +from utils.data_utils import build_rois, find_all_instances + +current_time = time.localtime() + +year = time.strftime("%Y", current_time) +day_of_year = time.strftime("%j", current_time) +time_string = time.strftime("%H%M%S", current_time) + +current_time_string = f"{year}{day_of_year}-{time_string}" + +if "cells_list" not in st.session_state: + st.session_state["cells_list"] = True + +st.write("# Load Your Cell Data šŸ‘‹") + +st.markdown( +""" +## Getting Started + +We currently support an ROI zip folder created by FIJI/ImageJ. What this means is you may have a folder structure as follows: +``` + └── Cropped_Images + ā”œā”€ā”€ Bottom_plank_0 + │ ā”œā”€ā”€ Averaged_ROI + │ ā”œā”€ā”€ Data + │ ā”œā”€ā”€ Data_Filtered + │ ā”œā”€ā”€ Labels + │ ā”œā”€ā”€ OG + │ ā”œā”€ā”€ Outlines + │ └── ROIs <---- Folder of zipped ROIs +``` +You can simply upload this ROIs folder and we will load your data for you. We plan on supporting data given in `xy` coordinate format from `JSON` and CSV/TXT files. +Your chosen data structure __must__ contain `x` and `y` for the program to correctly parse and load your data. +""" +) + +def get_files_from_folder(folder_path): + """ + Retrieves a list of files from a specific folder. + + Parameters: + folder_path (str): The path to the folder. + + Returns: + list: A list of file paths. + + Example: + >>> folder_path = '/path/to/folder' + >>> files = get_files_from_folder(folder_path) + >>> print(files) + ['/path/to/folder/file1.txt', '/path/to/folder/file2.csv', '/path/to/folder/file3.jpg'] + """ + files = [] + for filename in os.listdir(folder_path): + if os.path.isfile(os.path.join(folder_path, filename)): + files.append(os.path.join(folder_path, filename)) + return files + + + +# Specify the folder path for file uploads and save run with date and time +upload_folder = f"/app/data/run-{current_time_string}" + +# Check if the upload folder exists, and create it if it doesn't +if not os.path.exists(upload_folder): + os.makedirs(upload_folder) + st.info(f"Upload folder created: {upload_folder}") + +# Get the list of files in the upload folder +files = get_files_from_folder(upload_folder) + + + +# Display the file uploader +uploaded_files = st.file_uploader("Upload a file", type=["zip"], accept_multiple_files=True) + + +# Process the uploaded files +if uploaded_files is not None: + progress_bar = st.progress(0) + total_files = len(uploaded_files) + completed_files = 0 + + for uploaded_file in uploaded_files: + file_path = os.path.join(upload_folder, uploaded_file.name) + with open(file_path, "wb") as f: + f.write(uploaded_file.getbuffer()) + completed_files += 1 + progress = int((completed_files / total_files) * 100) + progress_bar.progress(progress) + # st.write(f"File saved: {file_path}") + + +# Build a dictionary of all the ROIs +dict_rois = build_rois(upload_folder) + +# Extract the cells +cells_list = [] +find_all_instances(dict_rois, 'x', 'y', cells_list) +st.session_state["cells_list"] = cells_list + +st.write(f"Successfully Loaded {len(cells_list)} cells.") + +# Sanity check visualization +cell_num = st.number_input(f"Visualize a cell. Pick a number between 0 and {len(cells_list)-1}", min_value=0) + + +fig, ax = plt.subplots() +ax.plot(cells_list[cell_num][:,0], cells_list[cell_num][:,1]) +st.pyplot(fig) diff --git a/cells/streamlit/cellgeometry/pages/Cell_Shear.py b/cells/streamlit/cellgeometry/pages/Cell_Shear.py new file mode 100644 index 0000000..fcbe50d --- /dev/null +++ b/cells/streamlit/cellgeometry/pages/Cell_Shear.py @@ -0,0 +1,107 @@ +import streamlit as st +import pandas as pd +import matplotlib.pyplot as plt + +import geomstats.backend as gs +from geomstats.geometry.euclidean import Euclidean +from geomstats.geometry.discrete_curves import R2, DiscreteCurves, ClosedDiscreteCurves + +from geomstats.learning.frechet_mean import FrechetMean +from geomstats.learning.kmeans import RiemannianKMeans +from geomstats.learning.mdm import RiemannianMinimumDistanceToMean +from geomstats.learning.pca import TangentPCA + +from utils import experimental + +st.write(st.session_state["cells_list"]) + +st.write("# Welcome to the Cell Shear Analysis App! šŸ‘‹") + +st.markdown( + """ + ## Step Zero + + šŸ‘ˆ If you have not already uploaded your data, please select the __Load Data__ page and follow the instructions. The format is important, so please read carefully. + + ## Analyzing Cell Data + + Now we will start analyzing our data. The first step is preprocessing our data, specifically interpolating, removing duplicates, and quotienting. +""" +) + +cells_list = st.session_state["cells_list"] + +n_sampling_points = st.slider('Select the Number of Sampling Points', 0, 100, 50) +cells, cell_shapes = experimental.nolabel_preprocess(cells_list, len(cells_list), n_sampling_points) + + +R1 = Euclidean(dim=1) +CLOSED_CURVES_SPACE = ClosedDiscreteCurves(R2) +CURVES_SPACE = DiscreteCurves(R2) +SRV_METRIC = CURVES_SPACE.srv_metric +L2_METRIC = CURVES_SPACE.l2_curves_metric + +ELASTIC_METRIC = {} +AS = [1, 2, 0.75, 0.5, 0.25, 0.01] #, 1.6] #, 1.4, 1.2, 1, 0.5, 0.2, 0.1] +BS = [0.5, 1, 0.5, 0.5, 0.5, 0.5] #, 2, 2, 2, 2, 2, 2, 2] +for a, b in zip(AS, BS): + ELASTIC_METRIC[a, b] = DiscreteCurves(R2, a=a, b=b).elastic_metric +METRICS = {} +METRICS["Linear"] = L2_METRIC +METRICS["SRV"] = SRV_METRIC + + +means = {} + +means["Linear"] = gs.mean(cell_shapes, axis=0) +means["SRV"] = FrechetMean( + metric=SRV_METRIC, + method="default").fit(cell_shapes).estimate_ + + +for a, b in zip(AS, BS): + means[a, b] = FrechetMean( + metric=ELASTIC_METRIC[a, b], + method="default").fit(cell_shapes).estimate_ + + +fig = plt.figure(figsize=(18, 8)) + +ncols = len(means) // 2 + +for i, (mean_name, mean) in enumerate(means.items()): + ax = fig.add_subplot(2, ncols, i+1) + ax.plot(mean[:, 0], mean[:, 1], "black") + ax.set_aspect("equal") + ax.axis("off") + axs_title = mean_name + if mean_name not in ["Linear", "SRV"]: + a = mean_name[0] + b = mean_name[1] + ratio = a / (2 * b) + mean_name = f"Elastic {mean_name}\n a / (2b) = {ratio}" + ax.set_title(mean_name) + +st.pyplot(fig) + + +fig = plt.figure(figsize=(18, 8)) + +ncols = len(means) // 2 + +for i, (mean_name, mean) in enumerate(means.items()): + ax = fig.add_subplot(2, ncols, i+1) + mean = CLOSED_CURVES_SPACE.projection(mean) + ax.plot(mean[:, 0], mean[:, 1], "black") + ax.set_aspect("equal") + ax.axis("off") + axs_title = mean_name + if mean_name not in ["Linear", "SRV"]: + a = mean_name[0] + b = mean_name[1] + ratio = a / (2 * b) + mean_name = f"Elastic {mean_name}\n a / (2b) = {ratio}" + ax.set_title(mean_name) + +st.pyplot(fig) + diff --git a/cells/streamlit/cellgeometry/utils/basic.py b/cells/streamlit/cellgeometry/utils/basic.py new file mode 100644 index 0000000..529faa0 --- /dev/null +++ b/cells/streamlit/cellgeometry/utils/basic.py @@ -0,0 +1,34 @@ +"""Compute basic shape features.""" + +import geomstats.backend as gs + + +def perimeter(xy): + """Calculate polygon perimeter. + + Parameters + ---------- + xy : array-like, shape=[n_points, 2] + Polygon, such that: + x = xy[:, 0]; y = xy[:, 1] + """ + first_point = gs.expand_dims(gs.array(xy[0]), axis=0) + xy1 = gs.concatenate([xy[1:], first_point], axis=0) + return gs.sum(gs.sqrt((xy1[:, 0] - xy[:, 0]) ** 2 + (xy1[:, 1] - xy[:, 1]) ** 2)) + + +def area(xy): + """Calculate polygon area. + + Parameters + ---------- + xy : array-like, shape=[n_points, 2] + Polygon, such that: + x = xy[:, 0]; y = xy[:, 1] + """ + n_points = len(xy) + s = 0.0 + for i in range(n_points): + j = (i + 1) % n_points + s += (xy[j, 0] - xy[i, 0]) * (xy[j, 1] + xy[i, 1]) + return -0.5 * s diff --git a/cells/streamlit/cellgeometry/utils/data_utils.py b/cells/streamlit/cellgeometry/utils/data_utils.py new file mode 100644 index 0000000..37dbb31 --- /dev/null +++ b/cells/streamlit/cellgeometry/utils/data_utils.py @@ -0,0 +1,100 @@ +import os +from read_roi import read_roi_zip +import numpy as np + + +def build_rois(path) -> dict: + """ + Builds a dictionary of region of interest (ROI) data from a directory of ROI files. + + Parameters: + path (str): The path to the directory containing ROI files. + + Returns: + dict: A dictionary where the keys are ROI names and the values are the corresponding ROI data. + + Example: + >>> roi_directory = '/path/to/roi_directory' + >>> rois = build_rois(roi_directory) + >>> print(rois) + {'roi1': , 'roi2': , ...} + """ + rois = {} + for roi in sorted(os.listdir(path)): + # print(roi.split(".")[0]) + rois[roi.split(".")[0]] = read_roi_zip(os.path.join(path,roi)) + return rois + + +def find_key(dictionary, target_key): + """ + Recursively searches for a key in a nested dictionary. + + Parameters: + dictionary (dict): The nested dictionary to search. + target_key (str): The key to find. + + Returns: + object: The value associated with the target key, or None if the key is not found. + + Example: + >>> data = { + ... 'key1': { + ... 'key2': { + ... 'key3': 'value3', + ... 'key4': 'value4' + ... } + ... } + ... } + >>> result = find_key(data, 'key4') + >>> print(result) + value4 + """ + if target_key in dictionary: + return dictionary[target_key] + + for value in dictionary.values(): + if isinstance(value, dict): + result = find_key(value, target_key) + if result is not None: + return result + + return None + + +def find_all_instances(dictionary, target_key1, target_key2, results_list): + """ + Recursively finds instances of two target keys in a nested dictionary and appends their corresponding values together. + + Parameters: + dictionary (dict): The nested dictionary to search. + target_key1 (hashable): The first target key to find. + target_key2 (hashable): The second target key to find. + results_list (list): The list where the corresponding values will be appended. + + Returns: + None + + Example: + >>> my_dict = { + ... "a": 1, + ... "b": {"c": 2, "d": 3}, + ... "e": {"f": 4, "g": {"a": 5, "c": 6}}, + ... "i": 7 + ... } + >>> target_key1 = "a" + >>> target_key2 = "c" + >>> instances = [] + >>> find_all_instances(my_dict, target_key1, target_key2, instances) + >>> print(instances) + [5, 6] + """ + found_keys = set() + for key, value in dictionary.items(): + if key == target_key1 or key == target_key2: + found_keys.add(key) + elif isinstance(value, dict): + find_all_instances(value, target_key1, target_key2, results_list) + + if {target_key1, target_key2}.issubset(found_keys): + results_list.append(np.array([dictionary[target_key1], dictionary[target_key2]]).T) \ No newline at end of file diff --git a/cells/streamlit/cellgeometry/utils/experimental.py b/cells/streamlit/cellgeometry/utils/experimental.py new file mode 100644 index 0000000..e072982 --- /dev/null +++ b/cells/streamlit/cellgeometry/utils/experimental.py @@ -0,0 +1,317 @@ +"""Utils to load experimental datasets of cells.""" + +from utils import basic as basic +import geomstats.backend as gs +import geomstats.datasets.utils as data_utils +import numpy as np +import skimage.io as skio +from geomstats.geometry.pre_shape import PreShapeSpace +from skimage import measure +from skimage.filters import threshold_otsu + +M_AMBIENT = 2 + + +def img_to_contour(img): + """Extract the longest cluster/cell contour from an image. + + Parameters + ---------- + img : array-like + Image showing cells or cell clusters. + + Returns + ------- + contour : array-like, shape=[n_sampling_points, 2] + Contour of the longest cluster/cell in the image, + as an array of 2D coordinates on points sampling + the contour. + """ + thresh = threshold_otsu(img) + binary = img > thresh + contours = measure.find_contours(binary, 0.8) + lengths = [len(c) for c in contours] + max_length = max(lengths) + index_max_length = lengths.index(max_length) + contour = contours[index_max_length] + return contour + + +def _tif_video_to_lists(tif_path): + """Convert a cell video into two trajectories of contours and images. + + Parameters + ---------- + tif_path : absolute path of video in .tif format. + + Returns + ------- + contours_list : list of arrays + List of 2D coordinates of points defining the contours of each cell + within the video. + imgs_list : list of array + List of images in the input video. + """ + img_stack = skio.imread(tif_path, plugin="tifffile") + contours_list = [] + imgs_list = [] + for img in img_stack: + imgs_list.append(img) + contour = img_to_contour(img) + contours_list.append(contour) + + return contours_list, imgs_list + + +def _interpolate(curve, n_sampling_points): + """Interpolate a discrete curve with nb_points from a discrete curve. + + Parameters + ---------- + curve : array-like, shape=[n_points, 2] + n_sampling_points : int + + Returns + ------- + interpolation : array-like, shape=[n_sampling_points, 2] + Discrete curve with n_sampling_points + """ + old_length = curve.shape[0] + interpolation = np.zeros((n_sampling_points, 2)) + incr = old_length / n_sampling_points + pos = np.array(0.0, dtype=np.float32) + for i in range(n_sampling_points): + index = int(np.floor(pos)) + interpolation[i] = curve[index] + (pos - index) * ( + curve[(index + 1) % old_length] - curve[index] + ) + pos += incr + return gs.array(interpolation, dtype=gs.float32) + + +def _remove_consecutive_duplicates(curve, tol=1e-2): + """Preprocess curve to ensure that there are no consecutive duplicate points. + + Returns + ------- + curve : discrete curve + """ + dist = curve[1:] - curve[:-1] + dist_norm = gs.sqrt(gs.sum(dist**2, axis=1)) + + if gs.any(dist_norm < tol): + for i in range(len(curve) - 2): + if gs.sqrt(gs.sum((curve[i + 1] - curve[i]) ** 2, axis=0)) < tol: + curve[i + 1] = (curve[i] + curve[i + 2]) / 2 + + return curve + + +def _exhaustive_align(curve, base_curve): + """Project a curve in shape space. + + This happens in 2 steps: + - remove translation (and scaling?) by projecting in pre-shape space. + - remove rotation by exhaustive alignment minimizing the L² distance. + + Returns + ------- + aligned_curve : discrete curve + """ + n_sampling_points = curve.shape[-2] + preshape = PreShapeSpace(m_ambient=M_AMBIENT, k_landmarks=n_sampling_points) + + nb_sampling = len(curve) + distances = gs.zeros(nb_sampling) + for shift in range(nb_sampling): + reparametrized = gs.array( + [curve[(i + shift) % nb_sampling] for i in range(nb_sampling)] + ) + aligned = preshape.align(point=reparametrized, base_point=base_curve) + distances[shift] = preshape.total_space_metric.norm( + gs.array(aligned) - gs.array(base_curve) + ) + shift_min = gs.argmin(distances) + reparametrized_min = gs.array( + [curve[(i + shift_min) % nb_sampling] for i in range(nb_sampling)] + ) + aligned_curve = preshape.align(point=reparametrized_min, base_point=base_curve) + return aligned_curve + + +def preprocess( + cells, + labels_a, + labels_b, + n_cells, + n_sampling_points, + quotient=["scaling", "rotation"], +): + """Preprocess a dataset of cells. + + Parameters + ---------- + cells : list of all cells + Each cell is an array of points in 2D. + labels_a : list of str + List of labels associated with each cell. + labels_b : list of str + List of labels associated with each cell. + n_cells : int + Number of cells to (randomly) select from this dataset. + n_sampling_points : int + Number of sampling points along the boundary of each cell. + """ + if n_cells > 0: + print(f"... Selecting only a random subset of {n_cells} / {len(cells)} cells.") + indices = sorted( + np.random.choice(gs.arange(0, len(cells), 1), size=n_cells, replace=False) + ) + cells = [cells[idx] for idx in indices] + labels_a = [labels_a[idx] for idx in indices] + labels_b = [labels_b[idx] for idx in indices] + + if n_sampling_points > 0: + print( + "... Interpolating: " + f"Cell boundaries have {n_sampling_points} samplings points." + ) + interpolated_cells = gs.zeros((n_cells, n_sampling_points, 2)) + for i_cell, cell in enumerate(cells): + interpolated_cells[i_cell] = _interpolate(cell, n_sampling_points) + + cells = interpolated_cells + + print("... Removing potential duplicate sampling points on cell boundaries.") + for i_cell, cell in enumerate(cells): + cells[i_cell] = _remove_consecutive_duplicates(cell) + + print("\n- Cells: quotienting translation.") + cells = cells - gs.mean(cells, axis=-2)[..., None, :] + + cell_shapes = gs.zeros_like(cells) + if "scaling" in quotient: + print("- Cell shapes: quotienting scaling (length).") + for i_cell, cell in enumerate(cells): + cell_shapes[i_cell] = cell / basic.perimeter(cell) + + if "rotation" in quotient: + print("- Cell shapes: quotienting rotation.") + if "scaling" not in quotient: + for i_cell, cell_shape in enumerate(cells): + cell_shapes[i_cell] = _exhaustive_align(cell_shape, cells[0]) + else: + for i_cell, cell_shape in enumerate(cell_shapes): + cell_shapes[i_cell] = _exhaustive_align(cell_shape, cell_shapes[0]) + + return cells, cell_shapes, labels_a, labels_b + + + +def nolabel_preprocess( + cells, + # labels_a, + # labels_b, + n_cells, + n_sampling_points, + quotient=["scaling", "rotation"], +): + """Preprocess a dataset of cells. + + Parameters + ---------- + cells : list of all cells + Each cell is an array of points in 2D. + labels_a : list of str + List of labels associated with each cell. + labels_b : list of str + List of labels associated with each cell. + n_cells : int + Number of cells to (randomly) select from this dataset. + n_sampling_points : int + Number of sampling points along the boundary of each cell. + """ + # if n_cells > 0: + # print(f"... Selecting only a random subset of {n_cells} / {len(cells)} cells.") + # indices = sorted( + # np.random.choice(gs.arange(0, len(cells), 1), size=n_cells, replace=False) + # ) + # cells = [cells[idx] for idx in indices] + # labels_a = [labels_a[idx] for idx in indices] + # labels_b = [labels_b[idx] for idx in indices] + + if n_sampling_points > 0: + print( + "... Interpolating: " + f"Cell boundaries have {n_sampling_points} samplings points." + ) + interpolated_cells = gs.zeros((n_cells, n_sampling_points, 2)) + for i_cell, cell in enumerate(cells): + interpolated_cells[i_cell] = _interpolate(cell, n_sampling_points) + + cells = interpolated_cells + + print("... Removing potential duplicate sampling points on cell boundaries.") + for i_cell, cell in enumerate(cells): + cells[i_cell] = _remove_consecutive_duplicates(cell) + + print("\n- Cells: quotienting translation.") + cells = cells - gs.mean(cells, axis=-2)[..., None, :] + + cell_shapes = gs.zeros_like(cells) + if "scaling" in quotient: + print("- Cell shapes: quotienting scaling (length).") + for i_cell, cell in enumerate(cells): + cell_shapes[i_cell] = cell / basic.perimeter(cell) + + if "rotation" in quotient: + print("- Cell shapes: quotienting rotation.") + if "scaling" not in quotient: + for i_cell, cell_shape in enumerate(cells): + cell_shapes[i_cell] = _exhaustive_align(cell_shape, cells[0]) + else: + for i_cell, cell_shape in enumerate(cell_shapes): + cell_shapes[i_cell] = _exhaustive_align(cell_shape, cell_shapes[0]) + + return cells, cell_shapes #, labels_a, labels_b + + +def load_treated_osteosarcoma_cells( + n_cells=-1, n_sampling_points=10, quotient=["scaling", "rotation"] +): + """Load dataset of osteosarcoma cells (bone cancer cells). + + This cell dataset contains cell boundaries of mouse osteosarcoma + (bone cancer) cells. The dlm8 cell line is derived from dunn and is more + aggressive as a cancer. The cells have been treated with one of three + treatments : control (no treatment), jasp (jasplakinolide) + and cytd (cytochalasin D). These are drugs which perturb the cytoskelet + of the cells. + + Parameters + ---------- + n_sampling_points : int + Number of points used to interpolate each cell boundary. + Optional, Default: 0. + If equal to 0, then no interpolation is performed. + + Returns + ------- + cells : array of n_cells planar discrete curves + Each curve represents the boundary of a cell in counterclockwise order. + Their barycenters are fixed at 0 (translation has been removed). + Their lengths are not necessarily equal (scaling has not been removed). + cell_shapes : array of n_cells planar discrete curves shapes + Each curve represents the boundary of a cell in counterclockwise order. + Their barycenters are fixed at 0 (translation has been removed). + Their lengths are fixed at 1 (scaling has been removed). + They are aligned in rotation to the first cell (rotation has been removed). + lines : list of n_cells strings + List of the cell lines of each cell (dlm8 or dunn). + treatments : list of n_cells strings + List of the treatments given to each cell (control, cytd or jasp). + """ + cells, lines, treatments = data_utils.load_cells() + return preprocess( + cells, lines, treatments, n_cells, n_sampling_points, quotient=quotient + )