Skip to content

Commit 9ccc1f3

Browse files
Fix casting list array to fixed size list (#7021)
* Test array_cast * Test array values as well * Fix array_cast by using list_size instead of length
1 parent 4ba47a3 commit 9ccc1f3

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/datasets/table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1919,7 +1919,7 @@ def array_cast(
19191919
)
19201920
else:
19211921
array_values = array.values[
1922-
array.offset * pa_type.length : (array.offset + len(array)) * pa_type.length
1922+
array.offset * pa_type.list_size : (array.offset + len(array)) * pa_type.list_size
19231923
]
19241924
return pa.FixedSizeListArray.from_arrays(_c(array_values, pa_type.value_type), pa_type.list_size)
19251925
elif pa.types.is_list(pa_type):

tests/test_table.py

+11
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_in_memory_arrow_table_from_file,
2020
_interpolation_search,
2121
_memory_mapped_arrow_table_from_file,
22+
array_cast,
2223
cast_array_to_feature,
2324
concat_tables,
2425
embed_array_storage,
@@ -1323,3 +1324,13 @@ def test_table_iter(table, batch_size, drop_last_batch):
13231324
if num_rows > 0:
13241325
reloaded = pa.concat_tables(subtables)
13251326
assert table.slice(0, num_rows).to_pydict() == reloaded.to_pydict()
1327+
1328+
1329+
@pytest.mark.parametrize("to_type", ["list", "fixed_size_list"])
1330+
@pytest.mark.parametrize("from_type", ["list", "fixed_size_list"])
1331+
def test_array_cast(from_type, to_type):
1332+
array_type = {"list": pa.list_(pa.int64()), "fixed_size_list": pa.list_(pa.int64(), 2)}
1333+
arr = pa.array([[0, 1]], type=array_type[from_type])
1334+
cast_arr = array_cast(arr, array_type[to_type])
1335+
assert cast_arr.type == array_type[to_type]
1336+
assert cast_arr.values == arr.values

0 commit comments

Comments
 (0)