Skip to content

Test figure returned by WulffShape.get_plot() contains single Axes3D #2953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pymatgen/analysis/tests/test_wulff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import unittest

from mpl_toolkits.mplot3d import Axes3D
from pytest import approx

from pymatgen.analysis.wulff import WulffShape
Expand Down Expand Up @@ -70,14 +71,13 @@ def setUp(self):

self.surface_properties = surface_properties

@unittest.skipIf("DISPLAY" not in os.environ, "Need display")
def test_get_plot(self):
# Basic test, not really a unittest.
self.wulff_Ti.get_plot()
self.wulff_Nb.get_plot()
self.wulff_Ir.get_plot()
# Basic test to check figure contains a single Axes3D object
for wulff in (self.wulff_Nb, self.wulff_Ir, self.wulff_Ti):
plt = wulff.get_plot()
assert len(plt.gcf().get_axes()) == 1
assert isinstance(plt.gcf().get_axes()[0], Axes3D)

@unittest.skipIf("DISPLAY" not in os.environ, "Need display")
def test_get_plotly(self):
# Basic test, not really a unittest.
self.wulff_Ti.get_plotly()
Expand Down
6 changes: 3 additions & 3 deletions pymatgen/analysis/wulff.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def get_plot(
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as mpl3
from mpl_toolkits.mplot3d import Axes3D, art3d

colors = self._get_colors(color_set, alpha, off_color, custom_colors=custom_colors or {})
color_list, color_proxy, color_proxy_on_wulff, miller_on_wulff, e_surf_on_wulff = colors
Expand All @@ -438,7 +438,7 @@ def get_plot(

wulff_pt_list = self.wulff_pt_list

ax = mpl3.Axes3D(fig, azim=azim, elev=elev)
ax = Axes3D(fig, azim=azim, elev=elev)
fig.add_axes(ax)

for plane in self.facets:
Expand All @@ -451,7 +451,7 @@ def get_plot(
plane_color = color_list[plane.index]
pt = self.get_line_in_facet(plane)
# plot from the sorted pts from [simpx]
tri = mpl3.art3d.Poly3DCollection([pt])
tri = art3d.Poly3DCollection([pt])
tri.set_color(plane_color)
tri.set_edgecolor("#808080")
ax.add_collection3d(tri)
Expand Down