Skip to content

fix(app): recursive cursor errors #7727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 3, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,18 @@


class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor

def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._cursor = self._conn.cursor()

def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
Expand All @@ -44,7 +41,8 @@ def remove_image_from_board(
image_name: str,
) -> None:
try:
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM board_images
WHERE image_name = ?;
Expand All @@ -63,7 +61,8 @@ def get_images_for_board(
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
Expand All @@ -73,15 +72,15 @@ def get_images_for_board(
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]

self._cursor.execute(
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
count = cast(int, cursor.fetchone()[0])

return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)

Expand Down Expand Up @@ -128,31 +127,34 @@ def get_all_board_image_names_for_board(
stmt += ";"

# Execute the query
self._cursor.execute(stmt, params)
cursor = self._conn.cursor()
cursor.execute(stmt, params)

result = cast(list[sqlite3.Row], self._cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names

def get_board_for_image(
self,
image_name: str,
) -> Optional[str]:
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])

def get_image_count_for_board(self, board_id: str) -> int:
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
Expand All @@ -162,5 +164,5 @@ def get_image_count_for_board(self, board_id: str) -> int:
""",
(board_id,),
)
count = cast(int, self._cursor.fetchone()[0])
count = cast(int, cursor.fetchone()[0])
return count
37 changes: 20 additions & 17 deletions invokeai/app/services/board_records/board_records_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@


class SqliteBoardRecordStorage(BoardRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor

def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._cursor = self._conn.cursor()

def delete(self, board_id: str) -> None:
try:
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
Expand All @@ -46,7 +43,8 @@ def save(
) -> BoardRecord:
try:
board_id = uuid_string()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
Expand All @@ -64,7 +62,8 @@ def get(
board_id: str,
) -> BoardRecord:
try:
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM boards
Expand All @@ -73,7 +72,7 @@ def get(
(board_id,),
)

result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
if result is None:
Expand All @@ -86,9 +85,10 @@ def update(
changes: BoardChanges,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
Expand All @@ -99,7 +99,7 @@ def update(

# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
Expand All @@ -110,7 +110,7 @@ def update(

# Change the archived status of a board
if changes.archived is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
Expand All @@ -133,6 +133,8 @@ def get_many(
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
cursor = self._conn.cursor()

# Build base query
base_query = """
SELECT *
Expand All @@ -150,9 +152,9 @@ def get_many(
)

# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
cursor.execute(final_query, (limit, offset))

result = cast(list[sqlite3.Row], self._cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

# Determine count query
Expand All @@ -169,15 +171,16 @@ def get_many(
"""

# Execute count query
self._cursor.execute(count_query)
cursor.execute(count_query)

count = cast(int, self._cursor.fetchone()[0])
count = cast(int, cursor.fetchone()[0])

return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)

def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
Expand All @@ -199,9 +202,9 @@ def get_all(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)

self._cursor.execute(final_query)
cursor.execute(final_query)

result = cast(list[sqlite3.Row], self._cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

return boards
Loading
Loading