Skip to content

Commit 6cfe303

Browse files
wenghallisonwang-db
authored andcommitted
[SPARK-51919][PYTHON] Allow overwriting statically registered Python Data Source
### What changes were proposed in this pull request? - Allow overwriting static Python Data Sources during registration - Update documentation to clarify Python Data Source behavior and registration options ### Why are the changes needed? Static registration is a bit obscure and doesn't always work as expected (e.g. when the module providing DefaultSource is installed after `lookup_data_sources` already ran). So in practice users (or LLM agents) often want to explicitly register the data source even if it is provided as a DefaultSource. Raising an error in this case interrupts the workflow, making LLM agents spend extra tokens regenerating the same code but without registration. This change also makes the behavior consistent with user data source registration which are already allowed to overwrite previous user registrations. ### Does this PR introduce _any_ user-facing change? Yes. Previously, registering a Python Data Source with the same name as a statically registered one would throw an error. With this change, it will overwrite the static registration. ### How was this patch tested? Added a test in `PythonDataSourceSuite.scala` to verify that static sources can be overwritten correctly. ### Was this patch authored or co-authored using generative AI tooling? No Closes #50716 from wengh/pyds-overwrite-static. Authored-by: Haoyu Weng <[email protected]> Signed-off-by: Allison Wang <[email protected]>
1 parent 8915c60 commit 6cfe303

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

python/docs/source/tutorial/sql/python_data_source.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,4 +520,6 @@ The following example demonstrates how to implement a basic Data Source using Ar
520520
Usage Notes
521521
-----------
522522

523-
- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other Data Sources.
523+
- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other non-Python Data Sources.
524+
- It is allowed to register multiple Python Data Sources with the same name. Later registrations will overwrite earlier ones.
525+
- To automatically register a data source, export it as ``DefaultSource`` in a top level module with name prefix ``pyspark_``. See `pyspark_huggingface <https://github.com/huggingface/pyspark_huggingface>`_ for an example.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,13 @@ class DataSourceManager extends Logging {
4848
*/
4949
def registerDataSource(name: String, source: UserDefinedPythonDataSource): Unit = {
5050
val normalizedName = normalize(name)
51-
if (staticDataSourceBuilders.contains(normalizedName)) {
52-
// Cannot overwrite static Python Data Sources.
53-
throw QueryCompilationErrors.dataSourceAlreadyExists(name)
54-
}
5551
val previousValue = runtimeDataSourceBuilders.put(normalizedName, source)
5652
if (previousValue != null) {
5753
logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a previously " +
5854
log"registered data source.")
55+
} else if (staticDataSourceBuilders.contains(normalizedName)) {
56+
logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a statically " +
57+
log"registered data source.")
5958
}
6059
}
6160

@@ -64,11 +63,7 @@ class DataSourceManager extends Logging {
6463
* it does not exist.
6564
*/
6665
def lookupDataSource(name: String): UserDefinedPythonDataSource = {
67-
if (dataSourceExists(name)) {
68-
val normalizedName = normalize(name)
69-
staticDataSourceBuilders.getOrElse(
70-
normalizedName, runtimeDataSourceBuilders.get(normalizedName))
71-
} else {
66+
getDataSource(name).getOrElse {
7267
throw QueryCompilationErrors.dataSourceDoesNotExist(name)
7368
}
7469
}
@@ -77,9 +72,14 @@ class DataSourceManager extends Logging {
7772
* Checks if a data source with the specified name exists (case-insensitive).
7873
*/
7974
def dataSourceExists(name: String): Boolean = {
75+
getDataSource(name).isDefined
76+
}
77+
78+
private def getDataSource(name: String): Option[UserDefinedPythonDataSource] = {
8079
val normalizedName = normalize(name)
81-
staticDataSourceBuilders.contains(normalizedName) ||
82-
runtimeDataSourceBuilders.containsKey(normalizedName)
80+
// Runtime registration takes precedence over static.
81+
Option(runtimeDataSourceBuilders.get(normalizedName))
82+
.orElse(staticDataSourceBuilders.get(normalizedName))
8383
}
8484

8585
override def clone(): DataSourceManager = {

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,24 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase {
126126
assume(shouldTestPandasUDFs)
127127
val df = spark.read.format(staticSourceName).load()
128128
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))
129+
130+
// Overwrite the static source
131+
val errorText = "static source overwritten"
132+
val dataSourceScript =
133+
s"""
134+
|from pyspark.sql.datasource import DataSource
135+
|
136+
|class $staticSourceName(DataSource):
137+
| def schema(self) -> str:
138+
| raise Exception("$errorText")
139+
|""".stripMargin
140+
val dataSource = createUserDefinedPythonDataSource(
141+
name = staticSourceName, pythonScript = dataSourceScript)
142+
spark.dataSource.registerPython(staticSourceName, dataSource)
143+
val err = intercept[AnalysisException] {
144+
spark.read.format(staticSourceName).load()
145+
}
146+
assert(err.getMessage.contains(errorText))
129147
}
130148

131149
test("simple data source") {

0 commit comments

Comments
 (0)