Skip to content

Commit 586ad4b

Browse files
committed
feat(LMES): Support custom template and prompt
Expand the LMEvalJob CRD to support custom templates and system prompts. This is mainly for custom unitxt task recipes. Now, users can use the `template` and `systemPrompt` fields under the `taskRecipes` to specify the custom template and system prompt. Signed-off-by: Yihong Wang <[email protected]>
1 parent 6ec0cdf commit 586ad4b

File tree

9 files changed

+420
-59
lines changed

9 files changed

+420
-59
lines changed

Dockerfile.lmes-job

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ RUN curl -L https://github.com/opendatahub-io/lm-evaluation-harness/archive/refs
2020

2121
RUN python -c 'from lm_eval.tasks.unitxt import task; import os.path; print("class: !function " + task.__file__.replace("task.py", "task.Unitxt"))' > ./my_tasks/unitxt
2222

23-
ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src/lm-evaluation-harness:/opt/app-root/src:/opt/app-root/src/server
23+
ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src:/opt/app-root/src/server
2424
ENV HF_HOME=/opt/app-root/src/hf_home
25-
ENV UNITXT_ARTIFACTORIES=/opt/app-root/src/my_catalogs
25+
ENV UNITXT_CATALOGS=/opt/app-root/src/my_catalogs
2626

2727
USER 65532:65532
2828
CMD ["/opt/app-root/bin/python"]

api/lmes/v1alpha1/lmevaljob_types.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,38 @@ type Card struct {
7676
Custom string `json:"custom,omitempty"`
7777
}
7878

79+
type Template struct {
80+
// Unitxt template ID
81+
// +optional
82+
Name string `json:"name,omitempty"`
83+
// A JSON string for a custom unitxt template.
84+
// Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html
85+
// to compose a custom template, store it as a JSON file by calling the
86+
// add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
87+
// and use the JSON content as the value here.
88+
// +optional
89+
Custom string `json:"custom,omitempty"`
90+
}
91+
92+
type SystemPrompt struct {
93+
// Unitxt System Prompt id
94+
Name string `json:"name,omitempty"`
95+
// A custom system prompt string
96+
Custom string `json:"custom,omitempty"`
97+
}
98+
7999
// Use a task recipe to form a custom task. It maps to the Unitxt Recipe
80100
// Find details of the Unitxt Recipe here:
81101
// https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
82102
type TaskRecipe struct {
83103
// The Unitxt dataset card
84104
Card Card `json:"card"`
85105
// The Unitxt template
86-
Template string `json:"template"`
106+
// +optional
107+
Template *Template `json:"template,omitempty"`
108+
// The Unitxt System Prompt
109+
// +optional
110+
SystemPrompt *SystemPrompt `json:"systemPrompt,omitempty"`
87111
// The Unitxt Task
88112
// +optional
89113
Task *string `json:"task,omitempty"`
@@ -111,9 +135,17 @@ type TaskList struct {
111135
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
112136
}
113137

138+
// Use the tp_idx and sp_idx to point to the corresponding custom template
139+
// and custom system_prompt
114140
func (t *TaskRecipe) String() string {
115141
var b strings.Builder
116-
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card.Name, t.Template))
142+
b.WriteString(fmt.Sprintf("card=%s", t.Card.Name))
143+
if t.Template != nil && t.Template.Name != "" {
144+
b.WriteString(fmt.Sprintf(",template=%s", t.Template.Name))
145+
}
146+
if t.SystemPrompt != nil && t.SystemPrompt.Name != "" {
147+
b.WriteString(fmt.Sprintf(",system_prompt=%s", t.SystemPrompt.Name))
148+
}
117149
if t.Task != nil {
118150
b.WriteString(fmt.Sprintf(",task=%s", *t.Task))
119151
}

api/lmes/v1alpha1/zz_generated.deepcopy.go

Lines changed: 40 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmd/lmes_driver/main.go

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,23 @@ func (t *strArrayArg) String() string {
5050
}
5151

5252
var (
53-
taskRecipes strArrayArg
54-
customCards strArrayArg
55-
copy = flag.String("copy", "", "copy this binary to specified destination path")
56-
getStatus = flag.Bool("get-status", false, "Get current status")
57-
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
58-
outputPath = flag.String("output-path", OutputPath, "output path")
59-
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
60-
driverLog = ctrl.Log.WithName("driver")
53+
taskRecipes strArrayArg
54+
customCards strArrayArg
55+
customTemplates strArrayArg
56+
customSystemPrompts strArrayArg
57+
copy = flag.String("copy", "", "copy this binary to specified destination path")
58+
getStatus = flag.Bool("get-status", false, "Get current status")
59+
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
60+
outputPath = flag.String("output-path", OutputPath, "output path")
61+
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
62+
driverLog = ctrl.Log.WithName("driver")
6163
)
6264

6365
func init() {
6466
flag.Var(&taskRecipes, "task-recipe", "task recipe")
6567
flag.Var(&customCards, "custom-card", "A JSON string represents a custom card")
68+
flag.Var(&customTemplates, "custom-template", "A JSON string represents a custom template")
69+
flag.Var(&customSystemPrompts, "custom-prompt", "A string represents a custom system_prompt")
6670
}
6771

6872
func main() {
@@ -105,13 +109,15 @@ func main() {
105109
}
106110

107111
driverOpt := driver.DriverOption{
108-
Context: ctx,
109-
OutputPath: *outputPath,
110-
DetectDevice: *detectDevice,
111-
Logger: driverLog,
112-
TaskRecipes: taskRecipes,
113-
CustomCards: customCards,
114-
Args: args,
112+
Context: ctx,
113+
OutputPath: *outputPath,
114+
DetectDevice: *detectDevice,
115+
Logger: driverLog,
116+
TaskRecipes: taskRecipes,
117+
CustomCards: customCards,
118+
CustomTemplates: customTemplates,
119+
CustomSystemPrompt: customSystemPrompts,
120+
Args: args,
115121
}
116122

117123
driver, err := driver.NewDriver(&driverOpt)

config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4632,15 +4632,36 @@ spec:
46324632
numDemos:
46334633
description: Number of fewshot
46344634
type: integer
4635+
systemPrompt:
4636+
description: The Unitxt System Prompt
4637+
properties:
4638+
custom:
4639+
description: A custom system prompt string
4640+
type: string
4641+
name:
4642+
description: Unitxt System Prompt id
4643+
type: string
4644+
type: object
46354645
task:
46364646
description: The Unitxt Task
46374647
type: string
46384648
template:
46394649
description: The Unitxt template
4640-
type: string
4650+
properties:
4651+
custom:
4652+
description: |-
4653+
A JSON string for a custom unitxt template.
4654+
Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html
4655+
to compose a custom template, store it as a JSON file by calling the
4656+
add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
4657+
and use the JSON content as the value here.
4658+
type: string
4659+
name:
4660+
description: Unitxt template ID
4661+
type: string
4662+
type: object
46414663
required:
46424664
- card
4643-
- template
46444665
type: object
46454666
type: array
46464667
type: object

controllers/lmes/driver/driver.go

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"bufio"
2121
"context"
2222
"encoding/json"
23+
"errors"
2324
"fmt"
2425
"io"
2526
"io/fs"
@@ -43,26 +44,30 @@ var (
4344

4445
const (
4546
// put the domain socket under /tmp. may move to emptydir to share across containers
46-
socketPath = "/tmp/ta-lmes-driver.sock"
47-
DefaultTaskRecipesPath = "/opt/app-root/src/my_tasks"
48-
DefaultCatalogPath = "/opt/app-root/src/my_catalogs"
49-
TaskRecipePrefix = "tr"
50-
CustomCardPrefix = "custom"
51-
ShutdownURI = "/Shutdown"
52-
GetStatusURI = "/GetStatus"
47+
socketPath = "/tmp/ta-lmes-driver.sock"
48+
DefaultTaskRecipesPath = "/opt/app-root/src/my_tasks"
49+
DefaultCatalogPath = "/opt/app-root/src/my_catalogs"
50+
TaskRecipePrefix = "tr"
51+
CustomCardPrefix = "custom"
52+
CustomTemplatePrefix = "tp"
53+
CustomSystemPromptPrefix = "sp"
54+
ShutdownURI = "/Shutdown"
55+
GetStatusURI = "/GetStatus"
5356
)
5457

5558
type DriverOption struct {
56-
Context context.Context
57-
OutputPath string
58-
DetectDevice bool
59-
TaskRecipesPath string
60-
TaskRecipes []string
61-
CatalogPath string
62-
CustomCards []string
63-
Logger logr.Logger
64-
Args []string
65-
SocketPath string
59+
Context context.Context
60+
OutputPath string
61+
DetectDevice bool
62+
TaskRecipesPath string
63+
TaskRecipes []string
64+
CatalogPath string
65+
CustomCards []string
66+
CustomTemplates []string
67+
CustomSystemPrompt []string
68+
Logger logr.Logger
69+
Args []string
70+
SocketPath string
6671
}
6772

6873
type Driver interface {
@@ -313,9 +318,19 @@ func (d *driverImpl) exec() error {
313318
return fmt.Errorf("failed to create task recipes: %v", err)
314319
}
315320

321+
if err := d.prepDir4CustomArtifacts(); err != nil {
322+
return fmt.Errorf("failed to create the directories for custom artifacts: %v", err)
323+
}
324+
316325
if err := d.createCustomCards(); err != nil {
317326
return fmt.Errorf("failed to create custom cards: %v", err)
318327
}
328+
if err := d.createCustomTemplates(); err != nil {
329+
return fmt.Errorf("failed to create custom templates: %v", err)
330+
}
331+
if err := d.createCustomSystemPrompts(); err != nil {
332+
return fmt.Errorf("failed to create custom system_prompts: %v", err)
333+
}
319334

320335
// Detect available devices if needed
321336
if err := d.detectDevice(); err != nil {
@@ -478,6 +493,15 @@ func (d *driverImpl) createTaskRecipes() error {
478493
return nil
479494
}
480495

496+
func (d *driverImpl) prepDir4CustomArtifacts() error {
497+
subDirs := []string{"cards", "templates", "system_prompts"}
498+
var errs []error
499+
for _, dir := range subDirs {
500+
errs = append(errs, mkdirIfNotExist(filepath.Join(d.Option.CatalogPath, dir)))
501+
}
502+
return errors.Join(errs...)
503+
}
504+
481505
func (d *driverImpl) createCustomCards() error {
482506
for i, customCard := range d.Option.CustomCards {
483507
err := os.WriteFile(
@@ -491,3 +515,42 @@ func (d *driverImpl) createCustomCards() error {
491515
}
492516
return nil
493517
}
518+
519+
func (d *driverImpl) createCustomTemplates() error {
520+
for i, customTemplate := range d.Option.CustomTemplates {
521+
err := os.WriteFile(
522+
filepath.Join(d.Option.CatalogPath, "templates", fmt.Sprintf("%s_%d.json", CustomTemplatePrefix, i)),
523+
[]byte(customTemplate),
524+
0666,
525+
)
526+
if err != nil {
527+
return err
528+
}
529+
}
530+
return nil
531+
}
532+
533+
func (d *driverImpl) createCustomSystemPrompts() error {
534+
for i, systemPrompt := range d.Option.CustomSystemPrompt {
535+
err := os.WriteFile(
536+
filepath.Join(d.Option.CatalogPath, "system_prompts", fmt.Sprintf("%s_%d.json", CustomSystemPromptPrefix, i)),
537+
[]byte(fmt.Sprintf(`{ "__type__": "textual_system_prompt", "text": "%s" }`, systemPrompt)),
538+
0666,
539+
)
540+
if err != nil {
541+
return err
542+
}
543+
}
544+
return nil
545+
}
546+
547+
func mkdirIfNotExist(path string) error {
548+
fi, err := os.Stat(path)
549+
if err == nil && !fi.IsDir() {
550+
return fmt.Errorf("%s is a file. can not create a directory", path)
551+
}
552+
if os.IsNotExist(err) {
553+
return os.MkdirAll(path, 0770)
554+
}
555+
return nil
556+
}

0 commit comments

Comments
 (0)