|
1 | 1 | # ruff: noqa: E501
|
2 | 2 | import logging
|
3 | 3 |
|
| 4 | +import cratedb_sqlparse |
4 | 5 | import sqlparse
|
5 | 6 |
|
6 | 7 | from cratedb_mcp.settings import PERMIT_ALL_STATEMENTS
|
@@ -131,25 +132,55 @@ def sql_expression_permitted(expression: str) -> bool:
|
131 | 132 |
|
132 | 133 | def _sql_expression_permitted(expression: str) -> bool:
|
133 | 134 |
|
| 135 | + if expression: |
| 136 | + expression = expression.strip() |
| 137 | + |
134 | 138 | if not expression:
|
135 | 139 | return False
|
136 | 140 |
|
137 | 141 | if PERMIT_ALL_STATEMENTS:
|
138 | 142 | return True
|
139 | 143 |
|
140 | 144 | # Parse the SQL statement.
|
141 |
| - parsed = sqlparse.parse(expression.strip()) |
142 |
| - if not parsed: |
143 |
| - return False |
| 145 | + parsed = cratedb_sqlparse.sqlparse(expression) |
144 | 146 |
|
145 |
| - # Check for multiple statements (potential SQL injection). |
| 147 | + # Reject multiple statements to prevent potential SQL injections. |
146 | 148 | if len(parsed) > 1:
|
147 | 149 | return False
|
148 | 150 |
|
149 | 151 | # Check if the expression is valid and if it's a SELECT statement,
|
150 | 152 | # also trying to consider `SELECT ... INTO ...` statements.
|
151 |
| - operation = parsed[0].get_type().upper() |
| 153 | + operation = parsed[0].type.upper() |
| 154 | + is_select = operation == 'SELECT' |
| 155 | + is_rejected = _sql_is_select_into(expression) or _sql_is_evasive(expression) |
| 156 | + if is_select and not is_rejected: |
| 157 | + return True |
| 158 | + return False |
| 159 | + |
| 160 | + |
| 161 | +def _sql_is_select_into(expression: str) -> bool: |
| 162 | + """ |
| 163 | + Helper function using traditional `sqlparse` for catching `SELECT ... INTO ...` statements. |
| 164 | +
|
| 165 | + Examples: |
| 166 | +
|
| 167 | + SELECT * INTO foobar FROM bazqux |
| 168 | + SELECT * FROM bazqux INTO foobar |
| 169 | + """ |
| 170 | + parsed = sqlparse.parse(expression) |
152 | 171 | tokens = [str(item).upper() for item in parsed[0]]
|
153 |
| - if operation != 'SELECT' or (operation == 'SELECT' and 'INTO' in tokens): |
154 |
| - return False |
155 |
| - return True |
| 172 | + return "INTO" in tokens |
| 173 | + |
| 174 | + |
| 175 | +def _sql_is_evasive(expression: str) -> bool: |
| 176 | + """ |
| 177 | + Helper function using traditional `sqlparse` for catching evasive SQL statements. |
| 178 | +
|
| 179 | + Reject multiple statements to prevent potential SQL injections. |
| 180 | +
|
| 181 | + Examples: |
| 182 | +
|
| 183 | + SELECT * FROM users; \uff1b DROP TABLE users |
| 184 | + """ |
| 185 | + parsed = sqlparse.parse(expression) |
| 186 | + return len(parsed) > 1 |
0 commit comments