Skip to content

Commit 8c4e7c7

Browse files
maxi297harrytou
authored andcommitted
Issue 28893/infer schema csv (airbytehq#29099)
1 parent a67b7a1 commit 8c4e7c7

File tree

16 files changed

+768
-317
lines changed

16 files changed

+768
-317
lines changed

airbyte-cdk/python/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _check_parse_record(self, stream: "AbstractFileBasedStream", file: RemoteFil
8282
parser = stream.get_parser(stream.config.file_type)
8383

8484
try:
85-
record = next(iter(parser.parse_records(stream.config, file, self.stream_reader, logger)))
85+
record = next(iter(parser.parse_records(stream.config, file, self.stream_reader, logger, discovered_schema=None)))
8686
except StopIteration:
8787
# The file is empty. We've verified that we can open it, so will
8888
# consider the connection check successful even though it means

airbyte-cdk/python/airbyte_cdk/sources/file_based/config/csv_format.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ class QuotingBehavior(Enum):
1717
QUOTE_NONE = "Quote None"
1818

1919

20+
class InferenceType(Enum):
21+
NONE = "None"
22+
PRIMITIVE_TYPES_ONLY = "Primitive Types Only"
23+
24+
2025
DEFAULT_TRUE_VALUES = ["y", "yes", "t", "true", "on", "1"]
2126
DEFAULT_FALSE_VALUES = ["n", "no", "f", "false", "off", "0"]
2227

@@ -81,6 +86,12 @@ class Config:
8186
default=DEFAULT_FALSE_VALUES,
8287
description="A set of case-sensitive strings that should be interpreted as false values.",
8388
)
89+
inference_type: InferenceType = Field(
90+
title="Inference Type",
91+
default=InferenceType.NONE,
92+
description="How to infer the types of the columns. If none, inference default to strings.",
93+
airbyte_hidden=True,
94+
)
8495

8596
@validator("delimiter")
8697
def validate_delimiter(cls, v: str) -> str:

airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/avro_parser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
import logging
66
import uuid
7-
from typing import Any, Dict, Iterable, Mapping
7+
from typing import Any, Dict, Iterable, Mapping, Optional
88

99
import fastavro
1010
from airbyte_cdk.sources.file_based.config.avro_format import AvroFormat
1111
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
1212
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
1313
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
1414
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
15+
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType
1516

1617
AVRO_TYPE_TO_JSON_TYPE = {
1718
"null": "null",
@@ -47,7 +48,7 @@ async def infer_schema(
4748
file: RemoteFile,
4849
stream_reader: AbstractFileBasedStreamReader,
4950
logger: logging.Logger,
50-
) -> Dict[str, Any]:
51+
) -> SchemaType:
5152
avro_format = config.format or AvroFormat()
5253
if not isinstance(avro_format, AvroFormat):
5354
raise ValueError(f"Expected ParquetFormat, got {avro_format}")
@@ -132,6 +133,7 @@ def parse_records(
132133
file: RemoteFile,
133134
stream_reader: AbstractFileBasedStreamReader,
134135
logger: logging.Logger,
136+
discovered_schema: Optional[Mapping[str, SchemaType]],
135137
) -> Iterable[Dict[str, Any]]:
136138
avro_format = config.format or AvroFormat()
137139
if not isinstance(avro_format, AvroFormat):

airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/csv_parser.py

Lines changed: 289 additions & 144 deletions
Large diffs are not rendered by default.

airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/file_type_parser.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
import logging
66
from abc import ABC, abstractmethod
7-
from typing import Any, Dict, Iterable
7+
from typing import Any, Dict, Iterable, Mapping, Optional
88

99
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
1010
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
1111
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
12+
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType
1213

13-
Schema = Dict[str, str]
1414
Record = Dict[str, Any]
1515

1616

@@ -27,7 +27,7 @@ async def infer_schema(
2727
file: RemoteFile,
2828
stream_reader: AbstractFileBasedStreamReader,
2929
logger: logging.Logger,
30-
) -> Schema:
30+
) -> SchemaType:
3131
"""
3232
Infer the JSON Schema for this file.
3333
"""
@@ -40,6 +40,7 @@ def parse_records(
4040
file: RemoteFile,
4141
stream_reader: AbstractFileBasedStreamReader,
4242
logger: logging.Logger,
43+
discovered_schema: Optional[Mapping[str, SchemaType]],
4344
) -> Iterable[Record]:
4445
"""
4546
Parse and emit each record.

airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/jsonl_parser.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import json
66
import logging
7-
from typing import Any, Dict, Iterable
7+
from typing import Any, Dict, Iterable, Mapping, Optional
88

99
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
1010
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError
1111
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
1212
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
1313
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
14-
from airbyte_cdk.sources.file_based.schema_helpers import PYTHON_TYPE_MAPPING, merge_schemas
14+
from airbyte_cdk.sources.file_based.schema_helpers import PYTHON_TYPE_MAPPING, SchemaType, merge_schemas
1515

1616

1717
class JsonlParser(FileTypeParser):
@@ -25,12 +25,12 @@ async def infer_schema(
2525
file: RemoteFile,
2626
stream_reader: AbstractFileBasedStreamReader,
2727
logger: logging.Logger,
28-
) -> Dict[str, Any]:
28+
) -> SchemaType:
2929
"""
3030
Infers the schema for the file by inferring the schema for each line, and merging
3131
it with the previously-inferred schema.
3232
"""
33-
inferred_schema: Dict[str, Any] = {}
33+
inferred_schema: Mapping[str, Any] = {}
3434

3535
for entry in self._parse_jsonl_entries(file, stream_reader, logger, read_limit=True):
3636
line_schema = self._infer_schema_for_record(entry)
@@ -44,6 +44,7 @@ def parse_records(
4444
file: RemoteFile,
4545
stream_reader: AbstractFileBasedStreamReader,
4646
logger: logging.Logger,
47+
discovered_schema: Optional[Mapping[str, SchemaType]],
4748
) -> Iterable[Dict[str, Any]]:
4849
"""
4950
This code supports parsing json objects over multiple lines even though this does not align with the JSONL format. This is for

airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/parquet_parser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import logging
77
import os
8-
from typing import Any, Dict, Iterable, List, Mapping
8+
from typing import Any, Dict, Iterable, List, Mapping, Optional
99
from urllib.parse import unquote
1010

1111
import pyarrow as pa
@@ -15,6 +15,7 @@
1515
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
1616
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
1717
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
18+
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType
1819
from pyarrow import Scalar
1920

2021

@@ -28,7 +29,7 @@ async def infer_schema(
2829
file: RemoteFile,
2930
stream_reader: AbstractFileBasedStreamReader,
3031
logger: logging.Logger,
31-
) -> Dict[str, Any]:
32+
) -> SchemaType:
3233
parquet_format = config.format or ParquetFormat()
3334
if not isinstance(parquet_format, ParquetFormat):
3435
raise ValueError(f"Expected ParquetFormat, got {parquet_format}")
@@ -51,6 +52,7 @@ def parse_records(
5152
file: RemoteFile,
5253
stream_reader: AbstractFileBasedStreamReader,
5354
logger: logging.Logger,
55+
discovered_schema: Optional[Mapping[str, SchemaType]],
5456
) -> Iterable[Dict[str, Any]]:
5557
parquet_format = config.format or ParquetFormat()
5658
if not isinstance(parquet_format, ParquetFormat):

airbyte-cdk/python/airbyte_cdk/sources/file_based/schema_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError, SchemaInferenceError
1212

1313
JsonSchemaSupportedType = Union[List[str], Literal["string"], str]
14-
SchemaType = Dict[str, Dict[str, JsonSchemaSupportedType]]
14+
SchemaType = Mapping[str, Mapping[str, JsonSchemaSupportedType]]
1515

1616
schemaless_schema = {"type": "object", "properties": {"data": {"type": "object"}}}
1717

@@ -99,7 +99,7 @@ def merge_schemas(schema1: SchemaType, schema2: SchemaType) -> SchemaType:
9999
if not isinstance(t, dict) or "type" not in t or not _is_valid_type(t["type"]):
100100
raise SchemaInferenceError(FileBasedSourceError.UNRECOGNIZED_TYPE, key=k, type=t)
101101

102-
merged_schema: Dict[str, Any] = deepcopy(schema1)
102+
merged_schema: Dict[str, Any] = deepcopy(schema1) # type: ignore # as of 2023-08-08, deepcopy can copy Mapping
103103
for k2, t2 in schema2.items():
104104
t1 = merged_schema.get(k2)
105105
if t1 is None:
@@ -116,7 +116,7 @@ def _is_valid_type(t: JsonSchemaSupportedType) -> bool:
116116
return t == "array" or get_comparable_type(t) is not None
117117

118118

119-
def _choose_wider_type(key: str, t1: Dict[str, Any], t2: Dict[str, Any]) -> Dict[str, Any]:
119+
def _choose_wider_type(key: str, t1: Mapping[str, Any], t2: Mapping[str, Any]) -> Mapping[str, Any]:
120120
if (t1["type"] == "array" or t2["type"] == "array") and t1 != t2:
121121
raise SchemaInferenceError(
122122
FileBasedSourceError.SCHEMA_INFERENCE_ERROR,

airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import itertools
77
import traceback
88
from functools import cache
9-
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Union
9+
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Set, Union
1010

1111
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level
1212
from airbyte_cdk.models import Type as MessageType
@@ -20,7 +20,7 @@
2020
StopSyncPerValidationPolicy,
2121
)
2222
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
23-
from airbyte_cdk.sources.file_based.schema_helpers import merge_schemas, schemaless_schema
23+
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType, merge_schemas, schemaless_schema
2424
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
2525
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
2626
from airbyte_cdk.sources.file_based.types import StreamSlice
@@ -84,7 +84,7 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping
8484
n_skipped = line_no = 0
8585

8686
try:
87-
for record in parser.parse_records(self.config, file, self._stream_reader, self.logger):
87+
for record in parser.parse_records(self.config, file, self._stream_reader, self.logger, schema):
8888
line_no += 1
8989
if self.config.schemaless:
9090
record = {"data": record}
@@ -231,8 +231,8 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
231231
Each file type has a corresponding `infer_schema` handler.
232232
Dispatch on file type.
233233
"""
234-
base_schema: Dict[str, Any] = {}
235-
pending_tasks: Set[asyncio.tasks.Task[Dict[str, Any]]] = set()
234+
base_schema: SchemaType = {}
235+
pending_tasks: Set[asyncio.tasks.Task[SchemaType]] = set()
236236

237237
n_started, n_files = 0, len(files)
238238
files_iterator = iter(files)
@@ -251,7 +251,7 @@ async def _infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
251251

252252
return base_schema
253253

254-
async def _infer_file_schema(self, file: RemoteFile) -> Dict[str, Any]:
254+
async def _infer_file_schema(self, file: RemoteFile) -> SchemaType:
255255
try:
256256
return await self.get_parser(self.config.file_type).infer_schema(self.config, file, self._stream_reader, self.logger)
257257
except Exception as exc:

airbyte-cdk/python/unit_tests/sources/file_based/config/test_csv_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
pytest.param(0, False, None, id="test_no_skip_rows_before_header_and_no_autogenerate_column_names"),
1616
]
1717
)
18-
def test_csv_format(skip_rows_before_header, autogenerate_column_names, expected_error):
18+
def test_csv_format_skip_rows_and_autogenerate_column_names(skip_rows_before_header, autogenerate_column_names, expected_error) -> None:
1919
if expected_error:
2020
with pytest.raises(expected_error):
2121
CsvFormat(skip_rows_before_header=skip_rows_before_header, autogenerate_column_names=autogenerate_column_names)

0 commit comments

Comments
 (0)