Skip to content

Commit 7c3d0f1

Browse files
committed
fix #32
1 parent c8ca00d commit 7c3d0f1

File tree

7 files changed

+29
-109
lines changed

7 files changed

+29
-109
lines changed

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# rollama (development version)
22

3+
* added support for structured output
4+
35
# rollama 0.2.0
46

57
* added make_query() function to facilitate easier annotation

R/embedding.r

+1-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ embed_text <- function(text,
4949
model_params = model_params) |>
5050
purrr::compact() |>
5151
make_req(server = server,
52-
endpoint = "/api/embeddings",
53-
perform = FALSE)
52+
endpoint = "/api/embeddings")
5453
})
5554

5655
resps <- httr2::req_perform_parallel(reqs, progress = pb)

R/lib.R

+13-24
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ build_req <- function(model,
6060
template = template) |>
6161
purrr::compact() |> # remove NULL values
6262
make_req(server = sample(server, 1, prob = as_prob(names(server))),
63-
endpoint = "/api/chat",
64-
perform = FALSE)
63+
endpoint = "/api/chat")
6564
})
6665
}) |>
6766
unlist(recursive = FALSE)
@@ -70,21 +69,14 @@ build_req <- function(model,
7069
}
7170

7271

73-
make_req <- function(req_data, server, endpoint, perform = TRUE) {
72+
make_req <- function(req_data, server, endpoint) {
7473
r <- httr2::request(server) |>
7574
httr2::req_url_path_append(endpoint) |>
7675
httr2::req_body_json(prep_req_data(req_data), auto_unbox = FALSE) |>
77-
# turn off errors since error messages can't be seen in sub-process
78-
httr2::req_error(is_error = function(resp) FALSE) |>
7976
# see https://github.com/JBGruber/rollama/issues/23
8077
httr2::req_options(timeout_ms = 1000 * 60 * 60 * 24,
8178
connecttimeout_ms = 1000 * 60 * 60 * 24) |>
8279
httr2::req_headers(!!!get_headers())
83-
if (perform) {
84-
r <- r |>
85-
httr2::req_perform() |>
86-
httr2::resp_body_json()
87-
}
8880
return(r)
8981
}
9082

@@ -117,17 +109,7 @@ perform_reqs <- function(reqs, verbose) {
117109
if (length(fails) == length(reqs)) {
118110
cli::cli_abort(fails)
119111
} else if (length(fails) < length(reqs) && length(fails) > 0) {
120-
error_counts <- table(fails)
121-
for (f in names(error_counts)) {
122-
if (error_counts[f] > 2) {
123-
cli::cli_alert_danger("error ({error_counts[f]} times): {f}")
124-
} else {
125-
cli::cli_alert_danger("error: {f}")
126-
}
127-
}
128-
for (f in fails) {
129-
cli::cli_alert_danger(f)
130-
}
112+
throw_error(fails)
131113
}
132114

133115
httr2::resps_successes(resps)
@@ -143,15 +125,23 @@ perform_req <- function(reqs, verbose) {
143125
id <- cli::cli_progress_bar(format = "{cli::pb_spin} {model} {?is/are} thinking",
144126
clear = TRUE)
145127

128+
# turn off errors since error messages can't be seen in sub-process
129+
req <- httr2::req_error(reqs[[1]], is_error = function(resp) FALSE)
130+
146131
rp <- callr::r_bg(httr2::req_perform,
147-
args = list(req = reqs[[1]]),
132+
args = list(req = req),
148133
package = TRUE)
149134

150135
while (rp$is_alive()) {
151136
cli::cli_progress_update(id = id)
152137
Sys.sleep(2 / 100)
153138
}
154-
return(list(rp$get_result()))
139+
resp <- rp$get_result()
140+
res <- httr2::resp_body_json(resp)
141+
if (purrr::pluck_exists(res, "error")) {
142+
cli::cli_abort(purrr::pluck(res, "error"))
143+
}
144+
return(list(res))
155145
}
156146

157147
list(httr2::req_perform(reqs[[1]]))
@@ -207,7 +197,6 @@ pgrs <- function(resp) {
207197
jsonlite::stream_in(verbose = FALSE, simplifyVector = FALSE)
208198

209199
status <- setdiff(status, the$str_prgs$pb_done)
210-
211200
for (s in status) {
212201
status_message <- purrr::pluck(s, "status")
213202
if (!purrr::pluck_exists(s, "total")) {

R/utils.r

+11
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,14 @@ check_conversation <- function(msg) {
113113
"least one user message. See {.help query}."))
114114
return(msg)
115115
}
116+
117+
throw_error <- function(fails) {
118+
error_counts <- table(fails)
119+
for (f in names(error_counts)) {
120+
if (error_counts[f] > 2) {
121+
cli::cli_alert_danger("error ({error_counts[f]} times): {f}")
122+
} else {
123+
cli::cli_alert_danger("error: {f}")
124+
}
125+
}
126+
}

man/rollama-package.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88

99
library(testthat)
1010
library(rollama)
11-
11+
options(rollama_server = "http://192.168.2.29:11434")
1212
test_check("rollama")

tests/testthat/_snaps/verbose.md

-81
This file was deleted.

0 commit comments

Comments
 (0)