Skip to content

Commit cd3cb4d

Browse files
committed
Fix tests
1 parent 359f3df commit cd3cb4d

File tree

8 files changed

+30
-23
lines changed

8 files changed

+30
-23
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
import hyperspherical.distributions
2-
import hyperspherical.ops
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
from hyperspherical.distributions.hyperspherical_uniform import HypersphericalUniform
2-
from hyperspherical.distributions.von_mises_fisher import VonMisesFisher

neurometry/curvature/hyperspherical/distributions/von_mises_fisher.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import math
22

33
import torch
4-
from hyperspherical.distributions.hyperspherical_uniform import HypersphericalUniform
5-
from hyperspherical.ops.ive import ive
64
from torch.distributions.kl import register_kl
75

6+
from neurometry.curvature.hyperspherical.distributions.hyperspherical_uniform import (
7+
HypersphericalUniform,
8+
)
9+
from neurometry.curvature.hyperspherical.ops.ive import ive
10+
811

912
class VonMisesFisher(torch.distributions.Distribution):
1013
arg_constraints = {
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from hyperspherical.ops.ive import ive

neurometry/curvature/losses.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Losses."""
22

33
import torch
4-
from hyperspherical.distributions import HypersphericalUniform, VonMisesFisher
4+
5+
from neurometry.curvature.hyperspherical.distributions import (
6+
HypersphericalUniform,
7+
VonMisesFisher,
8+
)
59

610

711
def elbo(x, x_mu, posterior_params, z, labels, config):

notebooks/test_plot.py renamed to notebooks/plot_klein_bottle.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33

44

5-
def klein_sphere_points(num_points, radius = 1):
5+
def klein_sphere_points(num_points, radius=1):
66
"""
77
Generate points on a Klein sphere.
88
@@ -13,8 +13,8 @@ def klein_sphere_points(num_points, radius = 1):
1313
- points: Array of points on the Klein sphere.
1414
"""
1515
# Generate random points on a 2D plane
16-
theta = np.random.uniform(0, 2*np.pi, num_points)
17-
phi = np.random.uniform(0, 2*np.pi, num_points)
16+
theta = np.random.uniform(0, 2 * np.pi, num_points)
17+
phi = np.random.uniform(0, 2 * np.pi, num_points)
1818

1919
# Parametric equations for a Klein sphere
2020
x = radius * (np.cos(theta) * (1 + np.sin(phi)))
@@ -34,7 +34,7 @@ def plot_klein_sphere(points):
3434
fig = plt.figure()
3535
ax = fig.add_subplot(111, projection="3d")
3636

37-
ax.scatter(points[:,0], points[:,1], points[:,2], c="r", marker="o")
37+
ax.scatter(points[:, 0], points[:, 1], points[:, 2], c="r", marker="o")
3838

3939
ax.set_xlabel("X")
4040
ax.set_ylabel("Y")
@@ -44,38 +44,42 @@ def plot_klein_sphere(points):
4444

4545
plt.show()
4646

47+
4748
# Generate points on the Klein sphere
4849
num_points = 1000
49-
scales = [0.5,1,0.5]
50+
scales = [0.5, 1, 0.5]
5051
points = klein_sphere_points(num_points, scales)
5152

5253
# Plot the generated points
5354
plot_klein_sphere(points)
5455

55-
#https://mathworld.wolfram.com/KleinBottle.html
56-
#Klein Bagel
56+
# https://mathworld.wolfram.com/KleinBottle.html
57+
# Klein Bagel
58+
5759

5860
def klein_bottle_points(num_points, scale=1):
59-
u = np.linspace(0, 2*np.pi, num_points)
60-
v = np.linspace(0, 2*np.pi, num_points)
61+
u = np.linspace(0, 2 * np.pi, num_points)
62+
v = np.linspace(0, 2 * np.pi, num_points)
6163
U, V = np.meshgrid(u, v)
6264

63-
X = (scale + np.cos(U/2) *np.sin(V) - np.sin(U/2)*np.sin(2*V))*np.cos(U)
64-
Y = (scale + np.cos(U/2) *np.sin(V) - np.sin(U/2)*np.sin(2*V))*np.sin(U)
65-
Z = np.sin(U/2)*np.sin(V) + np.cos(U/2)*np.sin(2*V)
65+
X = (scale + np.cos(U / 2) * np.sin(V) - np.sin(U / 2) * np.sin(2 * V)) * np.cos(U)
66+
Y = (scale + np.cos(U / 2) * np.sin(V) - np.sin(U / 2) * np.sin(2 * V)) * np.sin(U)
67+
Z = np.sin(U / 2) * np.sin(V) + np.cos(U / 2) * np.sin(2 * V)
6668
return X, Y, Z
6769

70+
6871
def plot_klein_bottle(X, Y, Z):
6972
fig = plt.figure()
7073
ax = fig.add_subplot(111, projection="3d")
71-
ax.plot_surface(X, Y, Z, cmap="viridis", alpha = 0.7)
74+
ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.7)
7275

7376
ax.set_xlabel("X")
7477
ax.set_ylabel("Y")
7578
ax.set_zlabel("Z")
7679
ax.set_title("Klein Bottle")
7780
plt.show()
7881

82+
7983
# Example usage
8084
num_points = 100
8185
scale = 3

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ doc = [
5858
"pydata-sphinx-theme"
5959
]
6060
lint = [
61-
"pre-commit"
61+
"pre-commit",
62+
"ruff"
6263
]
6364
test = [
6465
"pytest",

tests/test_losses.py renamed to tests/test_curvature.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import torch
33

4-
from neurometry.losses import latent_regularization_loss
4+
from neurometry.curvature.losses import latent_regularization_loss
55

66

77
class AttrDict(dict):

0 commit comments

Comments
 (0)