Skip to content

Commit 388678d

Browse files
authored
Re-format with air (#387)
And configure positron to automatically apply
1 parent 5f07aa7 commit 388678d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1035
-581
lines changed

.Rbuildignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111
_cache/
1212
^cran-comments\.md$
1313
^CRAN-SUBMISSION$
14+
^[\.]?air\.toml$
15+
^\.vscode$

.vscode/extensions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"recommendations": [
3+
"Posit.air-vscode"
4+
]
5+
}

.vscode/settings.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"[r]": {
3+
"editor.formatOnSave": true,
4+
"editor.defaultFormatter": "Posit.air-vscode"
5+
}
6+
}

R/chat-parallel.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ parallel_requests <- function(
1313
tools = tools,
1414
stream = FALSE,
1515
type = type
16-
)}
17-
)
16+
)
17+
})
1818
reqs <- map(reqs, function(req) {
1919
req_throttle(req, capacity = rpm, fill_time_s = 60)
2020
})

R/chat.R

Lines changed: 97 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ NULL
1818
#' @examplesIf has_credentials("openai")
1919
#' chat <- chat_openai(echo = TRUE)
2020
#' chat$chat("Tell me a funny joke")
21-
Chat <- R6::R6Class("Chat",
21+
Chat <- R6::R6Class(
22+
"Chat",
2223
public = list(
2324
#' @param provider A provider object.
2425
#' @param turns An unnamed list of turns to start the chat with (i.e.,
@@ -108,30 +109,34 @@ Chat <- R6::R6Class("Chat",
108109
invisible(self)
109110
},
110111

111-
#' @description A data frame with a `tokens` column that proides the
112-
#' number of input tokens used by user turns and the number of
112+
#' @description A data frame with a `tokens` column that proides the
113+
#' number of input tokens used by user turns and the number of
113114
#' output tokens used by assistant turns.
114-
#' @param include_system_prompt Whether to include the system prompt in
115+
#' @param include_system_prompt Whether to include the system prompt in
115116
#' the turns (if any exists).
116117
tokens = function(include_system_prompt = FALSE) {
117118
turns <- self$get_turns(include_system_prompt = FALSE)
118119
assistant_turns <- keep(turns, function(x) x@role == "assistant")
119120

120121
n <- length(assistant_turns)
121-
tokens <- t(vapply(assistant_turns, function(turn) turn@tokens, double(2)))
122+
tokens <- t(vapply(
123+
assistant_turns,
124+
function(turn) turn@tokens,
125+
double(2)
126+
))
122127
if (n > 1) {
123128
# Compute just the new tokens
124-
tokens[-1, 1] <- tokens[seq(2, n), 1] -
129+
tokens[-1, 1] <- tokens[seq(2, n), 1] -
125130
(tokens[seq(1, n - 1), 1] + tokens[seq(1, n - 1), 2])
126131
}
127132
# collapse into a single vector
128133
tokens_v <- c(t(tokens))
129-
134+
130135
tokens_df <- data.frame(
131136
role = rep(c("user", "assistant"), times = n),
132137
tokens = tokens_v
133138
)
134-
139+
135140
if (include_system_prompt && private$has_system_prompt()) {
136141
# How do we compute this?
137142
tokens_df <- rbind(data.frame(role = "system", tokens = 0), tokens_df)
@@ -149,7 +154,8 @@ Chat <- R6::R6Class("Chat",
149154
role <- arg_match(role)
150155

151156
n <- length(private$.turns)
152-
switch(role,
157+
switch(
158+
role,
153159
system = if (private$has_system_prompt()) private$.turns[[1]],
154160
assistant = if (n > 1) private$.turns[[n]],
155161
user = if (n > 1) private$.turns[[n - 1]]
@@ -169,7 +175,11 @@ Chat <- R6::R6Class("Chat",
169175

170176
# Returns a single turn (the final response from the assistant), even if
171177
# multiple rounds of back and forth happened.
172-
coro::collect(private$chat_impl(turn, stream = echo != "none", echo = echo))
178+
coro::collect(private$chat_impl(
179+
turn,
180+
stream = echo != "none",
181+
echo = echo
182+
))
173183

174184
text <- self$last_turn()@text
175185
if (echo == "none") text else invisible(text)
@@ -277,7 +287,12 @@ Chat <- R6::R6Class("Chat",
277287

278288
map(json, function(json) {
279289
turn <- value_turn(private$provider, json, has_type = TRUE)
280-
extract_data(turn, type, convert = convert, needs_wrapper = needs_wrapper)
290+
extract_data(
291+
turn,
292+
type,
293+
convert = convert,
294+
needs_wrapper = needs_wrapper
295+
)
281296
})
282297
},
283298

@@ -415,16 +430,29 @@ Chat <- R6::R6Class("Chat",
415430
if (private$.turns[[i]]@role != "user") {
416431
private$.turns[[i + 1]] <- Turn("user", contents)
417432
} else {
418-
private$.turns[[i]]@contents <- c(private$.turns[[i]]@contents, contents)
433+
private$.turns[[i]]@contents <- c(
434+
private$.turns[[i]]@contents,
435+
contents
436+
)
419437
}
420438
invisible(self)
421439
},
422440

423441
# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
424442
# complete assistant turns.
425-
chat_impl = generator_method(function(self, private, user_turn, stream, echo) {
426-
while(!is.null(user_turn)) {
427-
for (chunk in private$submit_turns(user_turn, stream = stream, echo = echo)) {
443+
chat_impl = generator_method(function(
444+
self,
445+
private,
446+
user_turn,
447+
stream,
448+
echo
449+
) {
450+
while (!is.null(user_turn)) {
451+
for (chunk in private$submit_turns(
452+
user_turn,
453+
stream = stream,
454+
echo = echo
455+
)) {
428456
yield(chunk)
429457
}
430458
user_turn <- private$invoke_tools()
@@ -433,9 +461,19 @@ Chat <- R6::R6Class("Chat",
433461

434462
# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
435463
# complete assistant turns.
436-
chat_impl_async = async_generator_method(function(self, private, user_turn, stream, echo) {
437-
while(!is.null(user_turn)) {
438-
for (chunk in await_each(private$submit_turns_async(user_turn, stream = stream, echo = echo))) {
464+
chat_impl_async = async_generator_method(function(
465+
self,
466+
private,
467+
user_turn,
468+
stream,
469+
echo
470+
) {
471+
while (!is.null(user_turn)) {
472+
for (chunk in await_each(private$submit_turns_async(
473+
user_turn,
474+
stream = stream,
475+
echo = echo
476+
))) {
439477
yield(chunk)
440478
}
441479
user_turn <- await(private$invoke_tools_async())
@@ -447,8 +485,14 @@ Chat <- R6::R6Class("Chat",
447485

448486
# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
449487
# complete assistant turns.
450-
submit_turns = generator_method(function(self, private, user_turn, stream, echo, type = NULL) {
451-
488+
submit_turns = generator_method(function(
489+
self,
490+
private,
491+
user_turn,
492+
stream,
493+
echo,
494+
type = NULL
495+
) {
452496
if (echo == "all") {
453497
cat_line(format(user_turn), prefix = "> ")
454498
}
@@ -490,7 +534,11 @@ Chat <- R6::R6Class("Chat",
490534
cat_line(formatted, prefix = "< ")
491535
}
492536
} else {
493-
turn <- value_turn(private$provider, response, has_type = !is.null(type))
537+
turn <- value_turn(
538+
private$provider,
539+
response,
540+
has_type = !is.null(type)
541+
)
494542
text <- turn@text
495543
if (!is.null(text)) {
496544
text <- paste0(text, "\n")
@@ -508,7 +556,14 @@ Chat <- R6::R6Class("Chat",
508556

509557
# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
510558
# complete assistant turns.
511-
submit_turns_async = async_generator_method(function(self, private, user_turn, stream, echo, type = NULL) {
559+
submit_turns_async = async_generator_method(function(
560+
self,
561+
private,
562+
user_turn,
563+
stream,
564+
echo,
565+
type = NULL
566+
) {
512567
response <- chat_perform(
513568
provider = private$provider,
514569
mode = if (stream) "async-stream" else "async-value",
@@ -582,29 +637,36 @@ is_chat <- function(x) {
582637
#' @export
583638
print.Chat <- function(x, ...) {
584639
turns <- x$get_turns(include_system_prompt = TRUE)
585-
640+
586641
tokens <- x$tokens(include_system_prompt = TRUE)
587642
tokens_user <- sum(tokens$tokens[tokens$role == "user"])
588643
tokens_assistant <- sum(tokens$tokens[tokens$role == "assistant"])
589644

590645
cat(paste0(
591-
"<Chat",
592-
" turns=", length(turns),
593-
" tokens=", tokens_user, "/", tokens_assistant,
646+
"<Chat",
647+
" turns=",
648+
length(turns),
649+
" tokens=",
650+
tokens_user,
651+
"/",
652+
tokens_assistant,
594653
">\n"
595654
))
596-
655+
597656
for (i in seq_along(turns)) {
598657
turn <- turns[[i]]
599-
600-
color <- switch(turn@role,
658+
659+
color <- switch(
660+
turn@role,
601661
user = cli::col_blue,
602662
assistant = cli::col_green,
603663
system = cli::col_br_white,
604664
identity
605665
)
606666

607-
cli::cat_rule(cli::format_inline("{color(turn@role)} [{tokens$tokens[[i]]}]"))
667+
cli::cat_rule(cli::format_inline(
668+
"{color(turn@role)} [{tokens$tokens[[i]]}]"
669+
))
608670
for (content in turn@contents) {
609671
cat_line(format(content))
610672
}
@@ -613,7 +675,10 @@ print.Chat <- function(x, ...) {
613675
invisible(x)
614676
}
615677

616-
method(contents_markdown, new_S3_class("Chat")) <- function(content, heading_level = 2) {
678+
method(contents_markdown, new_S3_class("Chat")) <- function(
679+
content,
680+
heading_level = 2
681+
) {
617682
turns <- content$get_turns()
618683
if (length(turns) == 0) {
619684
return("")
@@ -628,7 +693,7 @@ method(contents_markdown, new_S3_class("Chat")) <- function(content, heading_lev
628693
res[i] <- glue::glue("{hh} {role}\n\n{contents_markdown(turns[[i]])}")
629694
}
630695

631-
paste(res, collapse="\n\n")
696+
paste(res, collapse = "\n\n")
632697
}
633698

634699
extract_data <- function(turn, type, convert = TRUE, needs_wrapper = FALSE) {

R/content-tools.R

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,21 @@ invoke_tool <- function(fun, arguments, id) {
5959
)
6060
}
6161

62-
on_load(invoke_tool_async <- coro::async(function(fun, arguments, id) {
63-
if (is.null(fun)) {
64-
return(ContentToolResult(id = id, error = "Unknown tool"))
65-
}
66-
67-
tryCatch(
68-
{
69-
result <- await(do.call(fun, arguments))
70-
ContentToolResult(id, result)
71-
},
72-
error = function(e) {
73-
# TODO: We need to report this somehow; it's way too hidden from the user
74-
ContentToolResult(id, error = conditionMessage(e))
62+
on_load(
63+
invoke_tool_async <- coro::async(function(fun, arguments, id) {
64+
if (is.null(fun)) {
65+
return(ContentToolResult(id = id, error = "Unknown tool"))
7566
}
76-
)
77-
}))
67+
68+
tryCatch(
69+
{
70+
result <- await(do.call(fun, arguments))
71+
ContentToolResult(id, result)
72+
},
73+
error = function(e) {
74+
# TODO: We need to report this somehow; it's way too hidden from the user
75+
ContentToolResult(id, error = conditionMessage(e))
76+
}
77+
)
78+
})
79+
)

0 commit comments

Comments
 (0)