@@ -450,18 +450,17 @@ def add_child_plots(
450
450
451
451
def add_ele_symbols (
452
452
self ,
453
- text : Callable [[Element ], str ] = lambda elem : elem .symbol ,
453
+ text : str | Callable [[Element ], str ] = lambda elem : elem .symbol ,
454
454
pos : tuple [float , float ] = (0.5 , 0.5 ),
455
455
kwargs : dict [str , Any ] | None = None ,
456
456
) -> None :
457
457
"""Add element symbols for each tile.
458
458
459
459
Args:
460
- text: A callable or string specifying how to display
461
- the element symbol. If a callable is provided,
462
- it should accept an Element object and return a string.
463
- If a string is provided, it can contain a format
464
- specifier for the element symbol, e.g., "{elem.symbol}".
460
+ text (str | Callable): The text to add to the tile.
461
+ If a callable, it should accept a pymatgen Element object and return a
462
+ string. If a string, it can contain a format
463
+ specifier for an `elem` variable which will be replaced by the element.
465
464
pos: The position of the text relative to the axes.
466
465
kwargs: Additional keyword arguments to pass to the `ax.text`.
467
466
"""
@@ -476,13 +475,9 @@ def add_ele_symbols(
476
475
row , column = df_ptable .loc [symbol , ["row" , "column" ]]
477
476
ax : plt .Axes = self .axes [row - 1 ][column - 1 ]
478
477
478
+ anno = text (element ) if callable (text ) else text .format (elem = element )
479
479
ax .text (
480
- * pos ,
481
- text (element ) if callable (text ) else text .format (elem = element ), # type: ignore[attr-defined]
482
- ha = "center" ,
483
- va = "center" ,
484
- transform = ax .transAxes ,
485
- ** kwargs ,
480
+ * pos , anno , ha = "center" , va = "center" , transform = ax .transAxes , ** kwargs
486
481
)
487
482
488
483
def add_colorbar (
@@ -998,7 +993,7 @@ def ptable_heatmap_splits(
998
993
999
994
# Add element symbols
1000
995
plotter .add_ele_symbols (
1001
- text = symbol_text , # type: ignore[arg-type]
996
+ text = symbol_text ,
1002
997
pos = symbol_pos ,
1003
998
kwargs = symbol_kwargs ,
1004
999
)
@@ -1623,7 +1618,7 @@ def ptable_scatters(
1623
1618
1624
1619
# Add element symbols
1625
1620
plotter .add_ele_symbols (
1626
- text = symbol_text , # type: ignore[arg-type]
1621
+ text = symbol_text ,
1627
1622
pos = symbol_pos ,
1628
1623
kwargs = symbol_kwargs ,
1629
1624
)
@@ -1697,7 +1692,7 @@ def ptable_lines(
1697
1692
1698
1693
# Add element symbols
1699
1694
plotter .add_ele_symbols (
1700
- text = symbol_text , # type: ignore[arg-type]
1695
+ text = symbol_text ,
1701
1696
pos = symbol_pos ,
1702
1697
kwargs = symbol_kwargs ,
1703
1698
)
0 commit comments