Skip to content

Commit 686f5df

Browse files
Add support for categorical/dictionary types (#6892)
* Add support for dictionary types * Add unit tests * Style fix * bump pyarrow * convert in Dataset init * remove beam from tests --------- Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent a2dc287 commit 686f5df

File tree

6 files changed

+37
-6
lines changed

6 files changed

+37
-6
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
run: uv pip install --system --upgrade pyarrow huggingface-hub dill
6666
- name: Install dependencies (minimum versions)
6767
if: ${{ matrix.deps_versions != 'deps-latest' }}
68-
run: uv pip install --system pyarrow==12.0.0 huggingface-hub==0.21.2 transformers dill==0.3.1.1
68+
run: uv pip install --system pyarrow==15.0.0 huggingface-hub==0.21.2 transformers dill==0.3.1.1
6969
- name: Test with pytest
7070
run: |
7171
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/

setup.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@
113113
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
114114
"numpy>=1.17",
115115
# Backend and serialization.
116-
# Minimum 12.0.0 to be able to concatenate extension arrays
117-
"pyarrow>=12.0.0",
116+
# Minimum 15.0.0 to be able to cast dictionary types to their underlying types
117+
"pyarrow>=15.0.0",
118118
# As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
119119
"pyarrow-hotfix",
120120
# For smart caching dataset processing
@@ -166,7 +166,6 @@
166166
"pytest-datadir",
167167
"pytest-xdist",
168168
# optional dependencies
169-
"apache-beam>=2.26.0; sys_platform != 'win32' and python_version<'3.10'", # doesn't support recent dill versions for recent python versions and on windows requires pyarrow<12.0.0
170169
"elasticsearch<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch()
171170
"faiss-cpu>=1.6.4",
172171
"jax>=0.3.14; sys_platform != 'win32'",

src/datasets/arrow_dataset.py

+5
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,11 @@ def __init__(
711711
f"{e}\nThe 'source' features come from dataset_info.json, and the 'target' ones are those of the dataset arrow file."
712712
)
713713

714+
# In case there are types like pa.dictionary that we need to convert to the underlying type
715+
716+
if self.data.schema != self.info.features.arrow_schema:
717+
self._data = self.data.cast(self.info.features.arrow_schema)
718+
714719
# Infer fingerprint if None
715720

716721
if self._fingerprint is None:

src/datasets/features/features.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str:
109109
return "string"
110110
elif pyarrow.types.is_large_string(arrow_type):
111111
return "large_string"
112+
elif pyarrow.types.is_dictionary(arrow_type):
113+
return _arrow_to_datasets_dtype(arrow_type.value_type)
112114
else:
113115
raise ValueError(f"Arrow type {arrow_type} does not have a datasets dtype equivalent.")
114116

@@ -1434,8 +1436,6 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType:
14341436
elif isinstance(pa_type, _ArrayXDExtensionType):
14351437
array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims]
14361438
return array_feature(shape=pa_type.shape, dtype=pa_type.value_type)
1437-
elif isinstance(pa_type, pa.DictionaryType):
1438-
raise NotImplementedError # TODO(thom) this will need access to the dictionary as well (for labels). I.e. to the py_table
14391439
elif isinstance(pa_type, pa.DataType):
14401440
return Value(dtype=_arrow_to_datasets_dtype(pa_type))
14411441
else:
@@ -1705,6 +1705,9 @@ def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features":
17051705
It also checks the schema metadata for Hugging Face Datasets features.
17061706
Non-nullable fields are not supported and set to nullable.
17071707
1708+
Also, pa.dictionary is not supported and it uses its underlying type instead.
1709+
Therefore datasets convert DictionaryArray objects to their actual values.
1710+
17081711
Args:
17091712
pa_schema (`pyarrow.Schema`):
17101713
Arrow Schema.

tests/features/test_features.py

+6
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def test_string_to_arrow_bijection_for_primitive_types(self):
9898
with self.assertRaises(ValueError):
9999
string_to_arrow(sdt)
100100

101+
def test_categorical_one_way(self):
102+
# Categorical types (aka dictionary types) need special handling as there isn't a bijection
103+
categorical_type = pa.dictionary(pa.int32(), pa.string())
104+
105+
self.assertEqual("string", _arrow_to_datasets_dtype(categorical_type))
106+
101107
def test_feature_named_type(self):
102108
"""reference: issue #1110"""
103109
features = Features({"_type": Value("string")})

tests/test_arrow_dataset.py

+18
Original file line numberDiff line numberDiff line change
@@ -4826,3 +4826,21 @@ def test_dataset_getitem_raises():
48264826
ds[False]
48274827
with pytest.raises(TypeError):
48284828
ds._getitem(True)
4829+
4830+
4831+
def test_categorical_dataset(tmpdir):
4832+
n_legs = pa.array([2, 4, 5, 100])
4833+
animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]).cast(
4834+
pa.dictionary(pa.int32(), pa.string())
4835+
)
4836+
names = ["n_legs", "animals"]
4837+
4838+
table = pa.Table.from_arrays([n_legs, animals], names=names)
4839+
table_path = str(tmpdir / "data.parquet")
4840+
pa.parquet.write_table(table, table_path)
4841+
4842+
dataset = Dataset.from_parquet(table_path)
4843+
entry = dataset[0]
4844+
4845+
# Categorical types get transparently converted to string
4846+
assert entry["animals"] == "Flamingo"

0 commit comments

Comments
 (0)