|
105 | 105 | "source": [
|
106 | 106 | "import json\n",
|
107 | 107 | "\n",
|
108 |
| - "import default_config as config\n", |
109 | 108 | "import matplotlib as mpl\n",
|
| 109 | + "import matplotlib.cm as cm\n", |
110 | 110 | "import matplotlib.pyplot as plt\n",
|
111 | 111 | "import numpy as np\n",
|
112 | 112 | "import pandas as pd\n",
|
113 | 113 | "import torch\n",
|
114 |
| - "import viz\n", |
115 | 114 | "\n",
|
116 |
| - "import neurometry.datasets.utils as utils" |
| 115 | + "import default_config as config\n", |
| 116 | + "import neurometry.datasets.utils as utils\n", |
| 117 | + "import neurometry.models.neural_vae as neural_vae\n", |
| 118 | + "import train\n", |
| 119 | + "import viz" |
117 | 120 | ]
|
118 | 121 | },
|
119 | 122 | {
|
|
909 | 912 | "source": [
|
910 | 913 | "with open(\n",
|
911 | 914 | " os.path.join(CONFIG_DIR, run_id_config_file),\n",
|
| 915 | + " \"r\",\n", |
912 | 916 | ") as f:\n",
|
913 | 917 | " config_dict = json.load(f)\n",
|
914 | 918 | "\n",
|
|
1005 | 1009 | "# This is needed for the utils.load() function\n",
|
1006 | 1010 | "class AttrDict(dict):\n",
|
1007 | 1011 | " def __init__(self, *args, **kwargs):\n",
|
1008 |
| - " super().__init__(*args, **kwargs)\n", |
| 1012 | + " super(AttrDict, self).__init__(*args, **kwargs)\n", |
1009 | 1013 | " self.__dict__ = self\n",
|
1010 | 1014 | "\n",
|
1011 | 1015 | "\n",
|
|
1821 | 1825 | "source": [
|
1822 | 1826 | "import time\n",
|
1823 | 1827 | "\n",
|
1824 |
| - "import evaluate\n", |
1825 | 1828 | "import geomstats.backend as gs\n",
|
1826 | 1829 | "from geomstats.geometry.pullback_metric import PullbackMetric\n",
|
1827 | 1830 | "\n",
|
| 1831 | + "import evaluate\n", |
| 1832 | + "\n", |
1828 | 1833 | "learned_immersion = evaluate.get_learned_immersion(model, config)\n",
|
1829 | 1834 | "\n",
|
1830 | 1835 | "neural_metric = PullbackMetric(\n",
|
|
2012 | 2017 | "id": "28c0e407",
|
2013 | 2018 | "metadata": {},
|
2014 | 2019 | "source": [
|
| 2020 | + "Let's parallelize this." |
| 2021 | + ] |
| 2022 | + }, |
| 2023 | + { |
| 2024 | + "cell_type": "code", |
| 2025 | + "execution_count": null, |
| 2026 | + "id": "42875293", |
| 2027 | + "metadata": { |
| 2028 | + "tags": [] |
| 2029 | + }, |
| 2030 | + "outputs": [], |
| 2031 | + "source": [ |
| 2032 | + "import copy\n", |
| 2033 | + "import logging\n", |
| 2034 | + "\n", |
| 2035 | + "from joblib import Parallel, delayed\n", |
| 2036 | + "\n", |
| 2037 | + "model.to(\"cuda:0\")\n", |
| 2038 | + "z_grid = torch.tensor(curv_norm_learned_profile[\"z_grid\"].values)\n", |
| 2039 | + "z0 = torch.unsqueeze(z_grid[0], dim=0)\n", |
| 2040 | + "\n", |
| 2041 | + "\n", |
| 2042 | + "# TODO (use logging to actually print and know which iterations we are on)\n", |
| 2043 | + "def _geodesic_dist(i_z, z, grid_interval):\n", |
| 2044 | + " if i_z == 0:\n", |
| 2045 | + " return (0, torch.tensor(0.0))\n", |
| 2046 | + " # Parallelize on gpus: is that line really helping,\n", |
| 2047 | + " # i.e. is the copy taking less time than the computation? if not, don't bother\n", |
| 2048 | + " # also: does it change anything since neural metric has already been computed and uses the old version of model?\n", |
| 2049 | + " model_copy = copy.deepcopy(model).to(f\"cuda:{i_z % 9}\")\n", |
| 2050 | + " z = torch.unsqueeze(z, dim=0)\n", |
| 2051 | + " z_previous = torch.unsqueeze(z_grid[i_z - grid_interval], dim=0)\n", |
| 2052 | + " start = time.time()\n", |
| 2053 | + " # Tricks to speed up this computation:\n", |
| 2054 | + " # 1. Compute with less number of steps for the integration of the geodesic eqn\n", |
| 2055 | + " # 2. Compute distance between neighboring z's\n", |
| 2056 | + " dist = neural_metric.dist(z_previous, z, n_steps=7)\n", |
| 2057 | + " duration = time.time() - start\n", |
| 2058 | + " logging.info(f\"Time (it: {i_z}): {duration:.3f}\")\n", |
| 2059 | + " return (i_z, dist)\n", |
| 2060 | + "\n", |
| 2061 | + "\n", |
| 2062 | + "# To try this code , use z_grid[:5] to run on smaller batch\n", |
| 2063 | + "# Currently, the z_grid is too big (~800) --> go to 100\n", |
| 2064 | + "# Note: each distance computation takes ~5s.\n", |
| 2065 | + "grid_interval = 1\n", |
| 2066 | + "res = Parallel(n_jobs=-1)(\n", |
| 2067 | + " delayed(_geodesic_dist)(i_z, z, grid_interval)\n", |
| 2068 | + " for i_z, z in enumerate(z_grid)\n", |
| 2069 | + " if i_z % grid_interval == 0\n", |
| 2070 | + ")" |
| 2071 | + ] |
| 2072 | + }, |
| 2073 | + { |
| 2074 | + "cell_type": "code", |
| 2075 | + "execution_count": 16, |
| 2076 | + "id": "8641f943-5883-4fdc-bc18-f4ccb2f4f5b9", |
| 2077 | + "metadata": {}, |
| 2078 | + "outputs": [ |
| 2079 | + { |
| 2080 | + "data": { |
| 2081 | + "text/plain": [ |
| 2082 | + "pandas.core.series.Series" |
| 2083 | + ] |
| 2084 | + }, |
| 2085 | + "execution_count": 16, |
| 2086 | + "metadata": {}, |
| 2087 | + "output_type": "execute_result" |
| 2088 | + } |
| 2089 | + ], |
| 2090 | + "source": [ |
| 2091 | + "type(curv_norm_learned_profile[\"curv_norm_learned\"])" |
| 2092 | + ] |
| 2093 | + }, |
| 2094 | + { |
| 2095 | + "cell_type": "code", |
| 2096 | + "execution_count": null, |
| 2097 | + "id": "e3889082", |
| 2098 | + "metadata": { |
| 2099 | + "tags": [] |
| 2100 | + }, |
| 2101 | + "outputs": [], |
| 2102 | + "source": [ |
| 2103 | + "geodesic_dists = torch.zeros(len(res))\n", |
| 2104 | + "curv_norms = torch.zeros(len(res))\n", |
| 2105 | + "for i_z, dist in res:\n", |
| 2106 | + " geodesic_dists[i_z] = dist\n", |
| 2107 | + " curv_norms[i_z] = curv_norm_learned_profile[\"curv_norm_learned\"].values[i_z]\n", |
| 2108 | + "\n", |
| 2109 | + "print(geodesic_dists[:10])\n", |
| 2110 | + "print(curv_norms[:10])\n", |
| 2111 | + "print(1 / curv_norms[:10])\n", |
| 2112 | + "\n", |
| 2113 | + "print(len(geodesic_dists))" |
| 2114 | + ] |
| 2115 | + }, |
| 2116 | + { |
| 2117 | + "cell_type": "code", |
| 2118 | + "execution_count": null, |
| 2119 | + "id": "17f4214c", |
| 2120 | + "metadata": {}, |
| 2121 | + "outputs": [], |
| 2122 | + "source": [ |
| 2123 | + "cumul_geodesic_dists = torch.cumsum(geodesic_dists, dim=0)\n", |
| 2124 | + "cumul_geodesic_dists[:10]\n", |
| 2125 | + "print(cumul_geodesic_dists.max())" |
| 2126 | + ] |
| 2127 | + }, |
| 2128 | + { |
| 2129 | + "cell_type": "markdown", |
| 2130 | + "id": "a6d31bbf-b37b-46a3-9b8d-ec48e0b39597", |
| 2131 | + "metadata": {}, |
| 2132 | + "source": [ |
| 2133 | + "## Plot invariant neural manifold (colored)" |
| 2134 | + ] |
| 2135 | + }, |
| 2136 | + { |
| 2137 | + "cell_type": "code", |
| 2138 | + "execution_count": null, |
| 2139 | + "id": "b4807db9", |
| 2140 | + "metadata": { |
| 2141 | + "tags": [] |
| 2142 | + }, |
| 2143 | + "outputs": [], |
| 2144 | + "source": [ |
| 2145 | + "import matplotlib.cm as cm\n", |
| 2146 | + "import matplotlib.pyplot as plt\n", |
| 2147 | + "import numpy as np\n", |
| 2148 | + "\n", |
| 2149 | + "stats = [\n", |
| 2150 | + " \"mean_velocities\",\n", |
| 2151 | + " \"median_velocities\",\n", |
| 2152 | + " \"std_velocities\",\n", |
| 2153 | + " \"min_velocities\",\n", |
| 2154 | + " \"max_velocities\",\n", |
| 2155 | + "]\n", |
| 2156 | + "cmaps = [\"viridis\", \"viridis\", \"magma\", \"Blues\", \"Reds\"]\n", |
| 2157 | + "\n", |
| 2158 | + "fig, axes = plt.subplots(\n", |
| 2159 | + " nrows=len(stats), ncols=1, figsize=(20, 20), subplot_kw={\"projection\": \"polar\"}\n", |
| 2160 | + ")\n", |
| 2161 | + "\n", |
| 2162 | + "i_zs = [i_z for i_z, _ in res]\n", |
| 2163 | + "subgrid_profile = curv_norm_learned_profile # .take(i_zs)\n", |
| 2164 | + "print(len(subgrid_profile))\n", |
| 2165 | + "\n", |
| 2166 | + "for i_stat, stat_velocities in enumerate(stats):\n", |
| 2167 | + " ax = axes[i_stat]\n", |
| 2168 | + " ax.scatter(\n", |
| 2169 | + " cumul_geodesic_dists,\n", |
| 2170 | + " 1 / curv_norms,\n", |
| 2171 | + " c=subgrid_profile[stat_velocities],\n", |
| 2172 | + " cmap=cmaps[i_stat],\n", |
| 2173 | + " )\n", |
| 2174 | + " ax.plot(\n", |
| 2175 | + " cumul_geodesic_dists,\n", |
| 2176 | + " 1 / curv_norms,\n", |
| 2177 | + " c=\"black\",\n", |
| 2178 | + " )\n", |
| 2179 | + " # ax.set_rticks([0.5, 1, 1.5, 2]) # Less radial ticks\n", |
| 2180 | + " ax.set_rlabel_position(-22.5) # Move radial labels away from plotted line\n", |
| 2181 | + " ax.grid(True)\n", |
| 2182 | + " ax.set_title(\"Color: \" + stat_velocities, va=\"bottom\")\n", |
| 2183 | + "fig.tight_layout()" |
| 2184 | + ] |
| 2185 | + }, |
| 2186 | + { |
| 2187 | + "cell_type": "markdown", |
| 2188 | + "id": "a20ad8fc-90cf-4f3a-aa59-85ee4a794550", |
| 2189 | + "metadata": {}, |
| 2190 | + "source": [ |
| 2191 | + "## Plot invariant neural manifold (uncolored)" |
| 2192 | + ] |
| 2193 | + }, |
| 2194 | + { |
| 2195 | + "cell_type": "code", |
| 2196 | + "execution_count": null, |
| 2197 | + "id": "fd80354e-322f-4e1d-a364-8d1c28b93d81", |
| 2198 | + "metadata": { |
| 2199 | + "tags": [] |
| 2200 | + }, |
| 2201 | + "outputs": [], |
| 2202 | + "source": [ |
| 2203 | + "closed_geodesic_dists = torch.concat(\n", |
| 2204 | + " [cumul_geodesic_dists, torch.tensor([cumul_geodesic_dists[0]])]\n", |
| 2205 | + ")\n", |
| 2206 | + "closed_curv_norms = torch.concat([curv_norms, torch.tensor([curv_norms[0]])])" |
| 2207 | + ] |
| 2208 | + }, |
| 2209 | + { |
| 2210 | + "cell_type": "code", |
| 2211 | + "execution_count": null, |
| 2212 | + "id": "af04df43-29d5-4a9a-b359-e7fa1e9cf9f0", |
| 2213 | + "metadata": { |
| 2214 | + "tags": [] |
| 2215 | + }, |
| 2216 | + "outputs": [], |
| 2217 | + "source": [ |
| 2218 | + "import matplotlib.cm as cm\n", |
| 2219 | + "import matplotlib.pyplot as plt\n", |
| 2220 | + "import numpy as np\n", |
| 2221 | + "\n", |
| 2222 | + "stats = [\n", |
| 2223 | + " \"mean_velocities\",\n", |
| 2224 | + " \"median_velocities\",\n", |
| 2225 | + " \"std_velocities\",\n", |
| 2226 | + " \"min_velocities\",\n", |
| 2227 | + " \"max_velocities\",\n", |
| 2228 | + "]\n", |
| 2229 | + "cmaps = [\"viridis\", \"viridis\", \"magma\", \"Blues\", \"Reds\"]\n", |
| 2230 | + "\n", |
| 2231 | + "fig, ax = plt.subplots(\n", |
| 2232 | + " nrows=1, ncols=1, figsize=(6, 6), subplot_kw={\"projection\": \"polar\"}\n", |
| 2233 | + ")\n", |
| 2234 | + "\n", |
| 2235 | + "i_zs = [i_z for i_z, _ in res]\n", |
| 2236 | + "subgrid_profile = curv_norm_learned_profile.take(i_zs)\n", |
| 2237 | + "\n", |
| 2238 | + "ax.plot(closed_geodesic_dists, 1 / closed_curv_norms, c=\"black\")\n", |
| 2239 | + "# ax.set_rticks([0.5, 1, 1.5, 2]) # Less radial ticks\n", |
| 2240 | + "ax.set_rlabel_position(-22.5) # Move radial labels away from plotted line\n", |
| 2241 | + "ax.grid(True)\n", |
| 2242 | + "ax.set_title(\"Color: \" + stat_velocities, va=\"bottom\")\n", |
| 2243 | + "fig.tight_layout()\n", |
| 2244 | + "# import os\n", |
| 2245 | + "# print(os.getcwd())\n", |
| 2246 | + "fig.savefig(f\"notebooks/figures/run_{run_id}_invariant_manifold.svg\")" |
| 2247 | + ] |
| 2248 | + }, |
| 2249 | + { |
| 2250 | + "cell_type": "markdown", |
| 2251 | + "id": "dc68132e-7d46-415b-a895-344cad5ddf49", |
| 2252 | + "metadata": {}, |
| 2253 | + "source": [ |
| 2254 | + "Note: this looping is weird, and the geodesic distances go well over 14, i.e. over 3.14*2 which is 2 pi. There must be an error." |
| 2255 | + ] |
| 2256 | + }, |
| 2257 | + { |
| 2258 | + "cell_type": "code", |
| 2259 | + "execution_count": null, |
| 2260 | + "id": "4dc802a5-3f1e-4ec8-9ad2-8c82d100240f", |
| 2261 | + "metadata": {}, |
| 2262 | + "outputs": [], |
| 2263 | + "source": [] |
| 2264 | + } |
| 2265 | + ], |
| 2266 | + "metadata": { |
| 2267 | + "kernelspec": { |
| 2268 | + "display_name": "Python 3 (ipykernel)", |
| 2269 | + "language": "python", |
| 2270 | + "name": "python3" |
| 2271 | + }, |
| 2272 | + "language_info": { |
| 2273 | + "codemirror_mode": { |
| 2274 | + "name": "ipython", |
| 2275 | + "version": 3 |
| 2276 | + }, |
| 2277 | + "file_extension": ".py", |
| 2278 | + "mimetype": "text/x-python", |
| 2279 | + "name": "python", |
| 2280 | + "nbconvert_exporter": "python", |
| 2281 | + "pygments_lexer": "ipython3", |
| 2282 | + "version": "3.8.16" |
| 2283 | + } |
| 2284 | + }, |
| 2285 | + "nbformat": 4, |
| 2286 | + "nbformat_minor": 5 |
| 2287 | +} |
0 commit comments