Skip to content

fix: exclude manually added columns from copy #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,24 +968,33 @@ async def _load_copy_types(self, conn: AsyncConnection) -> None:
Args:
conn (AsyncConnection): The database connection.
"""
target_columns: list[str] = list(self.queries.pk_attnames) + [
"chunk_seq",
"chunk",
"embedding",
]
async with conn.cursor() as cursor:
await cursor.execute(
"""
select a.atttypid
select a.attname, a.atttypid
from pg_catalog.pg_class k
inner join pg_catalog.pg_namespace n
on (k.relnamespace operator(pg_catalog.=) n.oid)
inner join pg_catalog.pg_attribute a
on (k.oid operator(pg_catalog.=) a.attrelid)
where n.nspname operator(pg_catalog.=) %s
and k.relname operator(pg_catalog.=) %s
and a.attname operator(pg_catalog.!=) 'embedding_uuid'
AND a.attname = ANY(%s)
and a.attnum operator(pg_catalog.>) 0
order by a.attnum
""",
(self.vectorizer.target_schema, self.vectorizer.target_table),
(
self.vectorizer.target_schema,
self.vectorizer.target_table,
target_columns,
),
)
self.copy_types = [row[0] for row in await cursor.fetchall()]
column_name_to_type = {row[0]: row[1] for row in await cursor.fetchall()}
self.copy_types = [column_name_to_type[col] for col in target_columns]
assert self.copy_types is not None
# len(source_pk) + chunk_seq + chunk + embedding
assert len(self.copy_types) == len(self.vectorizer.source_pk) + 3
Expand Down
107 changes: 107 additions & 0 deletions projects/pgai/tests/vectorizer/cli/test_additional_colums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest
from psycopg import Connection
from psycopg.rows import dict_row

from tests.vectorizer.cli.conftest import (
TestDatabase,
configure_vectorizer,
run_vectorizer_worker,
setup_source_table,
)


@pytest.mark.parametrize(
"column_def",
[
"new_column text",
"new_column text NOT NULL DEFAULT 'default_value'",
],
)
def test_additional_columns_are_added_to_target_table(
cli_db: tuple[TestDatabase, Connection],
cli_db_url: str,
column_def: str,
):
"""Test that if additional columns are added to the target table,
the vectorizer still works"""
_, connection = cli_db
table_name = setup_source_table(connection, 2)
vectorizer_id = configure_vectorizer(
table_name,
cli_db[1],
)
with connection.cursor(row_factory=dict_row) as cur:
cur.execute(f"ALTER TABLE blog_embedding_store ADD COLUMN {column_def}") # type: ignore

result = run_vectorizer_worker(cli_db_url, vectorizer_id)
print(result.stdout)
assert result.exit_code == 0

with connection.cursor(row_factory=dict_row) as cur:
cur.execute("SELECT * FROM blog_embedding_store")
rows = cur.fetchall()
assert len(rows) == 2


def test_embedding_column_removal_and_readd(
cli_db: tuple[TestDatabase, Connection],
cli_db_url: str,
):
"""Test that the vectorizer still works when the embedding column is removed,
another column is added, and then the embedding column is re-added."""
_, connection = cli_db
table_name = setup_source_table(connection, 2)
vectorizer_id = configure_vectorizer(
table_name,
cli_db[1],
)

# First run to create original rows
result = run_vectorizer_worker(cli_db_url, vectorizer_id)
assert result.exit_code == 0

# Check original rows were created
with connection.cursor(row_factory=dict_row) as cur:
cur.execute("SELECT * FROM blog_embedding_store")
rows = cur.fetchall()
assert len(rows) == 2
# Verify embedding column exists
assert "embedding" in rows[0]

# Drop View so we can change column order
cur.execute("DROP VIEW IF EXISTS blog_embedding")

# Remove embedding column
cur.execute("ALTER TABLE blog_embedding_store DROP COLUMN embedding")

# Add another optional column
cur.execute("ALTER TABLE blog_embedding_store ADD COLUMN extra_data text")

# Re-add embedding column with same type
cur.execute(
"ALTER TABLE blog_embedding_store ADD COLUMN embedding vector(1536)"
)

# Remove original rows
cur.execute("DELETE FROM blog")

# Add new rows
values = [(i, i, f"post_{i}") for i in range(1, 3)]
cur.executemany(
"INSERT INTO blog(id, id2, content) VALUES (%s, %s, %s)",
values,
)

# Run vectorizer again
result = run_vectorizer_worker(cli_db_url, vectorizer_id)
print(result.stdout)
assert result.exit_code == 0

# Verify vectorizer still works
with connection.cursor(row_factory=dict_row) as cur:
cur.execute("SELECT * FROM blog_embedding_store")
rows = cur.fetchall()
assert len(rows) == 2
# Verify embedding column exists and has data
assert "embedding" in rows[0]
assert rows[0]["embedding"] is not None