18
18
# ' @examplesIf has_credentials("openai")
19
19
# ' chat <- chat_openai(echo = TRUE)
20
20
# ' chat$chat("Tell me a funny joke")
21
- Chat <- R6 :: R6Class(" Chat" ,
21
+ Chat <- R6 :: R6Class(
22
+ " Chat" ,
22
23
public = list (
23
24
# ' @param provider A provider object.
24
25
# ' @param turns An unnamed list of turns to start the chat with (i.e.,
@@ -108,30 +109,34 @@ Chat <- R6::R6Class("Chat",
108
109
invisible (self )
109
110
},
110
111
111
- # ' @description A data frame with a `tokens` column that proides the
112
- # ' number of input tokens used by user turns and the number of
112
+ # ' @description A data frame with a `tokens` column that proides the
113
+ # ' number of input tokens used by user turns and the number of
113
114
# ' output tokens used by assistant turns.
114
- # ' @param include_system_prompt Whether to include the system prompt in
115
+ # ' @param include_system_prompt Whether to include the system prompt in
115
116
# ' the turns (if any exists).
116
117
tokens = function (include_system_prompt = FALSE ) {
117
118
turns <- self $ get_turns(include_system_prompt = FALSE )
118
119
assistant_turns <- keep(turns , function (x ) x @ role == " assistant" )
119
120
120
121
n <- length(assistant_turns )
121
- tokens <- t(vapply(assistant_turns , function (turn ) turn @ tokens , double(2 )))
122
+ tokens <- t(vapply(
123
+ assistant_turns ,
124
+ function (turn ) turn @ tokens ,
125
+ double(2 )
126
+ ))
122
127
if (n > 1 ) {
123
128
# Compute just the new tokens
124
- tokens [- 1 , 1 ] <- tokens [seq(2 , n ), 1 ] -
129
+ tokens [- 1 , 1 ] <- tokens [seq(2 , n ), 1 ] -
125
130
(tokens [seq(1 , n - 1 ), 1 ] + tokens [seq(1 , n - 1 ), 2 ])
126
131
}
127
132
# collapse into a single vector
128
133
tokens_v <- c(t(tokens ))
129
-
134
+
130
135
tokens_df <- data.frame (
131
136
role = rep(c(" user" , " assistant" ), times = n ),
132
137
tokens = tokens_v
133
138
)
134
-
139
+
135
140
if (include_system_prompt && private $ has_system_prompt()) {
136
141
# How do we compute this?
137
142
tokens_df <- rbind(data.frame (role = " system" , tokens = 0 ), tokens_df )
@@ -149,7 +154,8 @@ Chat <- R6::R6Class("Chat",
149
154
role <- arg_match(role )
150
155
151
156
n <- length(private $ .turns )
152
- switch (role ,
157
+ switch (
158
+ role ,
153
159
system = if (private $ has_system_prompt()) private $ .turns [[1 ]],
154
160
assistant = if (n > 1 ) private $ .turns [[n ]],
155
161
user = if (n > 1 ) private $ .turns [[n - 1 ]]
@@ -169,7 +175,11 @@ Chat <- R6::R6Class("Chat",
169
175
170
176
# Returns a single turn (the final response from the assistant), even if
171
177
# multiple rounds of back and forth happened.
172
- coro :: collect(private $ chat_impl(turn , stream = echo != " none" , echo = echo ))
178
+ coro :: collect(private $ chat_impl(
179
+ turn ,
180
+ stream = echo != " none" ,
181
+ echo = echo
182
+ ))
173
183
174
184
text <- self $ last_turn()@ text
175
185
if (echo == " none" ) text else invisible (text )
@@ -277,7 +287,12 @@ Chat <- R6::R6Class("Chat",
277
287
278
288
map(json , function (json ) {
279
289
turn <- value_turn(private $ provider , json , has_type = TRUE )
280
- extract_data(turn , type , convert = convert , needs_wrapper = needs_wrapper )
290
+ extract_data(
291
+ turn ,
292
+ type ,
293
+ convert = convert ,
294
+ needs_wrapper = needs_wrapper
295
+ )
281
296
})
282
297
},
283
298
@@ -415,16 +430,29 @@ Chat <- R6::R6Class("Chat",
415
430
if (private $ .turns [[i ]]@ role != " user" ) {
416
431
private $ .turns [[i + 1 ]] <- Turn(" user" , contents )
417
432
} else {
418
- private $ .turns [[i ]]@ contents <- c(private $ .turns [[i ]]@ contents , contents )
433
+ private $ .turns [[i ]]@ contents <- c(
434
+ private $ .turns [[i ]]@ contents ,
435
+ contents
436
+ )
419
437
}
420
438
invisible (self )
421
439
},
422
440
423
441
# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
424
442
# complete assistant turns.
425
- chat_impl = generator_method(function (self , private , user_turn , stream , echo ) {
426
- while (! is.null(user_turn )) {
427
- for (chunk in private $ submit_turns(user_turn , stream = stream , echo = echo )) {
443
+ chat_impl = generator_method(function (
444
+ self ,
445
+ private ,
446
+ user_turn ,
447
+ stream ,
448
+ echo
449
+ ) {
450
+ while (! is.null(user_turn )) {
451
+ for (chunk in private $ submit_turns(
452
+ user_turn ,
453
+ stream = stream ,
454
+ echo = echo
455
+ )) {
428
456
yield(chunk )
429
457
}
430
458
user_turn <- private $ invoke_tools()
@@ -433,9 +461,19 @@ Chat <- R6::R6Class("Chat",
433
461
434
462
# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
435
463
# complete assistant turns.
436
- chat_impl_async = async_generator_method(function (self , private , user_turn , stream , echo ) {
437
- while (! is.null(user_turn )) {
438
- for (chunk in await_each(private $ submit_turns_async(user_turn , stream = stream , echo = echo ))) {
464
+ chat_impl_async = async_generator_method(function (
465
+ self ,
466
+ private ,
467
+ user_turn ,
468
+ stream ,
469
+ echo
470
+ ) {
471
+ while (! is.null(user_turn )) {
472
+ for (chunk in await_each(private $ submit_turns_async(
473
+ user_turn ,
474
+ stream = stream ,
475
+ echo = echo
476
+ ))) {
439
477
yield(chunk )
440
478
}
441
479
user_turn <- await(private $ invoke_tools_async())
@@ -447,8 +485,14 @@ Chat <- R6::R6Class("Chat",
447
485
448
486
# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
449
487
# complete assistant turns.
450
- submit_turns = generator_method(function (self , private , user_turn , stream , echo , type = NULL ) {
451
-
488
+ submit_turns = generator_method(function (
489
+ self ,
490
+ private ,
491
+ user_turn ,
492
+ stream ,
493
+ echo ,
494
+ type = NULL
495
+ ) {
452
496
if (echo == " all" ) {
453
497
cat_line(format(user_turn ), prefix = " > " )
454
498
}
@@ -490,7 +534,11 @@ Chat <- R6::R6Class("Chat",
490
534
cat_line(formatted , prefix = " < " )
491
535
}
492
536
} else {
493
- turn <- value_turn(private $ provider , response , has_type = ! is.null(type ))
537
+ turn <- value_turn(
538
+ private $ provider ,
539
+ response ,
540
+ has_type = ! is.null(type )
541
+ )
494
542
text <- turn @ text
495
543
if (! is.null(text )) {
496
544
text <- paste0(text , " \n " )
@@ -508,7 +556,14 @@ Chat <- R6::R6Class("Chat",
508
556
509
557
# If stream = TRUE, yields completion deltas. If stream = FALSE, yields
510
558
# complete assistant turns.
511
- submit_turns_async = async_generator_method(function (self , private , user_turn , stream , echo , type = NULL ) {
559
+ submit_turns_async = async_generator_method(function (
560
+ self ,
561
+ private ,
562
+ user_turn ,
563
+ stream ,
564
+ echo ,
565
+ type = NULL
566
+ ) {
512
567
response <- chat_perform(
513
568
provider = private $ provider ,
514
569
mode = if (stream ) " async-stream" else " async-value" ,
@@ -582,29 +637,36 @@ is_chat <- function(x) {
582
637
# ' @export
583
638
print.Chat <- function (x , ... ) {
584
639
turns <- x $ get_turns(include_system_prompt = TRUE )
585
-
640
+
586
641
tokens <- x $ tokens(include_system_prompt = TRUE )
587
642
tokens_user <- sum(tokens $ tokens [tokens $ role == " user" ])
588
643
tokens_assistant <- sum(tokens $ tokens [tokens $ role == " assistant" ])
589
644
590
645
cat(paste0(
591
- " <Chat" ,
592
- " turns=" , length(turns ),
593
- " tokens=" , tokens_user , " /" , tokens_assistant ,
646
+ " <Chat" ,
647
+ " turns=" ,
648
+ length(turns ),
649
+ " tokens=" ,
650
+ tokens_user ,
651
+ " /" ,
652
+ tokens_assistant ,
594
653
" >\n "
595
654
))
596
-
655
+
597
656
for (i in seq_along(turns )) {
598
657
turn <- turns [[i ]]
599
-
600
- color <- switch (turn @ role ,
658
+
659
+ color <- switch (
660
+ turn @ role ,
601
661
user = cli :: col_blue ,
602
662
assistant = cli :: col_green ,
603
663
system = cli :: col_br_white ,
604
664
identity
605
665
)
606
666
607
- cli :: cat_rule(cli :: format_inline(" {color(turn@role)} [{tokens$tokens[[i]]}]" ))
667
+ cli :: cat_rule(cli :: format_inline(
668
+ " {color(turn@role)} [{tokens$tokens[[i]]}]"
669
+ ))
608
670
for (content in turn @ contents ) {
609
671
cat_line(format(content ))
610
672
}
@@ -613,7 +675,10 @@ print.Chat <- function(x, ...) {
613
675
invisible (x )
614
676
}
615
677
616
- method(contents_markdown , new_S3_class(" Chat" )) <- function (content , heading_level = 2 ) {
678
+ method(contents_markdown , new_S3_class(" Chat" )) <- function (
679
+ content ,
680
+ heading_level = 2
681
+ ) {
617
682
turns <- content $ get_turns()
618
683
if (length(turns ) == 0 ) {
619
684
return (" " )
@@ -628,7 +693,7 @@ method(contents_markdown, new_S3_class("Chat")) <- function(content, heading_lev
628
693
res [i ] <- glue :: glue(" {hh} {role}\n\n {contents_markdown(turns[[i]])}" )
629
694
}
630
695
631
- paste(res , collapse = " \n\n " )
696
+ paste(res , collapse = " \n\n " )
632
697
}
633
698
634
699
extract_data <- function (turn , type , convert = TRUE , needs_wrapper = FALSE ) {
0 commit comments