Skip to content

Commit b6b3aa0

Browse files
author
Varad Bhatnagar
committed
Modify add_column() to optionally accept a pyarrow schema as param
1 parent ca58154 commit b6b3aa0

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/datasets/arrow_dataset.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -5613,7 +5613,9 @@ def push_to_hub(
56135613

56145614
@transmit_format
56155615
@fingerprint_transform(inplace=False)
5616-
def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str):
5616+
def add_column(
5617+
self, name: str, column: Union[list, np.array], new_fingerprint: str, pyarrow_schema: pa.schema = None
5618+
):
56175619
"""Add column to Dataset.
56185620
56195621
<Added version="1.7"/>
@@ -5640,7 +5642,7 @@ def add_column(self, name: str, column: Union[list, np.array], new_fingerprint:
56405642
})
56415643
```
56425644
"""
5643-
column_table = InMemoryTable.from_pydict({name: column})
5645+
column_table = InMemoryTable.from_pydict({name: column}, schema=pyarrow_schema)
56445646
_check_column_names(self._data.column_names + column_table.column_names)
56455647
dataset = self.flatten_indices() if self._indices is not None else self
56465648
# Concatenate tables horizontally

0 commit comments

Comments
 (0)