Skip to content

Commit 422ebc5

Browse files
authored
Fix the string representation of archetypes (#9297)
`Archetype.__str__` was broken at some point. This PR fixes it. It also adds a (small) test. Before: ``` rr.Points3D( ) ``` After: ``` rr.Points3D( positions=[[11.0, 2.0, 3.0], [2.0, 3.0, 2.0], [3.0, 2.0, 3.0], [2.0, 3.0, 2.0], [3.0, 2.0, 3.0], [2.0, 3.0, 2.0], [3.0, 2.0, 3.0], [2.0, 3.0, 2.0], [3.0, 2.0, 3.0], [2.0, 3.0, 2.0], [3.0, 2.0, 3.0], [2.0, 3.0, 3.0]], radii=[1.0, 2.0, 3.0] ) ```
1 parent 25eeda6 commit 422ebc5

File tree

2 files changed

+128
-10
lines changed

2 files changed

+128
-10
lines changed

rerun_py/rerun_sdk/rerun/_baseclasses.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
from collections.abc import Iterable, Iterator
45
from typing import Generic, Protocol, TypeVar, runtime_checkable
56

@@ -201,18 +202,35 @@ class Archetype(AsComponents):
201202
"""Base class for all archetypes."""
202203

203204
def __str__(self) -> str:
204-
cls = type(self)
205+
from pprint import pformat
205206

206-
s = f"rr.{cls.__name__}(\n"
207-
for fld in fields(cls):
208-
if "component" in fld.metadata:
209-
comp = getattr(self, fld.name)
210-
datatype = getattr(comp, "type", None)
211-
if datatype:
212-
s += f" {datatype.extension_name}<{datatype}>(\n {comp.to_pylist()}\n )\n"
213-
s += ")"
207+
cls = type(self)
214208

215-
return s
209+
def fields_repr() -> Iterable[str]:
210+
for fld in fields(cls):
211+
if "component" in fld.metadata:
212+
comp = getattr(self, fld.name)
213+
if comp is None:
214+
continue
215+
216+
as_arrow_array = getattr(comp, "as_arrow_array", None)
217+
218+
if as_arrow_array is None:
219+
comp_contents = "<unknown>"
220+
else:
221+
# Note: the regex here is necessary because for some reason pformat add spurious spaces when
222+
# indent > 1.
223+
comp_contents = re.sub(
224+
r"\[\s+\[", "[[", pformat(as_arrow_array().to_pylist(), compact=True, indent=4)
225+
)
226+
227+
yield f" {fld.name}={comp_contents}"
228+
229+
args = ",\n".join(fields_repr())
230+
if args:
231+
return f"rr.{cls.__name__}(\n{args}\n)"
232+
else:
233+
return f"rr.{cls.__name__}()"
216234

217235
@classmethod
218236
def archetype_name(cls) -> str:
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
import rerun as rr
5+
6+
7+
@pytest.mark.parametrize(
8+
["archetype", "expected"],
9+
[
10+
[
11+
rr.Transform3D(),
12+
(
13+
"rr.Transform3D(\n"
14+
" translation=[],\n"
15+
" rotation_axis_angle=[],\n"
16+
" quaternion=[],\n"
17+
" scale=[],\n"
18+
" mat3x3=[],\n"
19+
" relation=[],\n"
20+
" axis_length=[]\n"
21+
")"
22+
),
23+
],
24+
[
25+
rr.Transform3D(translation=[10, 10, 10]),
26+
(
27+
"rr.Transform3D(\n"
28+
" translation=[[10.0, 10.0, 10.0]],\n"
29+
" rotation_axis_angle=[],\n"
30+
" quaternion=[],\n"
31+
" scale=[],\n"
32+
" mat3x3=[],\n"
33+
" relation=[],\n"
34+
" axis_length=[]\n"
35+
")"
36+
),
37+
],
38+
[
39+
rr.Points2D(positions=[[0, 0], [1, 1], [2, 2]]),
40+
"rr.Points2D(\n positions=[[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]\n)",
41+
],
42+
[
43+
rr.Points2D(positions=[0, 0, 1, 1, 2, 2], radii=[4, 5, 6]),
44+
"rr.Points2D(\n positions=[[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]],\n radii=[4.0, 5.0, 6.0]\n)",
45+
],
46+
[rr.Points2D.from_fields(), "rr.Points2D()"],
47+
[
48+
rr.Points3D(
49+
[
50+
11,
51+
2,
52+
3,
53+
2,
54+
3,
55+
2,
56+
3,
57+
2,
58+
3,
59+
2,
60+
3,
61+
2,
62+
3,
63+
2,
64+
3,
65+
2,
66+
3,
67+
2,
68+
3,
69+
2,
70+
3,
71+
2,
72+
3,
73+
2,
74+
3,
75+
2,
76+
3,
77+
2,
78+
3,
79+
2,
80+
3,
81+
2,
82+
3,
83+
2,
84+
3,
85+
3,
86+
],
87+
radii=[1, 2, 3],
88+
),
89+
"""\
90+
rr.Points3D(
91+
positions=[[11.0, 2.0, 3.0], [2.0, 3.0, 2.0], [3.0, 2.0, 3.0], [2.0, 3.0, 2.0],
92+
[3.0, 2.0, 3.0], [2.0, 3.0, 2.0], [3.0, 2.0, 3.0], [2.0, 3.0, 2.0],
93+
[3.0, 2.0, 3.0], [2.0, 3.0, 2.0], [3.0, 2.0, 3.0], [2.0, 3.0, 3.0]],
94+
radii=[1.0, 2.0, 3.0]
95+
)""",
96+
],
97+
],
98+
)
99+
def test_archetype_str(archetype: rr._baseclasses.Archetype, expected: str) -> None:
100+
assert str(archetype) == expected

0 commit comments

Comments
 (0)