Skip to content

Commit 2950fd7

Browse files
authored
[Models] [Postgres] Check if the dynamically-added index is in the table schema before adding (#32731)
* Check if the index is in the table schema before adding * add pre-condition assertion * static checks * Update test_models.py * integrate upstream auth manager changes
1 parent 04d3f45 commit 2950fd7

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

airflow/auth/managers/fab/models/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,15 @@ class RegisterUser(Model):
255255
def add_index_on_ab_user_username_postgres(table, conn, **kw):
256256
if conn.dialect.name != "postgresql":
257257
return
258-
table.indexes.add(Index("idx_ab_user_username", func.lower(table.c.username), unique=True))
258+
index_name = "idx_ab_user_username"
259+
if not any(table_index.name == index_name for table_index in table.indexes):
260+
table.indexes.add(Index(index_name, func.lower(table.c.username), unique=True))
259261

260262

261263
@event.listens_for(RegisterUser.__table__, "before_create")
262264
def add_index_on_ab_register_user_username_postgres(table, conn, **kw):
263265
if conn.dialect.name != "postgresql":
264266
return
265-
table.indexes.add(Index("idx_ab_register_user_username", func.lower(table.c.username), unique=True))
267+
index_name = "idx_ab_register_user_username"
268+
if not any(table_index.name == index_name for table_index in table.indexes):
269+
table.indexes.add(Index(index_name, func.lower(table.c.username), unique=True))
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from unittest import mock
20+
21+
from sqlalchemy import Column, MetaData, String, Table
22+
23+
from airflow.auth.managers.fab.models import (
24+
add_index_on_ab_register_user_username_postgres,
25+
add_index_on_ab_user_username_postgres,
26+
)
27+
28+
_mock_conn = mock.MagicMock()
29+
_mock_conn.dialect = mock.MagicMock()
30+
_mock_conn.dialect.name = "postgresql"
31+
32+
33+
def test_add_index_on_ab_user_username_postgres():
34+
table = Table("test_table", MetaData(), Column("username", String))
35+
36+
assert len(table.indexes) == 0
37+
38+
add_index_on_ab_user_username_postgres(table, _mock_conn)
39+
40+
# Assert that the index was added to the table
41+
assert len(table.indexes) == 1
42+
43+
add_index_on_ab_user_username_postgres(table, _mock_conn)
44+
45+
# Assert that index is not re-added when the schema is recreated
46+
assert len(table.indexes) == 1
47+
48+
49+
def test_add_index_on_ab_register_user_username_postgres():
50+
table = Table("test_table", MetaData(), Column("username", String))
51+
52+
assert len(table.indexes) == 0
53+
54+
add_index_on_ab_register_user_username_postgres(table, _mock_conn)
55+
56+
# Assert that the index was added to the table
57+
assert len(table.indexes) == 1
58+
59+
add_index_on_ab_register_user_username_postgres(table, _mock_conn)
60+
61+
# Assert that index is not re-added when the schema is recreated
62+
assert len(table.indexes) == 1

0 commit comments

Comments
 (0)