Skip to content

Commit 15d18af

Browse files
Allow a subquery to be passed into is_in and not_in (#1217)
* adds a subquery for is_in and not_in for tables with fk relationships * pass `QueryString` in, instead of `Select` * allow querystring to be passed in to - cuts down on repetition * add more tests * tweak docs * update docs --------- Co-authored-by: Daniel Townsend <[email protected]>
1 parent 8909679 commit 15d18af

File tree

5 files changed

+164
-13
lines changed

5 files changed

+164
-13
lines changed

docs/src/piccolo/query_clauses/where.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,45 @@ And all rows with a value not contained in the list:
112112
Band.name.not_in(['Terrible Band', 'Awful Band'])
113113
)
114114
115+
You can also pass a subquery into the ``is_in`` clause:
116+
117+
.. code-block:: python
118+
119+
await Band.select().where(
120+
Band.id.is_in(
121+
Concert.select(Concert.band_1).where(
122+
Concert.starts >= datetime.datetime(year=2025, month=1, day=1)
123+
)
124+
)
125+
)
126+
127+
.. hint::
128+
In SQL there are often several ways of solving the same problem. You
129+
can also solve the above using :meth:`join_on <piccolo.columns.base.Column.join_on>`.
130+
131+
.. code-block:: python
132+
133+
>>> await Band.select().where(
134+
... Band.id.join_on(Concert.band_1).starts >= datetime.datetime(
135+
... year=2025, month=1, day=1
136+
... )
137+
... )
138+
139+
Use whichever you prefer, and whichever suits the situation best.
140+
141+
Subqueries can also be passed into the ``not_in`` clause:
142+
143+
.. code-block:: python
144+
145+
await Band.select().where(
146+
Band.id.not_in(
147+
Concert.select(Concert.band_1).where(
148+
Concert.starts >= datetime.datetime(year=2025, month=1, day=1)
149+
)
150+
)
151+
)
152+
153+
115154
-------------------------------------------------------------------------------
116155

117156
``is_null`` / ``is_not_null``

piccolo/columns/base.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
if TYPE_CHECKING: # pragma: no cover
4848
from piccolo.columns.column_types import ForeignKey
49+
from piccolo.query.methods.select import Select
4950
from piccolo.table import Table
5051

5152

@@ -599,18 +600,40 @@ def _validate_choices(
599600

600601
return True
601602

602-
def is_in(self, values: list[Any]) -> Where:
603-
if len(values) == 0:
604-
raise ValueError(
605-
"The `values` list argument must contain at least one value."
606-
)
603+
def is_in(self, values: Union[Select, QueryString, list[Any]]) -> Where:
604+
from piccolo.query.methods.select import Select
605+
606+
if isinstance(values, list):
607+
if len(values) == 0:
608+
raise ValueError(
609+
"The `values` list argument must contain at least one "
610+
"value."
611+
)
612+
elif isinstance(values, Select):
613+
if len(values.columns_delegate.selected_columns) != 1:
614+
raise ValueError(
615+
"A sub select must only return a single column."
616+
)
617+
values = values.querystrings[0]
618+
607619
return Where(column=self, values=values, operator=In)
608620

609-
def not_in(self, values: list[Any]) -> Where:
610-
if len(values) == 0:
611-
raise ValueError(
612-
"The `values` list argument must contain at least one value."
613-
)
621+
def not_in(self, values: Union[Select, QueryString, list[Any]]) -> Where:
622+
from piccolo.query.methods.select import Select
623+
624+
if isinstance(values, list):
625+
if len(values) == 0:
626+
raise ValueError(
627+
"The `values` list argument must contain at least one "
628+
"value."
629+
)
630+
elif isinstance(values, Select):
631+
if len(values.columns_delegate.selected_columns) != 1:
632+
raise ValueError(
633+
"A sub select must only return a single column."
634+
)
635+
values = values.querystrings[0]
636+
614637
return Where(column=self, values=values, operator=NotIn)
615638

616639
def like(self, value: str) -> Where:

piccolo/columns/combination.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def __init__(
146146
self,
147147
column: Column,
148148
value: Any = UNDEFINED,
149-
values: Union[CustomIterable, Undefined] = UNDEFINED,
149+
values: Union[CustomIterable, Undefined, QueryString] = UNDEFINED,
150150
operator: type[ComparisonOperator] = ComparisonOperator,
151151
) -> None:
152152
"""
@@ -156,7 +156,7 @@ def __init__(
156156
self.column = column
157157

158158
self.value = value if value == UNDEFINED else self.clean_value(value)
159-
if values == UNDEFINED:
159+
if (values == UNDEFINED) or isinstance(values, QueryString):
160160
self.values = values
161161
else:
162162
self.values = [self.clean_value(i) for i in values] # type: ignore
@@ -192,6 +192,9 @@ def clean_value(self, value: Any) -> Any:
192192
def values_querystring(self) -> QueryString:
193193
values = self.values
194194

195+
if isinstance(values, QueryString):
196+
return values
197+
195198
if isinstance(values, Undefined):
196199
raise ValueError("values is undefined")
197200

tests/columns/test_combination.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22

3-
from tests.example_apps.music.tables import Band
3+
from tests.example_apps.music.tables import Band, Concert
44

55

66
class TestWhere(unittest.TestCase):
@@ -29,6 +29,20 @@ def test_is_in(self):
2929
with self.assertRaises(ValueError):
3030
Band.name.is_in([])
3131

32+
def test_is_in_subquery(self):
33+
_where = Band.id.is_in(
34+
Concert.select(Concert.band_1).where(Concert.band_1 == 1)
35+
)
36+
sql = _where.__str__()
37+
self.assertEqual(
38+
sql,
39+
'"band"."id" IN (SELECT ALL "concert"."band_1" AS "band_1" FROM "concert" WHERE "concert"."band_1" = 1)', # noqa: E501
40+
)
41+
42+
# a sub select must only return a single column
43+
with self.assertRaises(ValueError):
44+
Band.id.is_in(Concert.select().where(Concert.band_1 == 1))
45+
3246
def test_not_in(self):
3347
_where = Band.name.not_in(["CSharps"])
3448
sql = _where.__str__()
@@ -37,6 +51,20 @@ def test_not_in(self):
3751
with self.assertRaises(ValueError):
3852
Band.name.not_in([])
3953

54+
def test_not_in_subquery(self):
55+
_where = Band.id.not_in(
56+
Concert.select(Concert.band_1).where(Concert.band_1 == 1)
57+
)
58+
sql = _where.__str__()
59+
self.assertEqual(
60+
sql,
61+
'"band"."id" NOT IN (SELECT ALL "concert"."band_1" AS "band_1" FROM "concert" WHERE "concert"."band_1" = 1)', # noqa: E501
62+
)
63+
64+
# a sub select must only return a single column
65+
with self.assertRaises(ValueError):
66+
Band.id.not_in(Concert.select().where(Concert.band_1 == 1))
67+
4068

4169
class TestAnd(unittest.TestCase):
4270
def test_get_column_values(self):

tests/table/test_select.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,64 @@ def test_where_greater_than(self):
258258

259259
self.assertEqual(response, [{"name": "Rustaceans"}])
260260

261+
def test_is_in(self):
262+
self.insert_rows()
263+
264+
response = (
265+
Band.select(Band.name)
266+
.where(Band.manager._.name.is_in(["Guido"]))
267+
.run_sync()
268+
)
269+
270+
self.assertListEqual(response, [{"name": "Pythonistas"}])
271+
272+
def test_is_in_subquery(self):
273+
self.insert_rows()
274+
275+
# This is a contrived example, just for testing.
276+
response = (
277+
Band.select(Band.name)
278+
.where(
279+
Band.manager.is_in(
280+
Manager.select(Manager.id).where(Manager.name == "Guido")
281+
)
282+
)
283+
.run_sync()
284+
)
285+
286+
self.assertListEqual(response, [{"name": "Pythonistas"}])
287+
288+
def test_not_in(self):
289+
self.insert_rows()
290+
291+
response = (
292+
Band.select(Band.name)
293+
.where(Band.manager._.name.not_in(["Guido"]))
294+
.run_sync()
295+
)
296+
297+
self.assertListEqual(
298+
response, [{"name": "Rustaceans"}, {"name": "CSharps"}]
299+
)
300+
301+
def test_not_in_subquery(self):
302+
self.insert_rows()
303+
304+
# This is a contrived example, just for testing.
305+
response = (
306+
Band.select(Band.name)
307+
.where(
308+
Band.manager.not_in(
309+
Manager.select(Manager.id).where(Manager.name == "Guido")
310+
)
311+
)
312+
.run_sync()
313+
)
314+
315+
self.assertListEqual(
316+
response, [{"name": "Rustaceans"}, {"name": "CSharps"}]
317+
)
318+
261319
def test_where_is_null(self):
262320
self.insert_rows()
263321

0 commit comments

Comments
 (0)