Skip to content

Commit 2d6c818

Browse files
committed
1. Expose QueryBuilder as lmql.QueryBuilder and move the implementation to lmql/api
2. Wrap the output of `QueryBuilder` as `QueryExecution` class which has `async run()` and `run_sync()` methods. Usage example: `result = lmql.QueryBuilder().set_prompt("What is the capital of France? [ANSWER]").set_model("gpt2").build().run_sync()`
1 parent 0c3ec95 commit 2d6c818

File tree

4 files changed

+32
-12
lines changed

4 files changed

+32
-12
lines changed

docs/docs/language/reference.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -204,16 +204,16 @@ distribution
204204
Instead of writing the query in string, you could also write it in a more programmatic way with query builder.
205205
```python
206206
import lmql
207-
from lmql.language.query_builder import QueryBuilder
208207

209-
query = (QueryBuilder()
208+
query = (lmql.QueryBuilder()
210209
.set_decoder('argmax')
211210
.set_prompt('What is the capital of France? [ANSWER]')
212211
.set_model('gpt2')
213212
.set_distribution('ANSWER', '["A", "B"]')
214213
.build())
215214

216-
lmql.run_sync(query)
215+
query.run_sync()
216+
# You can also run it asynchronously with query.run_async() and await the result
217217
```
218218

219219

src/lmql/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,7 @@
3131
from lmql.runtime.lmql_runtime import (LMQLQueryFunction, compiled_query, tag)
3232

3333
# event loop utils
34-
from lmql.runtime.loop import main
34+
from lmql.runtime.loop import main
35+
36+
# query builder
37+
from lmql.api.query_builder import QueryBuilder

src/lmql/language/query_builder.py renamed to src/lmql/api/query_builder.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
from lmql.api import run, run_sync
2+
3+
4+
class QueryExecution:
5+
def __init__(self, query_string):
6+
self.query_string = query_string
7+
8+
async def run(self, *args, **kwargs):
9+
# This method should asynchronously execute the query_string
10+
return await run(self.query_string, *args, **kwargs)
11+
12+
def run_sync(self, *args, **kwargs):
13+
# This method should synchronously execute the query_string
14+
return run_sync(self.query_string, *args, **kwargs)
15+
16+
117
class QueryBuilder:
218
def __init__(self):
319
self.decoder = None
@@ -56,5 +72,7 @@ def build(self):
5672
variable, expr = self.distribution_expr
5773
components.append(f'distribution {variable} in {expr}')
5874

59-
return ' '.join(components)
75+
query_string = ' '.join(components)
76+
# Return an instance of QueryExecution instead of a string
77+
return QueryExecution(query_string)
6078

src/lmql/tests/test_query_builder.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33

44
from lmql.tests.expr_test_utils import run_all_tests
55

6-
from lmql.language.query_builder import QueryBuilder
76

87

98
def test_query_builder():
109
# Example usage:
11-
prompt = (QueryBuilder()
10+
query = (lmql.QueryBuilder()
1211
.set_decoder('argmax')
1312
.set_prompt('What is the capital of France? [ANSWER]')
1413
.set_model('gpt2')
@@ -18,12 +17,12 @@ def test_query_builder():
1817

1918
expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" where len(TOKENS(ANSWER)) < 10 and len(TOKENS(ANSWER)) > 2'
2019

21-
assert expected==prompt, f"Expected: {expected}, got: {prompt}"
22-
out = lmql.run_sync(prompt,)
20+
assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}"
21+
out = query.run_sync()
2322

2423
def test_query_builder_with_dist():
2524

26-
prompt = (QueryBuilder()
25+
query = (lmql.QueryBuilder()
2726
.set_decoder('argmax')
2827
.set_prompt('What is the capital of France? [ANSWER]')
2928
.set_model('gpt2')
@@ -32,8 +31,8 @@ def test_query_builder_with_dist():
3231

3332
expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" distribution ANSWER in ["Paris", "London"]'
3433

35-
assert expected==prompt, f"Expected: {expected}, got: {prompt}"
36-
out = lmql.run_sync(prompt,)
34+
assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}"
35+
out = query.run_sync()
3736

3837
if __name__ == "__main__":
3938
run_all_tests(globals())

0 commit comments

Comments
 (0)