Skip to content

Commit a1e87dd

Browse files
committed
feat(LMES): Support custom template and prompt
Expand the LMEvalJob CRD to support custom templates and system prompts. This is mainly for custom unitxt task recipes. Now, users can use the `template` and `systemPrompt` fields under the `taskRecipes` to specify the custom template and system prompt. Signed-off-by: Yihong Wang <[email protected]>
1 parent ffe9cf5 commit a1e87dd

File tree

9 files changed

+471
-84
lines changed

9 files changed

+471
-84
lines changed

Dockerfile.lmes-job

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ RUN curl -L https://github.com/opendatahub-io/lm-evaluation-harness/archive/refs
2020

2121
RUN python -c 'from lm_eval.tasks.unitxt import task; import os.path; print("class: !function " + task.__file__.replace("task.py", "task.Unitxt"))' > ./my_tasks/unitxt
2222

23-
ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src/lm-evaluation-harness:/opt/app-root/src:/opt/app-root/src/server
23+
ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src:/opt/app-root/src/server
2424
ENV HF_HOME=/opt/app-root/src/hf_home
25-
ENV UNITXT_ARTIFACTORIES=/opt/app-root/src/my_catalogs
25+
ENV UNITXT_CATALOGS=/opt/app-root/src/my_catalogs
2626

2727
USER 65532:65532
2828
CMD ["/opt/app-root/bin/python"]

api/lmes/v1alpha1/lmevaljob_types.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,38 @@ type Card struct {
7676
Custom string `json:"custom,omitempty"`
7777
}
7878

79+
type Template struct {
80+
// Unitxt template ID
81+
// +optional
82+
Name string `json:"name,omitempty"`
83+
// A JSON string for a custom unitxt template.
84+
// Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html
85+
// to compose a custom template, store it as a JSON file by calling the
86+
// add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
87+
// and use the JSON content as the value here.
88+
// +optional
89+
Custom string `json:"custom,omitempty"`
90+
}
91+
92+
type SystemPrompt struct {
93+
// Unitxt System Prompt id
94+
Name string `json:"name,omitempty"`
95+
// A custom system prompt string
96+
Custom string `json:"custom,omitempty"`
97+
}
98+
7999
// Use a task recipe to form a custom task. It maps to the Unitxt Recipe
80100
// Find details of the Unitxt Recipe here:
81101
// https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
82102
type TaskRecipe struct {
83103
// The Unitxt dataset card
84104
Card Card `json:"card"`
85105
// The Unitxt template
86-
Template string `json:"template"`
106+
// +optional
107+
Template *Template `json:"template,omitempty"`
108+
// The Unitxt System Prompt
109+
// +optional
110+
SystemPrompt *SystemPrompt `json:"systemPrompt,omitempty"`
87111
// The Unitxt Task
88112
// +optional
89113
Task *string `json:"task,omitempty"`
@@ -111,9 +135,17 @@ type TaskList struct {
111135
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
112136
}
113137

138+
// Use the tp_idx and sp_idx to point to the corresponding custom template
139+
// and custom system_prompt
114140
func (t *TaskRecipe) String() string {
115141
var b strings.Builder
116-
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card.Name, t.Template))
142+
b.WriteString(fmt.Sprintf("card=%s", t.Card.Name))
143+
if t.Template != nil && t.Template.Name != "" {
144+
b.WriteString(fmt.Sprintf(",template=%s", t.Template.Name))
145+
}
146+
if t.SystemPrompt != nil && t.SystemPrompt.Name != "" {
147+
b.WriteString(fmt.Sprintf(",system_prompt=%s", t.SystemPrompt.Name))
148+
}
117149
if t.Task != nil {
118150
b.WriteString(fmt.Sprintf(",task=%s", *t.Task))
119151
}

api/lmes/v1alpha1/zz_generated.deepcopy.go

Lines changed: 40 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmd/lmes_driver/main.go

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,25 @@ func (t *strArrayArg) String() string {
5050
}
5151

5252
var (
53-
taskRecipes strArrayArg
54-
customCards strArrayArg
55-
copy = flag.String("copy", "", "copy this binary to specified destination path")
56-
getStatus = flag.Bool("get-status", false, "Get current status")
57-
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
58-
outputPath = flag.String("output-path", OutputPath, "output path")
59-
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
60-
commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port")
61-
downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3")
62-
driverLog = ctrl.Log.WithName("driver")
53+
taskRecipes strArrayArg
54+
customCards strArrayArg
55+
customTemplates strArrayArg
56+
customSystemPrompts strArrayArg
57+
copy = flag.String("copy", "", "copy this binary to specified destination path")
58+
getStatus = flag.Bool("get-status", false, "Get current status")
59+
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
60+
outputPath = flag.String("output-path", OutputPath, "output path")
61+
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
62+
commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port")
63+
downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3")
64+
driverLog = ctrl.Log.WithName("driver")
6365
)
6466

6567
func init() {
6668
flag.Var(&taskRecipes, "task-recipe", "task recipe")
6769
flag.Var(&customCards, "custom-card", "A JSON string represents a custom card")
70+
flag.Var(&customTemplates, "custom-template", "A JSON string represents a custom template")
71+
flag.Var(&customSystemPrompts, "custom-prompt", "A string represents a custom system_prompt")
6872
}
6973

7074
func main() {
@@ -107,15 +111,17 @@ func main() {
107111
}
108112

109113
driverOpt := driver.DriverOption{
110-
Context: ctx,
111-
OutputPath: *outputPath,
112-
DetectDevice: *detectDevice,
113-
Logger: driverLog,
114-
TaskRecipes: taskRecipes,
115-
CustomCards: customCards,
116-
Args: args,
117-
CommPort: *commPort,
118-
DownloadAssetsS3: *downloadAssetsS3,
114+
Context: ctx,
115+
OutputPath: *outputPath,
116+
DetectDevice: *detectDevice,
117+
Logger: driverLog,
118+
TaskRecipes: taskRecipes,
119+
CustomCards: customCards,
120+
CustomTemplates: customTemplates,
121+
CustomSystemPrompt: customSystemPrompts,
122+
Args: args,
123+
CommPort: *commPort,
124+
DownloadAssetsS3: *downloadAssetsS3,
119125
}
120126

121127
driver, err := driver.NewDriver(&driverOpt)

config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4764,15 +4764,36 @@ spec:
47644764
numDemos:
47654765
description: Number of fewshot
47664766
type: integer
4767+
systemPrompt:
4768+
description: The Unitxt System Prompt
4769+
properties:
4770+
custom:
4771+
description: A custom system prompt string
4772+
type: string
4773+
name:
4774+
description: Unitxt System Prompt id
4775+
type: string
4776+
type: object
47674777
task:
47684778
description: The Unitxt Task
47694779
type: string
47704780
template:
47714781
description: The Unitxt template
4772-
type: string
4782+
properties:
4783+
custom:
4784+
description: |-
4785+
A JSON string for a custom unitxt template.
4786+
Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html
4787+
to compose a custom template, store it as a JSON file by calling the
4788+
add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
4789+
and use the JSON content as the value here.
4790+
type: string
4791+
name:
4792+
description: Unitxt template ID
4793+
type: string
4794+
type: object
47734795
required:
47744796
- card
4775-
- template
47764797
type: object
47774798
type: array
47784799
type: object

controllers/lmes/driver/driver.go

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"bufio"
2121
"context"
2222
"encoding/json"
23+
"errors"
2324
"fmt"
2425
"io"
2526
"io/fs"
@@ -43,27 +44,31 @@ var (
4344

4445
const (
4546
// the default port for the driver to listen on
46-
DefaultPort = 18080
47-
DefaultTaskRecipesPath = "/opt/app-root/src/my_tasks"
48-
DefaultCatalogPath = "/opt/app-root/src/my_catalogs"
49-
TaskRecipePrefix = "tr"
50-
CustomCardPrefix = "custom"
51-
ShutdownURI = "/Shutdown"
52-
GetStatusURI = "/GetStatus"
47+
DefaultPort = 18080
48+
DefaultTaskRecipesPath = "/opt/app-root/src/my_tasks"
49+
DefaultCatalogPath = "/opt/app-root/src/my_catalogs"
50+
TaskRecipePrefix = "tr"
51+
CustomCardPrefix = "custom"
52+
CustomTemplatePrefix = "tp"
53+
CustomSystemPromptPrefix = "sp"
54+
ShutdownURI = "/Shutdown"
55+
GetStatusURI = "/GetStatus"
5356
)
5457

5558
type DriverOption struct {
56-
Context context.Context
57-
OutputPath string
58-
DetectDevice bool
59-
TaskRecipesPath string
60-
TaskRecipes []string
61-
CatalogPath string
62-
CustomCards []string
63-
Logger logr.Logger
64-
Args []string
65-
CommPort int
66-
DownloadAssetsS3 bool
59+
Context context.Context
60+
OutputPath string
61+
DetectDevice bool
62+
TaskRecipesPath string
63+
TaskRecipes []string
64+
CatalogPath string
65+
CustomCards []string
66+
CustomTemplates []string
67+
CustomSystemPrompt []string
68+
Logger logr.Logger
69+
Args []string
70+
CommPort int
71+
DownloadAssetsS3 bool
6772
}
6873

6974
type Driver interface {
@@ -321,6 +326,10 @@ func (d *driverImpl) exec() error {
321326
return fmt.Errorf("failed to create task recipes: %v", err)
322327
}
323328

329+
if err := d.prepDir4CustomArtifacts(); err != nil {
330+
return fmt.Errorf("failed to create the directories for custom artifacts: %v", err)
331+
}
332+
324333
if err := d.createCustomCards(); err != nil {
325334
return fmt.Errorf("failed to create custom cards: %v", err)
326335
}
@@ -329,6 +338,12 @@ func (d *driverImpl) exec() error {
329338
if err := d.downloadS3Assets(); err != nil {
330339
return err
331340
}
341+
if err := d.createCustomTemplates(); err != nil {
342+
return fmt.Errorf("failed to create custom templates: %v", err)
343+
}
344+
if err := d.createCustomSystemPrompts(); err != nil {
345+
return fmt.Errorf("failed to create custom system_prompts: %v", err)
346+
}
332347

333348
// Detect available devices if needed
334349
if err := d.detectDevice(); err != nil {
@@ -491,6 +506,15 @@ func (d *driverImpl) createTaskRecipes() error {
491506
return nil
492507
}
493508

509+
func (d *driverImpl) prepDir4CustomArtifacts() error {
510+
subDirs := []string{"cards", "templates", "system_prompts"}
511+
var errs []error
512+
for _, dir := range subDirs {
513+
errs = append(errs, mkdirIfNotExist(filepath.Join(d.Option.CatalogPath, dir)))
514+
}
515+
return errors.Join(errs...)
516+
}
517+
494518
func (d *driverImpl) createCustomCards() error {
495519
for i, customCard := range d.Option.CustomCards {
496520
err := os.WriteFile(
@@ -504,3 +528,42 @@ func (d *driverImpl) createCustomCards() error {
504528
}
505529
return nil
506530
}
531+
532+
func (d *driverImpl) createCustomTemplates() error {
533+
for i, customTemplate := range d.Option.CustomTemplates {
534+
err := os.WriteFile(
535+
filepath.Join(d.Option.CatalogPath, "templates", fmt.Sprintf("%s_%d.json", CustomTemplatePrefix, i)),
536+
[]byte(customTemplate),
537+
0666,
538+
)
539+
if err != nil {
540+
return err
541+
}
542+
}
543+
return nil
544+
}
545+
546+
func (d *driverImpl) createCustomSystemPrompts() error {
547+
for i, systemPrompt := range d.Option.CustomSystemPrompt {
548+
err := os.WriteFile(
549+
filepath.Join(d.Option.CatalogPath, "system_prompts", fmt.Sprintf("%s_%d.json", CustomSystemPromptPrefix, i)),
550+
[]byte(fmt.Sprintf(`{ "__type__": "textual_system_prompt", "text": "%s" }`, systemPrompt)),
551+
0666,
552+
)
553+
if err != nil {
554+
return err
555+
}
556+
}
557+
return nil
558+
}
559+
560+
func mkdirIfNotExist(path string) error {
561+
fi, err := os.Stat(path)
562+
if err == nil && !fi.IsDir() {
563+
return fmt.Errorf("%s is a file. can not create a directory", path)
564+
}
565+
if os.IsNotExist(err) {
566+
return os.MkdirAll(path, 0770)
567+
}
568+
return nil
569+
}

0 commit comments

Comments
 (0)