Skip to content

Commit 8ca624c

Browse files
authored
Support multiple variables in run (#64)
2 parents 223a619 + c7593e3 commit 8ca624c

File tree

4 files changed

+175
-8
lines changed

4 files changed

+175
-8
lines changed

cmd/run/run.go

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,16 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
204204
If you know which model you want to run inference with, you can run the request in a single command
205205
as %[1]sgh models run [model] [prompt]%[1]s
206206
207+
When using prompt files, you can pass template variables using the %[1]s--var%[1]s flag:
208+
%[1]sgh models run --file prompt.yml --var name=Alice --var topic=AI%[1]s
209+
207210
The return value will be the response to your prompt from the selected model.
208211
`, "`"),
209-
Example: "gh models run openai/gpt-4o-mini \"how many types of hyena are there?\"",
210-
Args: cobra.ArbitraryArgs,
212+
Example: heredoc.Doc(`
213+
gh models run openai/gpt-4o-mini "how many types of hyena are there?"
214+
gh models run --file prompt.yml --var name=Alice --var topic="machine learning"
215+
`),
216+
Args: cobra.ArbitraryArgs,
211217
RunE: func(cmd *cobra.Command, args []string) error {
212218
filePath, _ := cmd.Flags().GetString("file")
213219
var pf *prompt.File
@@ -223,6 +229,12 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
223229
}
224230
}
225231

232+
// Parse template variables from flags
233+
templateVars, err := parseTemplateVariables(cmd.Flags())
234+
if err != nil {
235+
return err
236+
}
237+
226238
cmdHandler := newRunCommandHandler(cmd, cfg, args)
227239
if cmdHandler == nil {
228240
return nil
@@ -270,16 +282,22 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
270282
}
271283

272284
// If there is no prompt file, add the initialPrompt to the conversation.
273-
// If a prompt file is passed, load the messages from the file, templating {{input}}
274-
// using the initialPrompt.
285+
// If a prompt file is passed, load the messages from the file, templating variables
286+
// using the provided template variables and initialPrompt.
275287
if pf == nil {
276288
conversation.AddMessage(azuremodels.ChatMessageRoleUser, initialPrompt)
277289
} else {
278290
interactiveMode = false
279291

280-
// Template the messages with the input
281-
templateData := map[string]interface{}{
282-
"input": initialPrompt,
292+
// Template the messages with the variables
293+
templateData := make(map[string]interface{})
294+
295+
// Add the input variable (backward compatibility)
296+
templateData["input"] = initialPrompt
297+
298+
// Add custom variables
299+
for key, value := range templateVars {
300+
templateData[key] = value
283301
}
284302

285303
for _, m := range pf.Messages {
@@ -385,6 +403,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
385403
}
386404

387405
cmd.Flags().String("file", "", "Path to a .prompt.yml file.")
406+
cmd.Flags().StringSlice("var", []string{}, "Template variables for prompt files (can be used multiple times: --var name=value)")
388407
cmd.Flags().String("max-tokens", "", "Limit the maximum tokens for the model response.")
389408
cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.")
390409
cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.")
@@ -393,6 +412,43 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
393412
return cmd
394413
}
395414

415+
// parseTemplateVariables parses template variables from the --var flags
416+
func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) {
417+
varFlags, err := flags.GetStringSlice("var")
418+
if err != nil {
419+
return nil, err
420+
}
421+
422+
templateVars := make(map[string]string)
423+
for _, varFlag := range varFlags {
424+
// Handle empty strings
425+
if strings.TrimSpace(varFlag) == "" {
426+
continue
427+
}
428+
429+
parts := strings.SplitN(varFlag, "=", 2)
430+
if len(parts) != 2 {
431+
return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag)
432+
}
433+
434+
key := strings.TrimSpace(parts[0])
435+
value := parts[1] // Don't trim value to preserve intentional whitespace
436+
437+
if key == "" {
438+
return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag)
439+
}
440+
441+
// Check for duplicate keys
442+
if _, exists := templateVars[key]; exists {
443+
return nil, fmt.Errorf("duplicate variable key '%s'", key)
444+
}
445+
446+
templateVars[key] = value
447+
}
448+
449+
return templateVars, nil
450+
}
451+
396452
type runCommandHandler struct {
397453
ctx context.Context
398454
cfg *command.Config
@@ -445,7 +501,7 @@ func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSumm
445501
}
446502

447503
func validateModelName(modelName string, models []*azuremodels.ModelSummary) (string, error) {
448-
noMatchErrorMessage := "The specified model name is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively."
504+
noMatchErrorMessage := fmt.Sprintf("The specified model '%s' is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively.", modelName)
449505

450506
if modelName == "" {
451507
return "", errors.New(noMatchErrorMessage)

cmd/run/run_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/github/gh-models/internal/sse"
1212
"github.com/github/gh-models/pkg/command"
1313
"github.com/github/gh-models/pkg/util"
14+
"github.com/spf13/pflag"
1415
"github.com/stretchr/testify/require"
1516
)
1617

@@ -331,3 +332,74 @@ messages:
331332
require.Equal(t, "User message", *capturedReq.Messages[1].Content)
332333
})
333334
}
335+
336+
func TestParseTemplateVariables(t *testing.T) {
337+
tests := []struct {
338+
name string
339+
varFlags []string
340+
expected map[string]string
341+
expectErr bool
342+
}{
343+
{
344+
name: "empty vars",
345+
varFlags: []string{},
346+
expected: map[string]string{},
347+
},
348+
{
349+
name: "single var",
350+
varFlags: []string{"name=John"},
351+
expected: map[string]string{"name": "John"},
352+
},
353+
{
354+
name: "multiple vars",
355+
varFlags: []string{"name=John", "age=25", "city=New York"},
356+
expected: map[string]string{"name": "John", "age": "25", "city": "New York"},
357+
},
358+
{
359+
name: "multi-word values",
360+
varFlags: []string{"full_name=John Smith", "description=A senior developer"},
361+
expected: map[string]string{"full_name": "John Smith", "description": "A senior developer"},
362+
},
363+
{
364+
name: "value with equals sign",
365+
varFlags: []string{"equation=x = y + 2"},
366+
expected: map[string]string{"equation": "x = y + 2"},
367+
},
368+
{
369+
name: "empty strings are skipped",
370+
varFlags: []string{"", "name=John", " "},
371+
expected: map[string]string{"name": "John"},
372+
},
373+
{
374+
name: "invalid format - no equals",
375+
varFlags: []string{"invalid"},
376+
expectErr: true,
377+
},
378+
{
379+
name: "invalid format - empty key",
380+
varFlags: []string{"=value"},
381+
expectErr: true,
382+
},
383+
{
384+
name: "duplicate keys",
385+
varFlags: []string{"name=John", "name=Jane"},
386+
expectErr: true,
387+
},
388+
}
389+
390+
for _, tt := range tests {
391+
t.Run(tt.name, func(t *testing.T) {
392+
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
393+
flags.StringSlice("var", tt.varFlags, "test flag")
394+
395+
result, err := parseTemplateVariables(flags)
396+
397+
if tt.expectErr {
398+
require.Error(t, err)
399+
} else {
400+
require.NoError(t, err)
401+
require.Equal(t, tt.expected, result)
402+
}
403+
})
404+
}
405+
}

examples/advanced_template_prompt.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Advanced Template Variables Example
2+
name: Advanced Template Example
3+
description: Demonstrates advanced usage of template variables
4+
model: openai/gpt-4o-mini
5+
modelParameters:
6+
temperature: 0.7
7+
maxTokens: 300
8+
messages:
9+
- role: system
10+
content: |
11+
You are {{assistant_persona}}, a {{expertise_level}} {{domain}} specialist.
12+
Your communication style should be {{tone}} and {{formality_level}}.
13+
14+
Context: You are helping {{user_name}} who works as a {{user_role}} at {{company}}.
15+
16+
- role: user
17+
content: |
18+
Hello! I'm {{user_name}} from {{company}}.
19+
20+
Background: {{background_info}}
21+
22+
Question: {{input}}
23+
24+
Please provide your response considering my role as {{user_role}} and
25+
make it appropriate for a {{formality_level}} setting.
26+
27+
Additional context: {{additional_context}}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Example demonstrating arbitrary template variables
2+
name: Template Variables Example
3+
description: Shows how to use custom template variables in prompt files
4+
model: openai/gpt-4o
5+
modelParameters:
6+
temperature: 0.3
7+
maxTokens: 200
8+
messages:
9+
- role: system
10+
content: You are {{persona}}, a helpful assistant specializing in {{domain}}.
11+
- role: user
12+
content: Hello {{name}}! I need help with {{topic}}. {{input}}

0 commit comments

Comments
 (0)