Skip to content

Commit 43b1fe1

Browse files
varadhbhatnagarVarad Bhatnagar
and
Varad Bhatnagar
authored
Modify add_column() to optionally accept a FeatureType as param (#7143)
* Modify add_column() to optionally accept a FeatureType param * Add feature param to add_column() docstring --------- Co-authored-by: Varad Bhatnagar <[email protected]>
1 parent e4bba5e commit 43b1fe1

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/datasets/arrow_dataset.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -5613,7 +5613,9 @@ def push_to_hub(
56135613

56145614
@transmit_format
56155615
@fingerprint_transform(inplace=False)
5616-
def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str):
5616+
def add_column(
5617+
self, name: str, column: Union[list, np.array], new_fingerprint: str, feature: Optional[FeatureType] = None
5618+
):
56175619
"""Add column to Dataset.
56185620
56195621
<Added version="1.7"/>
@@ -5623,6 +5625,8 @@ def add_column(self, name: str, column: Union[list, np.array], new_fingerprint:
56235625
Column name.
56245626
column (`list` or `np.array`):
56255627
Column data to be added.
5628+
feature (`FeatureType` or `None`, defaults to `None`):
5629+
Column datatype.
56265630
56275631
Returns:
56285632
[`Dataset`]
@@ -5640,7 +5644,13 @@ def add_column(self, name: str, column: Union[list, np.array], new_fingerprint:
56405644
})
56415645
```
56425646
"""
5643-
column_table = InMemoryTable.from_pydict({name: column})
5647+
5648+
if feature:
5649+
pyarrow_schema = Features({name: feature}).arrow_schema
5650+
else:
5651+
pyarrow_schema = None
5652+
5653+
column_table = InMemoryTable.from_pydict({name: column}, schema=pyarrow_schema)
56445654
_check_column_names(self._data.column_names + column_table.column_names)
56455655
dataset = self.flatten_indices() if self._indices is not None else self
56465656
# Concatenate tables horizontally

0 commit comments

Comments
 (0)