1
+ import os
2
+ import subprocess
3
+ import geomstats .backend as gs
4
+ import streamlit as st
5
+ import time
6
+ import numpy as np
7
+ import matplotlib .pyplot as plt
8
+ from scipy import stats
9
+
10
+ from sklearn .cluster import KMeans
11
+ from sklearn .metrics import confusion_matrix
12
+ from sklearn .decomposition import PCA
13
+
14
+ from scipy .optimize import linear_sum_assignment as linear_assignment
15
+ from sklearn import manifold
16
+ from joblib import Parallel , delayed
17
+ from numba import jit , njit , prange
18
+
19
+ from geomstats .geometry .euclidean import Euclidean
20
+ from geomstats .geometry .discrete_curves import R2 , DiscreteCurves , ClosedDiscreteCurves
21
+
22
+ from geomstats .learning .frechet_mean import FrechetMean
23
+ from geomstats .learning .kmeans import RiemannianKMeans
24
+ from geomstats .learning .mdm import RiemannianMinimumDistanceToMean
25
+ from geomstats .learning .pca import TangentPCA
26
+
27
+ import sys
28
+ sys .path .append ("/app/utils" )
29
+
30
+ # import utils
31
+ from utils import experimental as experimental
32
+ from utils import basic as basic
33
+
34
+
35
+
36
+
37
+ st .set_page_config (page_title = "Elastic Metric for Cell Boundary Analysis" , page_icon = "📈" )
38
+
39
+ st .markdown ("# Shape Analysis of Cancer Cells" )
40
+ st .sidebar .header ("Shape Analysis of Cancer Cells" )
41
+ st .write (
42
+ """This notebook studies Osteosarcoma (bone cancer) cells and the impact of drug treatment on their morphological shapes, by analyzing cell images obtained from fluorescence microscopy.
43
+
44
+ This analysis relies on the elastic metric between discrete curves from Geomstats. We will study to which extent this metric can detect how the cell shape is associated with the response to treatment."""
45
+ )
46
+
47
+ dataset_name = "osteosarcoma"
48
+
49
+ n_sampling_points = st .slider ('Select the Number of Sampling Points' , 0 , 250 , 100 )
50
+ n_cells = 650
51
+ # n_sampling_points = 100
52
+ labels_a_name = "lines"
53
+ labels_b_name = "treatments"
54
+
55
+ quotient = ["rotation" ] #["scaling"] #, "rotation"]
56
+ do_not_quotient = False
57
+
58
+
59
+ if dataset_name == "osteosarcoma" :
60
+ cells , cell_shapes , labels_a , labels_b = experimental .load_treated_osteosarcoma_cells (
61
+ n_cells = n_cells , n_sampling_points = n_sampling_points , quotient = quotient
62
+ )
63
+ else :
64
+ pass
65
+
66
+
67
+ labels_a_dict = {lab : i_lab for i_lab , lab in enumerate (np .unique (labels_a ))}
68
+ labels_b_dict = {lab : i_lab for i_lab , lab in enumerate (np .unique (labels_b ))}
69
+
70
+ print (f"Dictionary associated to label \" { labels_a_name } \" :" )
71
+ print (labels_a_dict )
72
+ print (f"Dictionary associated to label \" { labels_b_name } \" :" )
73
+ print (labels_b_dict )
74
+
75
+ if do_not_quotient :
76
+ cell_shapes = cells
77
+
78
+ n_cells_to_plot = 10
79
+
80
+ fig = plt .figure (figsize = (16 , 6 ))
81
+ count = 1
82
+ for label_b in np .unique (labels_b ):
83
+ for i_lab_a , label_a in enumerate (np .unique (labels_a )):
84
+ cell_data = [cell for cell , lab_a , lab_b in zip (cell_shapes , labels_a , labels_b ) if lab_a == label_a and lab_b == label_b ]
85
+ for i_to_plot in range (n_cells_to_plot ):
86
+ cell = gs .random .choice (a = cell_data )
87
+ fig .add_subplot (len (np .unique (labels_b )), len (np .unique (labels_a )) * n_cells_to_plot , count )
88
+ count += 1
89
+ plt .plot (cell [:, 0 ], cell [:, 1 ], color = f"C{ i_lab_a } " )
90
+ plt .axis ("equal" )
91
+ plt .axis ("off" )
92
+ if i_to_plot == n_cells_to_plot // 2 :
93
+ plt .title (f"{ label_a } - { label_b } " , fontsize = 20 )
94
+ st .pyplot (fig )
95
+
96
+ # Define shape space
97
+ R1 = Euclidean (dim = 1 )
98
+ CLOSED_CURVES_SPACE = ClosedDiscreteCurves (R2 )
99
+ CURVES_SPACE = DiscreteCurves (R2 )
100
+ SRV_METRIC = CURVES_SPACE .srv_metric
101
+ L2_METRIC = CURVES_SPACE .l2_curves_metric
102
+
103
+ ELASTIC_METRIC = {}
104
+ AS = [1 , 2 , 0.75 , 0.5 , 0.25 , 0.01 ] #, 1.6] #, 1.4, 1.2, 1, 0.5, 0.2, 0.1]
105
+ BS = [0.5 , 1 , 0.5 , 0.5 , 0.5 , 0.5 ] #, 2, 2, 2, 2, 2, 2, 2]
106
+ for a , b in zip (AS , BS ):
107
+ ELASTIC_METRIC [a , b ] = DiscreteCurves (R2 , a = a , b = b ).elastic_metric
108
+ METRICS = {}
109
+ METRICS ["Linear" ] = L2_METRIC
110
+ METRICS ["SRV" ] = SRV_METRIC
111
+
112
+ means = {}
113
+
114
+ means ["Linear" ] = gs .mean (cell_shapes , axis = 0 )
115
+ means ["SRV" ] = FrechetMean (
116
+ metric = SRV_METRIC ,
117
+ method = "default" ).fit (cell_shapes ).estimate_
118
+
119
+ for a , b in zip (AS , BS ):
120
+ means [a , b ] = FrechetMean (
121
+ metric = ELASTIC_METRIC [a , b ],
122
+ method = "default" ).fit (cell_shapes ).estimate_
123
+
124
+ st .header ("Sample Means" )
125
+ st .markdown ("We compare results when computing the mean cell versus the mean cell shapes with different elastic metrics." )
126
+ fig = plt .figure (figsize = (18 , 8 ))
127
+
128
+ ncols = len (means ) // 2
129
+
130
+ for i , (mean_name , mean ) in enumerate (means .items ()):
131
+ ax = fig .add_subplot (2 , ncols , i + 1 )
132
+ ax .plot (mean [:, 0 ], mean [:, 1 ], "black" )
133
+ ax .set_aspect ("equal" )
134
+ ax .axis ("off" )
135
+ axs_title = mean_name
136
+ if mean_name not in ["Linear" , "SRV" ]:
137
+ a = mean_name [0 ]
138
+ b = mean_name [1 ]
139
+ ratio = a / (2 * b )
140
+ mean_name = f"Elastic { mean_name } \n a / (2b) = { ratio } "
141
+ ax .set_title (mean_name )
142
+
143
+ st .pyplot (fig )
144
+
145
+ # SAVEFIG = True
146
+ # if SAVEFIG:
147
+ # figs_dir = os.path.join(work_dir, f"cells/saved_figs/{dataset_name}")
148
+ # if not os.path.exists(figs_dir):
149
+ # os.makedirs(figs_dir)
150
+ # print(f"Will save figs to {figs_dir}")
151
+ # from datetime import datetime
152
+
153
+ # now = datetime.now().strftime("%Y%m%d_%H_%M_%S")
154
+ # print(now)
0 commit comments