Skip to content

Commit b01eef0

Browse files
authored
more dtype support: int and float types (#101)
1 parent ffd57d8 commit b01eef0

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

pandera/dtypes.py

+22
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,18 @@ class PandasDtype(Enum):
88
DateTime = "datetime64[ns]"
99
Category = "category"
1010
Float = "float64"
11+
Float16 = "float16"
12+
Float32 = "float32"
13+
Float64 = "float64"
1114
Int = "int64"
15+
Int8 = "int8"
16+
Int16 = "int16"
17+
Int32 = "int32"
18+
Int64 = "int64"
19+
UInt8 = "uint8"
20+
UInt16 = "uint16"
21+
UInt32 = "uint32"
22+
UInt64 = "uint64"
1223
Object = "object"
1324
String = "object"
1425
Timedelta = "timedelta64[ns]"
@@ -18,7 +29,18 @@ class PandasDtype(Enum):
1829
DateTime = PandasDtype.DateTime
1930
Category = PandasDtype.Category
2031
Float = PandasDtype.Float
32+
Float16 = PandasDtype.Float16
33+
Float32 = PandasDtype.Float32
34+
Float64 = PandasDtype.Float64
2135
Int = PandasDtype.Int
36+
Int8 = PandasDtype.Int8
37+
Int16 = PandasDtype.Int16
38+
Int32 = PandasDtype.Int32
39+
Int64 = PandasDtype.Int64
40+
UInt8 = PandasDtype.UInt8
41+
UInt16 = PandasDtype.UInt16
42+
UInt32 = PandasDtype.UInt32
43+
UInt64 = PandasDtype.UInt64
2244
Object = PandasDtype.Object
2345
String = PandasDtype.String
2446
Timedelta = PandasDtype.Timedelta

tests/test_pandera.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pandera import Column, DataFrameSchema, Index, MultiIndex, \
99
SeriesSchema, Bool, Category, Check, DateTime, Float, Int, Object, \
1010
String, Timedelta, check_input, check_output, Hypothesis
11+
from pandera import dtypes
1112
from scipy import stats
1213

1314

@@ -384,12 +385,42 @@ def _assert_expectation(result_df):
384385
transformer.transform_secord_arg_with_dict_getter(None, dataframe))
385386

386387

387-
def test_string_dtypes():
388-
# TODO: add tests for all datatypes
389-
schema = DataFrameSchema(
390-
{"col": Column("float64", nullable=True)})
391-
df = pd.DataFrame({"col": [np.nan, 1.0, 2.0]})
392-
assert isinstance(schema.validate(df), pd.DataFrame)
388+
def test_dtypes():
389+
for dtype in [
390+
dtypes.Float,
391+
dtypes.Float16,
392+
dtypes.Float32,
393+
dtypes.Float64]:
394+
schema = DataFrameSchema({"col": Column(dtype, nullable=False)})
395+
validated_df = schema.validate(
396+
pd.DataFrame(
397+
{"col": [-123.1, -7654.321, 1.0, 1.1, 1199.51, 5.1, 4.6]},
398+
dtype=dtype.value))
399+
assert isinstance(validated_df, pd.DataFrame)
400+
401+
for dtype in [
402+
dtypes.Int,
403+
dtypes.Int8,
404+
dtypes.Int16,
405+
dtypes.Int32,
406+
dtypes.Int64]:
407+
schema = DataFrameSchema({"col": Column(dtype, nullable=False)})
408+
validated_df = schema.validate(
409+
pd.DataFrame(
410+
{"col": [-712, -4, -321, 0, 1, 777, 5, 123, 9000]},
411+
dtype=dtype.value))
412+
assert isinstance(validated_df, pd.DataFrame)
413+
414+
for dtype in [
415+
dtypes.UInt8,
416+
dtypes.UInt16,
417+
dtypes.UInt32,
418+
dtypes.UInt64]:
419+
schema = DataFrameSchema({"col": Column(dtype, nullable=False)})
420+
validated_df = schema.validate(
421+
pd.DataFrame(
422+
{"col": [1, 777, 5, 123, 9000]}, dtype=dtype.value))
423+
assert isinstance(validated_df, pd.DataFrame)
393424

394425

395426
def test_nullable_int():

0 commit comments

Comments
 (0)