Skip to content

Commit 8e54152

Browse files
girardajbfbell
authored andcommitted
Add CSV options to the CSV parser (#28491)
* remove invalid legacy option * remove unused option * the tests pass but this is quite messy * very slight clean up * Add skip options to csv format * fix some of the typing issues * fixme comment * remove extra log message * fix typing issues * skip before header * skip after header * format * add another test * Automated Commit - Formatting Changes * auto generate column names * delete dead code * update title and description * true and false values * Update the tests * Add comment * missing test * rename * update expected spec * move to method * Update comment * fix typo * remove unused import * Add a comment * None records do not pass the WaitForDiscoverPolicy * format * remove second branch to ensure we always go through the same processing * Raise an exception if the record is None * reset * Update tests * handle unquoted newlines * Automated Commit - Formatting Changes * Update test case so the quoting is explicit * Update comment * Automated Commit - Formatting Changes * Fail validation if skipping rows before header and header is autogenerated * always fail if a record cannot be parsed * format * set write line_no in error message * remove none check * Automated Commit - Formatting Changes * enable autogenerate test * remove duplicate test * missing unit tests * Update * remove branching * remove unused none check * Update tests * remove branching * format * extract to function * comment * missing type * type annotation * use set * Document that the strings are case-sensitive * public -> private * add unit test * newline --------- Co-authored-by: girarda <[email protected]>
1 parent e76059f commit 8e54152

File tree

8 files changed

+1533
-130
lines changed

8 files changed

+1533
-130
lines changed

airbyte-cdk/python/airbyte_cdk/sources/file_based/config/csv_format.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import codecs
66
from enum import Enum
7-
from typing import Optional
7+
from typing import Any, Mapping, Optional, Set
88

9-
from pydantic import BaseModel, Field, validator
9+
from pydantic import BaseModel, Field, root_validator, validator
1010
from typing_extensions import Literal
1111

1212

@@ -17,6 +17,10 @@ class QuotingBehavior(Enum):
1717
QUOTE_NONE = "Quote None"
1818

1919

20+
DEFAULT_TRUE_VALUES = ["y", "yes", "t", "true", "on", "1"]
21+
DEFAULT_FALSE_VALUES = ["n", "no", "f", "false", "off", "0"]
22+
23+
2024
class CsvFormat(BaseModel):
2125
filetype: Literal["csv"] = "csv"
2226
delimiter: str = Field(
@@ -46,10 +50,34 @@ class CsvFormat(BaseModel):
4650
default=QuotingBehavior.QUOTE_SPECIAL_CHARACTERS,
4751
description="The quoting behavior determines when a value in a row should have quote marks added around it. For example, if Quote Non-numeric is specified, while reading, quotes are expected for row values that do not contain numbers. Or for Quote All, every row value will be expecting quotes.",
4852
)
49-
50-
# Noting that the existing S3 connector had a config option newlines_in_values. This was only supported by pyarrow and not
51-
# the Python csv package. It has a little adoption, but long term we should ideally phase this out because of the drawbacks
52-
# of using pyarrow
53+
null_values: Set[str] = Field(
54+
title="Null Values",
55+
default=[],
56+
description="A set of case-sensitive strings that should be interpreted as null values. For example, if the value 'NA' should be interpreted as null, enter 'NA' in this field.",
57+
)
58+
skip_rows_before_header: int = Field(
59+
title="Skip Rows Before Header",
60+
default=0,
61+
description="The number of rows to skip before the header row. For example, if the header row is on the 3rd row, enter 2 in this field.",
62+
)
63+
skip_rows_after_header: int = Field(
64+
title="Skip Rows After Header", default=0, description="The number of rows to skip after the header row."
65+
)
66+
autogenerate_column_names: bool = Field(
67+
title="Autogenerate Column Names",
68+
default=False,
69+
description="Whether to autogenerate column names if column_names is empty. If true, column names will be of the form “f0”, “f1”… If false, column names will be read from the first CSV row after skip_rows_before_header.",
70+
)
71+
true_values: Set[str] = Field(
72+
title="True Values",
73+
default=DEFAULT_TRUE_VALUES,
74+
description="A set of case-sensitive strings that should be interpreted as true values.",
75+
)
76+
false_values: Set[str] = Field(
77+
title="False Values",
78+
default=DEFAULT_FALSE_VALUES,
79+
description="A set of case-sensitive strings that should be interpreted as false values.",
80+
)
5381

5482
@validator("delimiter")
5583
def validate_delimiter(cls, v: str) -> str:
@@ -78,3 +106,11 @@ def validate_encoding(cls, v: str) -> str:
78106
except LookupError:
79107
raise ValueError(f"invalid encoding format: {v}")
80108
return v
109+
110+
@root_validator
111+
def validate_option_combinations(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
112+
skip_rows_before_header = values.get("skip_rows_before_header", 0)
113+
auto_generate_column_names = values.get("autogenerate_column_names", False)
114+
if skip_rows_before_header > 0 and auto_generate_column_names:
115+
raise ValueError("Cannot skip rows before header and autogenerate column names at the same time.")
116+
return values

airbyte-cdk/python/airbyte_cdk/sources/file_based/file_types/csv_parser.py

Lines changed: 114 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import csv
66
import json
77
import logging
8-
from distutils.util import strtobool
9-
from typing import Any, Dict, Iterable, Mapping, Optional
8+
from functools import partial
9+
from io import IOBase
10+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set
1011

1112
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, QuotingBehavior
1213
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
13-
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError
14+
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError
1415
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
1516
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
1617
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
@@ -34,30 +35,25 @@ async def infer_schema(
3435
stream_reader: AbstractFileBasedStreamReader,
3536
logger: logging.Logger,
3637
) -> Dict[str, Any]:
37-
config_format = config.format.get(config.file_type) if config.format else None
38-
if config_format:
39-
if not isinstance(config_format, CsvFormat):
40-
raise ValueError(f"Invalid format config: {config_format}")
41-
dialect_name = config.name + DIALECT_NAME
42-
csv.register_dialect(
43-
dialect_name,
44-
delimiter=config_format.delimiter,
45-
quotechar=config_format.quote_char,
46-
escapechar=config_format.escape_char,
47-
doublequote=config_format.double_quote,
48-
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
49-
)
50-
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
51-
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
52-
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
53-
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
54-
schema = {field.strip(): {"type": "string"} for field in next(reader)}
55-
csv.unregister_dialect(dialect_name)
56-
return schema
57-
else:
58-
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
59-
reader = csv.DictReader(fp) # type: ignore
60-
return {field.strip(): {"type": "string"} for field in next(reader)}
38+
config_format = config.format.get(config.file_type) if config.format else CsvFormat()
39+
if not isinstance(config_format, CsvFormat):
40+
raise ValueError(f"Invalid format config: {config_format}")
41+
dialect_name = config.name + DIALECT_NAME
42+
csv.register_dialect(
43+
dialect_name,
44+
delimiter=config_format.delimiter,
45+
quotechar=config_format.quote_char,
46+
escapechar=config_format.escape_char,
47+
doublequote=config_format.double_quote,
48+
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
49+
)
50+
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
51+
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
52+
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
53+
headers = self._get_headers(fp, config_format, dialect_name)
54+
schema = {field.strip(): {"type": "string"} for field in headers}
55+
csv.unregister_dialect(dialect_name)
56+
return schema
6157

6258
def parse_records(
6359
self,
@@ -67,38 +63,36 @@ def parse_records(
6763
logger: logging.Logger,
6864
) -> Iterable[Dict[str, Any]]:
6965
schema: Mapping[str, Any] = config.input_schema # type: ignore
70-
config_format = config.format.get(config.file_type) if config.format else None
71-
if config_format:
72-
if not isinstance(config_format, CsvFormat):
73-
raise ValueError(f"Invalid format config: {config_format}")
74-
# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
75-
# Wwe don't unregister the dialect because we are lazily parsing each csv file to generate records
76-
dialect_name = config.name + DIALECT_NAME
77-
csv.register_dialect(
78-
dialect_name,
79-
delimiter=config_format.delimiter,
80-
quotechar=config_format.quote_char,
81-
escapechar=config_format.escape_char,
82-
doublequote=config_format.double_quote,
83-
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
84-
)
85-
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
86-
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
87-
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
88-
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
89-
yield from self._read_and_cast_types(reader, schema, logger)
90-
else:
91-
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
92-
reader = csv.DictReader(fp) # type: ignore
93-
yield from self._read_and_cast_types(reader, schema, logger)
66+
config_format = config.format.get(config.file_type) if config.format else CsvFormat()
67+
if not isinstance(config_format, CsvFormat):
68+
raise ValueError(f"Invalid format config: {config_format}")
69+
# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
70+
# We don't unregister the dialect because we are lazily parsing each csv file to generate records
71+
# This will potentially be a problem if we ever process multiple streams concurrently
72+
dialect_name = config.name + DIALECT_NAME
73+
csv.register_dialect(
74+
dialect_name,
75+
delimiter=config_format.delimiter,
76+
quotechar=config_format.quote_char,
77+
escapechar=config_format.escape_char,
78+
doublequote=config_format.double_quote,
79+
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
80+
)
81+
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
82+
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
83+
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
84+
self._skip_rows_before_header(fp, config_format.skip_rows_before_header)
85+
field_names = self._auto_generate_headers(fp, config_format) if config_format.autogenerate_column_names else None
86+
reader = csv.DictReader(fp, dialect=dialect_name, fieldnames=field_names) # type: ignore
87+
yield from self._read_and_cast_types(reader, schema, config_format, logger)
9488

9589
@property
9690
def file_read_mode(self) -> FileReadMode:
9791
return FileReadMode.READ
9892

9993
@staticmethod
10094
def _read_and_cast_types(
101-
reader: csv.DictReader, schema: Optional[Mapping[str, Any]], logger: logging.Logger # type: ignore
95+
reader: csv.DictReader, schema: Optional[Mapping[str, Any]], config_format: CsvFormat, logger: logging.Logger # type: ignore
10296
) -> Iterable[Dict[str, Any]]:
10397
"""
10498
If the user provided a schema, attempt to cast the record values to the associated type.
@@ -107,16 +101,65 @@ def _read_and_cast_types(
107101
cast it to a string. Downstream, the user's validation policy will determine whether the
108102
record should be emitted.
109103
"""
110-
if not schema:
111-
yield from reader
104+
cast_fn = CsvParser._get_cast_function(schema, config_format, logger)
105+
for i, row in enumerate(reader):
106+
if i < config_format.skip_rows_after_header:
107+
continue
108+
# The row was not properly parsed if any of the values are None
109+
if any(val is None for val in row.values()):
110+
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD)
111+
else:
112+
yield CsvParser._to_nullable(cast_fn(row), config_format.null_values)
112113

113-
else:
114+
@staticmethod
115+
def _get_cast_function(
116+
schema: Optional[Mapping[str, Any]], config_format: CsvFormat, logger: logging.Logger
117+
) -> Callable[[Mapping[str, str]], Mapping[str, str]]:
118+
# Only cast values if the schema is provided
119+
if schema:
114120
property_types = {col: prop["type"] for col, prop in schema["properties"].items()}
115-
for row in reader:
116-
yield cast_types(row, property_types, logger)
121+
return partial(_cast_types, property_types=property_types, config_format=config_format, logger=logger)
122+
else:
123+
# If no schema is provided, yield the rows as they are
124+
return _no_cast
125+
126+
@staticmethod
127+
def _to_nullable(row: Mapping[str, str], null_values: Set[str]) -> Dict[str, Optional[str]]:
128+
nullable = row | {k: None if v in null_values else v for k, v in row.items()}
129+
return nullable
130+
131+
@staticmethod
132+
def _skip_rows_before_header(fp: IOBase, rows_to_skip: int) -> None:
133+
"""
134+
Skip rows before the header. This has to be done on the file object itself, not the reader
135+
"""
136+
for _ in range(rows_to_skip):
137+
fp.readline()
138+
139+
def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]:
140+
# Note that this method assumes the dialect has already been registered if we're parsing the headers
141+
if config_format.autogenerate_column_names:
142+
return self._auto_generate_headers(fp, config_format)
143+
else:
144+
# If we're not autogenerating column names, we need to skip the rows before the header
145+
self._skip_rows_before_header(fp, config_format.skip_rows_before_header)
146+
# Then read the header
147+
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
148+
return next(reader) # type: ignore
117149

150+
def _auto_generate_headers(self, fp: IOBase, config_format: CsvFormat) -> List[str]:
151+
"""
152+
Generates field names as [f0, f1, ...] in the same way as pyarrow's csv reader with autogenerate_column_names=True.
153+
See https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html
154+
"""
155+
next_line = next(fp).strip()
156+
number_of_columns = len(next_line.split(config_format.delimiter)) # type: ignore
157+
# Reset the file pointer to the beginning of the file so that the first row is not skipped
158+
fp.seek(0)
159+
return [f"f{i}" for i in range(number_of_columns)]
118160

119-
def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logging.Logger) -> Dict[str, Any]:
161+
162+
def _cast_types(row: Dict[str, str], property_types: Dict[str, Any], config_format: CsvFormat, logger: logging.Logger) -> Dict[str, Any]:
120163
"""
121164
Casts the values in the input 'row' dictionary according to the types defined in the JSON schema.
122165
@@ -142,7 +185,7 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg
142185

143186
elif python_type == bool:
144187
try:
145-
cast_value = strtobool(value)
188+
cast_value = _value_to_bool(value, config_format.true_values, config_format.false_values)
146189
except ValueError:
147190
warnings.append(_format_warning(key, value, prop_type))
148191

@@ -178,5 +221,17 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg
178221
return result
179222

180223

224+
def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) -> bool:
225+
if value in true_values:
226+
return True
227+
if value in false_values:
228+
return False
229+
raise ValueError(f"Value {value} is not a valid boolean value")
230+
231+
181232
def _format_warning(key: str, value: str, expected_type: Optional[Any]) -> str:
182233
return f"{key}: value={value},expected_type={expected_type}"
234+
235+
236+
def _no_cast(row: Mapping[str, str]) -> Mapping[str, str]:
237+
return row

airbyte-cdk/python/airbyte_cdk/sources/file_based/stream/default_file_based_stream.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FileBasedSourceError,
1616
InvalidSchemaError,
1717
MissingSchemaError,
18+
RecordParseError,
1819
SchemaInferenceError,
1920
StopSyncPerValidationPolicy,
2021
)
@@ -105,6 +106,18 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping
105106
)
106107
break
107108

109+
except RecordParseError:
110+
# Increment line_no because the exception was raised before we could increment it
111+
line_no += 1
112+
yield AirbyteMessage(
113+
type=MessageType.LOG,
114+
log=AirbyteLogMessage(
115+
level=Level.ERROR,
116+
message=f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream={self.name} file={file.uri} line_no={line_no} n_skipped={n_skipped}",
117+
stack_trace=traceback.format_exc(),
118+
),
119+
)
120+
108121
except Exception:
109122
yield AirbyteMessage(
110123
type=MessageType.LOG,
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#
2+
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
import pytest as pytest
6+
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
7+
8+
9+
@pytest.mark.parametrize(
10+
"skip_rows_before_header, autogenerate_column_names, expected_error",
11+
[
12+
pytest.param(1, True, ValueError, id="test_skip_rows_before_header_and_autogenerate_column_names"),
13+
pytest.param(1, False, None, id="test_skip_rows_before_header_and_no_autogenerate_column_names"),
14+
pytest.param(0, True, None, id="test_no_skip_rows_before_header_and_autogenerate_column_names"),
15+
pytest.param(0, False, None, id="test_no_skip_rows_before_header_and_no_autogenerate_column_names"),
16+
]
17+
)
18+
def test_csv_format(skip_rows_before_header, autogenerate_column_names, expected_error):
19+
if expected_error:
20+
with pytest.raises(expected_error):
21+
CsvFormat(skip_rows_before_header=skip_rows_before_header, autogenerate_column_names=autogenerate_column_names)
22+
else:
23+
CsvFormat(skip_rows_before_header=skip_rows_before_header, autogenerate_column_names=autogenerate_column_names)

0 commit comments

Comments
 (0)