Skip to content

Commit 3b0489b

Browse files
authored
Merge pull request #140 from geometric-intelligence/new-nbs
New nbs
2 parents d7acc36 + 0abd3c7 commit 3b0489b

17 files changed

+513379
-50317
lines changed

tutorials/01_methods_create_synthetic_data.ipynb

+17,358-3,085
Large diffs are not rendered by default.

tutorials/02_methods_estimate_manifold_dimension.ipynb

+3,091-30
Large diffs are not rendered by default.

tutorials/11_application_synthetic_v1.ipynb

+143,367-1,457
Large diffs are not rendered by default.

tutorials/place cell processing/14_preprocess_ca1_hippocampus_data.ipynb

-41,781
This file was deleted.

tutorials/place cell processing/21_explore_binned_experimental_place_cells.ipynb

-1,931
This file was deleted.

tutorials/place cell processing/23_inspect_reconstructed_experimental_place_cells.ipynb

-2,028
This file was deleted.

tutorials/place_cells/14_preprocess_ca1_hippocampus_data.ipynb

+335,739
Large diffs are not rendered by default.

tutorials/place_cells/21_explore_binned_experimental_place_cells.ipynb

+10,312
Large diffs are not rendered by default.

tutorials/place_cells/23_inspect_reconstructed_experimental_place_cells.ipynb

+3,234
Large diffs are not rendered by default.

tutorials/place cell processing/27_plot_neural_manifolds.ipynb renamed to tutorials/place_cells/27_plot_neural_manifolds.ipynb

+278-5
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,18 @@
105105
"source": [
106106
"import json\n",
107107
"\n",
108-
"import default_config as config\n",
109108
"import matplotlib as mpl\n",
109+
"import matplotlib.cm as cm\n",
110110
"import matplotlib.pyplot as plt\n",
111111
"import numpy as np\n",
112112
"import pandas as pd\n",
113113
"import torch\n",
114-
"import viz\n",
115114
"\n",
116-
"import neurometry.datasets.utils as utils"
115+
"import default_config as config\n",
116+
"import neurometry.datasets.utils as utils\n",
117+
"import neurometry.models.neural_vae as neural_vae\n",
118+
"import train\n",
119+
"import viz"
117120
]
118121
},
119122
{
@@ -909,6 +912,7 @@
909912
"source": [
910913
"with open(\n",
911914
" os.path.join(CONFIG_DIR, run_id_config_file),\n",
915+
" \"r\",\n",
912916
") as f:\n",
913917
" config_dict = json.load(f)\n",
914918
"\n",
@@ -1005,7 +1009,7 @@
10051009
"# This is needed for the utils.load() function\n",
10061010
"class AttrDict(dict):\n",
10071011
" def __init__(self, *args, **kwargs):\n",
1008-
" super().__init__(*args, **kwargs)\n",
1012+
" super(AttrDict, self).__init__(*args, **kwargs)\n",
10091013
" self.__dict__ = self\n",
10101014
"\n",
10111015
"\n",
@@ -1821,10 +1825,11 @@
18211825
"source": [
18221826
"import time\n",
18231827
"\n",
1824-
"import evaluate\n",
18251828
"import geomstats.backend as gs\n",
18261829
"from geomstats.geometry.pullback_metric import PullbackMetric\n",
18271830
"\n",
1831+
"import evaluate\n",
1832+
"\n",
18281833
"learned_immersion = evaluate.get_learned_immersion(model, config)\n",
18291834
"\n",
18301835
"neural_metric = PullbackMetric(\n",
@@ -2012,3 +2017,271 @@
20122017
"id": "28c0e407",
20132018
"metadata": {},
20142019
"source": [
2020+
"Let's parallelize this."
2021+
]
2022+
},
2023+
{
2024+
"cell_type": "code",
2025+
"execution_count": null,
2026+
"id": "42875293",
2027+
"metadata": {
2028+
"tags": []
2029+
},
2030+
"outputs": [],
2031+
"source": [
2032+
"import copy\n",
2033+
"import logging\n",
2034+
"\n",
2035+
"from joblib import Parallel, delayed\n",
2036+
"\n",
2037+
"model.to(\"cuda:0\")\n",
2038+
"z_grid = torch.tensor(curv_norm_learned_profile[\"z_grid\"].values)\n",
2039+
"z0 = torch.unsqueeze(z_grid[0], dim=0)\n",
2040+
"\n",
2041+
"\n",
2042+
"# TODO (use logging to actually print and know which iterations we are on)\n",
2043+
"def _geodesic_dist(i_z, z, grid_interval):\n",
2044+
" if i_z == 0:\n",
2045+
" return (0, torch.tensor(0.0))\n",
2046+
" # Parallelize on gpus: is that line really helping,\n",
2047+
" # i.e. is the copy taking less time than the computation? if not, don't bother\n",
2048+
" # also: does it change anything since neural metric has already been computed and uses the old version of model?\n",
2049+
" model_copy = copy.deepcopy(model).to(f\"cuda:{i_z % 9}\")\n",
2050+
" z = torch.unsqueeze(z, dim=0)\n",
2051+
" z_previous = torch.unsqueeze(z_grid[i_z - grid_interval], dim=0)\n",
2052+
" start = time.time()\n",
2053+
" # Tricks to speed up this computation:\n",
2054+
" # 1. Compute with less number of steps for the integration of the geodesic eqn\n",
2055+
" # 2. Compute distance between neighboring z's\n",
2056+
" dist = neural_metric.dist(z_previous, z, n_steps=7)\n",
2057+
" duration = time.time() - start\n",
2058+
" logging.info(f\"Time (it: {i_z}): {duration:.3f}\")\n",
2059+
" return (i_z, dist)\n",
2060+
"\n",
2061+
"\n",
2062+
"# To try this code , use z_grid[:5] to run on smaller batch\n",
2063+
"# Currently, the z_grid is too big (~800) --> go to 100\n",
2064+
"# Note: each distance computation takes ~5s.\n",
2065+
"grid_interval = 1\n",
2066+
"res = Parallel(n_jobs=-1)(\n",
2067+
" delayed(_geodesic_dist)(i_z, z, grid_interval)\n",
2068+
" for i_z, z in enumerate(z_grid)\n",
2069+
" if i_z % grid_interval == 0\n",
2070+
")"
2071+
]
2072+
},
2073+
{
2074+
"cell_type": "code",
2075+
"execution_count": 16,
2076+
"id": "8641f943-5883-4fdc-bc18-f4ccb2f4f5b9",
2077+
"metadata": {},
2078+
"outputs": [
2079+
{
2080+
"data": {
2081+
"text/plain": [
2082+
"pandas.core.series.Series"
2083+
]
2084+
},
2085+
"execution_count": 16,
2086+
"metadata": {},
2087+
"output_type": "execute_result"
2088+
}
2089+
],
2090+
"source": [
2091+
"type(curv_norm_learned_profile[\"curv_norm_learned\"])"
2092+
]
2093+
},
2094+
{
2095+
"cell_type": "code",
2096+
"execution_count": null,
2097+
"id": "e3889082",
2098+
"metadata": {
2099+
"tags": []
2100+
},
2101+
"outputs": [],
2102+
"source": [
2103+
"geodesic_dists = torch.zeros(len(res))\n",
2104+
"curv_norms = torch.zeros(len(res))\n",
2105+
"for i_z, dist in res:\n",
2106+
" geodesic_dists[i_z] = dist\n",
2107+
" curv_norms[i_z] = curv_norm_learned_profile[\"curv_norm_learned\"].values[i_z]\n",
2108+
"\n",
2109+
"print(geodesic_dists[:10])\n",
2110+
"print(curv_norms[:10])\n",
2111+
"print(1 / curv_norms[:10])\n",
2112+
"\n",
2113+
"print(len(geodesic_dists))"
2114+
]
2115+
},
2116+
{
2117+
"cell_type": "code",
2118+
"execution_count": null,
2119+
"id": "17f4214c",
2120+
"metadata": {},
2121+
"outputs": [],
2122+
"source": [
2123+
"cumul_geodesic_dists = torch.cumsum(geodesic_dists, dim=0)\n",
2124+
"cumul_geodesic_dists[:10]\n",
2125+
"print(cumul_geodesic_dists.max())"
2126+
]
2127+
},
2128+
{
2129+
"cell_type": "markdown",
2130+
"id": "a6d31bbf-b37b-46a3-9b8d-ec48e0b39597",
2131+
"metadata": {},
2132+
"source": [
2133+
"## Plot invariant neural manifold (colored)"
2134+
]
2135+
},
2136+
{
2137+
"cell_type": "code",
2138+
"execution_count": null,
2139+
"id": "b4807db9",
2140+
"metadata": {
2141+
"tags": []
2142+
},
2143+
"outputs": [],
2144+
"source": [
2145+
"import matplotlib.cm as cm\n",
2146+
"import matplotlib.pyplot as plt\n",
2147+
"import numpy as np\n",
2148+
"\n",
2149+
"stats = [\n",
2150+
" \"mean_velocities\",\n",
2151+
" \"median_velocities\",\n",
2152+
" \"std_velocities\",\n",
2153+
" \"min_velocities\",\n",
2154+
" \"max_velocities\",\n",
2155+
"]\n",
2156+
"cmaps = [\"viridis\", \"viridis\", \"magma\", \"Blues\", \"Reds\"]\n",
2157+
"\n",
2158+
"fig, axes = plt.subplots(\n",
2159+
" nrows=len(stats), ncols=1, figsize=(20, 20), subplot_kw={\"projection\": \"polar\"}\n",
2160+
")\n",
2161+
"\n",
2162+
"i_zs = [i_z for i_z, _ in res]\n",
2163+
"subgrid_profile = curv_norm_learned_profile # .take(i_zs)\n",
2164+
"print(len(subgrid_profile))\n",
2165+
"\n",
2166+
"for i_stat, stat_velocities in enumerate(stats):\n",
2167+
" ax = axes[i_stat]\n",
2168+
" ax.scatter(\n",
2169+
" cumul_geodesic_dists,\n",
2170+
" 1 / curv_norms,\n",
2171+
" c=subgrid_profile[stat_velocities],\n",
2172+
" cmap=cmaps[i_stat],\n",
2173+
" )\n",
2174+
" ax.plot(\n",
2175+
" cumul_geodesic_dists,\n",
2176+
" 1 / curv_norms,\n",
2177+
" c=\"black\",\n",
2178+
" )\n",
2179+
" # ax.set_rticks([0.5, 1, 1.5, 2]) # Less radial ticks\n",
2180+
" ax.set_rlabel_position(-22.5) # Move radial labels away from plotted line\n",
2181+
" ax.grid(True)\n",
2182+
" ax.set_title(\"Color: \" + stat_velocities, va=\"bottom\")\n",
2183+
"fig.tight_layout()"
2184+
]
2185+
},
2186+
{
2187+
"cell_type": "markdown",
2188+
"id": "a20ad8fc-90cf-4f3a-aa59-85ee4a794550",
2189+
"metadata": {},
2190+
"source": [
2191+
"## Plot invariant neural manifold (uncolored)"
2192+
]
2193+
},
2194+
{
2195+
"cell_type": "code",
2196+
"execution_count": null,
2197+
"id": "fd80354e-322f-4e1d-a364-8d1c28b93d81",
2198+
"metadata": {
2199+
"tags": []
2200+
},
2201+
"outputs": [],
2202+
"source": [
2203+
"closed_geodesic_dists = torch.concat(\n",
2204+
" [cumul_geodesic_dists, torch.tensor([cumul_geodesic_dists[0]])]\n",
2205+
")\n",
2206+
"closed_curv_norms = torch.concat([curv_norms, torch.tensor([curv_norms[0]])])"
2207+
]
2208+
},
2209+
{
2210+
"cell_type": "code",
2211+
"execution_count": null,
2212+
"id": "af04df43-29d5-4a9a-b359-e7fa1e9cf9f0",
2213+
"metadata": {
2214+
"tags": []
2215+
},
2216+
"outputs": [],
2217+
"source": [
2218+
"import matplotlib.cm as cm\n",
2219+
"import matplotlib.pyplot as plt\n",
2220+
"import numpy as np\n",
2221+
"\n",
2222+
"stats = [\n",
2223+
" \"mean_velocities\",\n",
2224+
" \"median_velocities\",\n",
2225+
" \"std_velocities\",\n",
2226+
" \"min_velocities\",\n",
2227+
" \"max_velocities\",\n",
2228+
"]\n",
2229+
"cmaps = [\"viridis\", \"viridis\", \"magma\", \"Blues\", \"Reds\"]\n",
2230+
"\n",
2231+
"fig, ax = plt.subplots(\n",
2232+
" nrows=1, ncols=1, figsize=(6, 6), subplot_kw={\"projection\": \"polar\"}\n",
2233+
")\n",
2234+
"\n",
2235+
"i_zs = [i_z for i_z, _ in res]\n",
2236+
"subgrid_profile = curv_norm_learned_profile.take(i_zs)\n",
2237+
"\n",
2238+
"ax.plot(closed_geodesic_dists, 1 / closed_curv_norms, c=\"black\")\n",
2239+
"# ax.set_rticks([0.5, 1, 1.5, 2]) # Less radial ticks\n",
2240+
"ax.set_rlabel_position(-22.5) # Move radial labels away from plotted line\n",
2241+
"ax.grid(True)\n",
2242+
"ax.set_title(\"Color: \" + stat_velocities, va=\"bottom\")\n",
2243+
"fig.tight_layout()\n",
2244+
"# import os\n",
2245+
"# print(os.getcwd())\n",
2246+
"fig.savefig(f\"notebooks/figures/run_{run_id}_invariant_manifold.svg\")"
2247+
]
2248+
},
2249+
{
2250+
"cell_type": "markdown",
2251+
"id": "dc68132e-7d46-415b-a895-344cad5ddf49",
2252+
"metadata": {},
2253+
"source": [
2254+
"Note: this looping is weird, and the geodesic distances go well over 14, i.e. over 3.14*2 which is 2 pi. There must be an error."
2255+
]
2256+
},
2257+
{
2258+
"cell_type": "code",
2259+
"execution_count": null,
2260+
"id": "4dc802a5-3f1e-4ec8-9ad2-8c82d100240f",
2261+
"metadata": {},
2262+
"outputs": [],
2263+
"source": []
2264+
}
2265+
],
2266+
"metadata": {
2267+
"kernelspec": {
2268+
"display_name": "Python 3 (ipykernel)",
2269+
"language": "python",
2270+
"name": "python3"
2271+
},
2272+
"language_info": {
2273+
"codemirror_mode": {
2274+
"name": "ipython",
2275+
"version": 3
2276+
},
2277+
"file_extension": ".py",
2278+
"mimetype": "text/x-python",
2279+
"name": "python",
2280+
"nbconvert_exporter": "python",
2281+
"pygments_lexer": "ipython3",
2282+
"version": "3.8.16"
2283+
}
2284+
},
2285+
"nbformat": 4,
2286+
"nbformat_minor": 5
2287+
}

0 commit comments

Comments
 (0)