-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfunction_calling.rb
executable file
·106 lines (85 loc) · 3.23 KB
/
function_calling.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env ruby
# frozen_string_literal: true
require 'bundler/setup'
require 'dotenv/load'
require 'mistral'
# Assuming we have the following data
data = {
'transaction_id' => %w[T1001 T1002 T1003 T1004 T1005],
'customer_id' => %w[C001 C002 C003 C002 C001],
'payment_amount' => [125.50, 89.99, 120.00, 54.30, 210.20],
'payment_date' => %w[2021-10-05 2021-10-06 2021-10-07 2021-10-05 2021-10-08],
'payment_status' => %w[Paid Unpaid Paid Paid Pending]
}
def retrieve_payment_status(data, transaction_id)
data['transaction_id'].each_with_index do |r, i|
return { status: data['payment_status'][i] }.to_json if r == transaction_id
end
{ status: 'Error - transaction id not found' }.to_json
end
def retrieve_payment_date(data, transaction_id)
data['transaction_id'].each_with_index do |r, i|
return { date: data['payment_date'][i] }.to_json if r == transaction_id
end
{ status: 'Error - transaction id not found' }.to_json
end
names_to_functions = {
'retrieve_payment_status' => ->(transaction_id) { retrieve_payment_status(data, transaction_id) },
'retrieve_payment_date' => ->(transaction_id) { retrieve_payment_date(data, transaction_id) }
}
tools = [
{
'type' => 'function',
'function' => Mistral::Function.new(
name: 'retrieve_payment_status',
description: 'Get payment status of a transaction id',
parameters: {
'type' => 'object',
'required' => ['transaction_id'],
'properties' => {
'transaction_id' => {
'type' => 'string',
'description' => 'The transaction id.'
}
}
}
)
},
{
'type' => 'function',
'function' => Mistral::Function.new(
name: 'retrieve_payment_date',
description: 'Get payment date of a transaction id',
parameters: {
'type' => 'object',
'required' => ['transaction_id'],
'properties' => {
'transaction_id' => {
'type' => 'string',
'description' => 'The transaction id.'
}
}
}
)
}
]
api_key = ENV.fetch('MISTRAL_API_KEY')
model = 'mistral-small-latest'
client = Mistral::Client.new(api_key: api_key)
messages = [Mistral::ChatMessage.new(role: 'user', content: "What's the status of my transaction?")]
response = client.chat(model: model, messages: messages, tools: tools)
puts response.choices[0].message.content
messages << Mistral::ChatMessage.new(role: 'assistant', content: response.choices[0].message.content)
messages << Mistral::ChatMessage.new(role: 'user', content: 'My transaction ID is T1001.')
response = client.chat(model: model, messages: messages, tools: tools)
tool_call = response.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
function_params = JSON.parse(tool_call.function.arguments)
puts "calling function_name: #{function_name}, with function_params: #{function_params}"
function_result = names_to_functions[function_name].call(function_params['transaction_id'])
messages << response.choices[0].message
messages << Mistral::ChatMessage.new(
role: 'tool', name: function_name, content: function_result, tool_call_id: tool_call.id
)
response = client.chat(model: model, messages: messages, tools: tools)
puts response.choices[0].message.content