Skip to content

Commit 6c30b2b

Browse files
committed
refactor, Separate LLM prompt template for "write test" task so we can add a template
Part of #350
1 parent 3c368f8 commit 6c30b2b

File tree

2 files changed

+40
-28
lines changed

2 files changed

+40
-28
lines changed

model/llm/llm.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
6363
return m.metaInformation
6464
}
6565

66-
// llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt.
66+
// llmSourceFilePromptContext is the base template context for an LLM generation prompt.
6767
type llmSourceFilePromptContext struct {
6868
// Language holds the programming language name.
6969
Language language.Language
@@ -76,8 +76,14 @@ type llmSourceFilePromptContext struct {
7676
ImportPath string
7777
}
7878

79-
// llmGenerateTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
80-
var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm-generate-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
79+
// llmWriteTestSourceFilePromptContext is the template context for a write test LLM prompt.
80+
type llmWriteTestSourceFilePromptContext struct {
81+
// llmSourceFilePromptContext holds the context for a source file prompt.
82+
llmSourceFilePromptContext
83+
}
84+
85+
// llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
86+
var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-write-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
8187
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code{{ with $testFramework := .Language.TestFramework }} with {{ $testFramework }} as a test framework{{ end }}.
8288
The tests should produce 100 percent code coverage and must compile.
8389
The response must contain only the test code in a fenced code block and nothing else.
@@ -87,14 +93,14 @@ var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm
8793
` + "```" + `
8894
`)))
8995

90-
// llmGenerateTestForFilePrompt returns the prompt for generating an LLM test generation.
91-
func llmGenerateTestForFilePrompt(data *llmSourceFilePromptContext) (message string, err error) {
96+
// llmWriteTestForFilePrompt returns the prompt for generating an LLM test generation.
97+
func llmWriteTestForFilePrompt(data *llmWriteTestSourceFilePromptContext) (message string, err error) {
9298
// Use Linux paths even when running the evaluation on Windows to ensure consistency in prompting.
9399
data.FilePath = filepath.ToSlash(data.FilePath)
94100
data.Code = strings.TrimSpace(data.Code)
95101

96102
var b strings.Builder
97-
if err := llmGenerateTestForFilePromptTemplate.Execute(&b, data); err != nil {
103+
if err := llmWriteTestForFilePromptTemplate.Execute(&b, data); err != nil {
98104
return "", pkgerrors.WithStack(err)
99105
}
100106

@@ -198,12 +204,14 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e
198204

199205
importPath := ctx.Language.ImportPath(ctx.RepositoryPath, ctx.FilePath)
200206

201-
request, err := llmGenerateTestForFilePrompt(&llmSourceFilePromptContext{
202-
Language: ctx.Language,
207+
request, err := llmWriteTestForFilePrompt(&llmWriteTestSourceFilePromptContext{
208+
llmSourceFilePromptContext: llmSourceFilePromptContext{
209+
Language: ctx.Language,
203210

204-
Code: fileContent,
205-
FilePath: ctx.FilePath,
206-
ImportPath: importPath,
211+
Code: fileContent,
212+
FilePath: ctx.FilePath,
213+
ImportPath: importPath,
214+
},
207215
})
208216
if err != nil {
209217
return nil, err

model/llm/llm_test.go

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,14 @@ func TestModelGenerateTestsForFile(t *testing.T) {
8484
func main() {}
8585
`
8686
sourceFilePath := "simple.go"
87-
promptMessage, err := llmGenerateTestForFilePrompt(&llmSourceFilePromptContext{
88-
Language: &golang.Language{},
87+
promptMessage, err := llmWriteTestForFilePrompt(&llmWriteTestSourceFilePromptContext{
88+
llmSourceFilePromptContext: llmSourceFilePromptContext{
89+
Language: &golang.Language{},
8990

90-
Code: bytesutil.StringTrimIndentations(sourceFileContent),
91-
FilePath: sourceFilePath,
92-
ImportPath: "native",
91+
Code: bytesutil.StringTrimIndentations(sourceFileContent),
92+
FilePath: sourceFilePath,
93+
ImportPath: "native",
94+
},
9395
})
9496
require.NoError(t, err)
9597
validate(t, &testCase{
@@ -291,14 +293,14 @@ func TestLLMGenerateTestForFilePrompt(t *testing.T) {
291293
type testCase struct {
292294
Name string
293295

294-
Data *llmSourceFilePromptContext
296+
Data *llmWriteTestSourceFilePromptContext
295297

296298
ExpectedMessage string
297299
}
298300

299301
validate := func(t *testing.T, tc *testCase) {
300302
t.Run(tc.Name, func(t *testing.T) {
301-
actualMessage, actualErr := llmGenerateTestForFilePrompt(tc.Data)
303+
actualMessage, actualErr := llmWriteTestForFilePrompt(tc.Data)
302304
require.NoError(t, actualErr)
303305

304306
assert.Equal(t, tc.ExpectedMessage, actualMessage)
@@ -308,18 +310,20 @@ func TestLLMGenerateTestForFilePrompt(t *testing.T) {
308310
validate(t, &testCase{
309311
Name: "Plain",
310312

311-
Data: &llmSourceFilePromptContext{
312-
Language: &golang.Language{},
313+
Data: &llmWriteTestSourceFilePromptContext{
314+
llmSourceFilePromptContext: llmSourceFilePromptContext{
315+
Language: &golang.Language{},
313316

314-
Code: bytesutil.StringTrimIndentations(`
315-
package increment
317+
Code: bytesutil.StringTrimIndentations(`
318+
package increment
316319
317-
func increment(i int) int
318-
return i + 1
319-
}
320-
`),
321-
FilePath: filepath.Join("path", "to", "increment.go"),
322-
ImportPath: "increment",
320+
func increment(i int) int
321+
return i + 1
322+
}
323+
`),
324+
FilePath: filepath.Join("path", "to", "increment.go"),
325+
ImportPath: "increment",
326+
},
323327
},
324328

325329
ExpectedMessage: bytesutil.StringTrimIndentations(`

0 commit comments

Comments
 (0)