Skip to content

Commit 23d8f5b

Browse files
committed
Allow users to define task groups in LMEvalJob
Add new field: TaskGroups under the TaskList to support custom task group. User can define a custom task group and specify a list of aggregate metrics. In the result JSON, the task groups have a dedicated section of their results. Signed-off-by: Yihong Wang <[email protected]>
1 parent 10c0fc3 commit 23d8f5b

File tree

8 files changed

+641
-21
lines changed

8 files changed

+641
-21
lines changed

api/lmes/v1alpha1/lmevaljob_types.go

+30
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ func (c *CustomArtifacts) GetTasks() []CustomArtifact {
181181
// Find details of the Unitxt Recipe here:
182182
// https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
183183
type TaskRecipe struct {
184+
// The name of the TaskRecipe
185+
// +optional
186+
Name *string `json:"name,omitempty"`
184187
// The Unitxt dataset card
185188
Card Card `json:"card"`
186189
// The Unitxt template
@@ -236,11 +239,35 @@ type CustomTasks struct {
236239
Source CustomTaskSource `json:"source,omitempty"`
237240
}
238241

242+
// Define an aggregate metric using 'mean' aggregation.
243+
type AggregateMetric struct {
244+
// The name of the metric to aggregate
245+
MetricName string `json:"metricName,omitempty"`
246+
// Weight by size or not. Default value is True
247+
// +optional
248+
WeightBySize *bool `json:"weightBySize,omitempty"`
249+
}
250+
251+
type TaskGroup struct {
252+
// The name of the task group
253+
Name string `json:"name"`
254+
// TaskNames from lm-eval's task list and/or from custom tasks if CustomTasks is defined
255+
// +optional
256+
TaskNames []string `json:"taskNames,omitempty"`
257+
// Task Recipes specifically for the Unitxt tasks
258+
// +optional
259+
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
260+
// A list of aggregate metrics to calculate for the task group
261+
// +optional
262+
AggregateMetrics []AggregateMetric `json:"aggregateMetrics,omitempty"`
263+
}
264+
239265
type TaskList struct {
240266
// TaskNames from lm-eval's task list and/or from custom tasks if CustomTasks is defined
241267
TaskNames []string `json:"taskNames,omitempty"`
242268
// Task Recipes specifically for Unitxt
243269
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
270+
TaskGroups []TaskGroup `json:"taskGroups,omitempty"`
244271
// Custom Unitxt artifacts that can be used in a TaskRecipe
245272
CustomArtifacts *CustomArtifacts `json:"custom,omitempty"`
246273
// CustomTasks is a list of external tasks
@@ -340,6 +367,9 @@ func (t *TaskRecipe) String() string {
340367
if t.DemosPoolSize != nil {
341368
b.WriteString(fmt.Sprintf(",demos_pool_size=%d", *t.DemosPoolSize))
342369
}
370+
if t.Name != nil && *t.Name != "" {
371+
b.WriteString(fmt.Sprintf("|%s", *t.Name))
372+
}
343373
return b.String()
344374
}
345375

api/lmes/v1alpha1/zz_generated.deepcopy.go

+66
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmd/lmes_driver/main.go

+3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func (t *strArrayArg) String() string {
5151

5252
var (
5353
taskRecipes strArrayArg
54+
taskGroups strArrayArg
5455
customArtifactArgs strArrayArg
5556
taskNames strArrayArg
5657
copy = flag.String("copy", "", "copy this binary to specified destination path")
@@ -70,6 +71,7 @@ var (
7071

7172
func init() {
7273
flag.Var(&taskRecipes, "task-recipe", "task recipe")
74+
flag.Var(&taskGroups, "task-group", "task group")
7375
flag.Var(&customArtifactArgs, "custom-artifact", "A string contains an artifact's type, name and value. Use | as separator")
7476
flag.Var(&taskNames, "task-name", "A task name for custom tasks")
7577
}
@@ -125,6 +127,7 @@ func main() {
125127
DetectDevice: *detectDevice,
126128
Logger: driverLog,
127129
TaskRecipes: taskRecipes,
130+
TaskGroups: taskGroups,
128131
CustomArtifacts: customArtifacts,
129132
Args: args,
130133
CommPort: *commPort,

config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml

+133
Original file line numberDiff line numberDiff line change
@@ -4819,6 +4819,136 @@ spec:
48194819
type: object
48204820
type: object
48214821
type: object
4822+
taskGroups:
4823+
items:
4824+
properties:
4825+
aggregateMetrics:
4826+
description: A list of aggregate metrics to calculate for
4827+
the task group
4828+
items:
4829+
description: Define an aggregate metric using 'mean' aggregation.
4830+
properties:
4831+
metricName:
4832+
description: The name of the metric to aggregate
4833+
type: string
4834+
weightBySize:
4835+
description: Weight by size or not. Default value
4836+
is True
4837+
type: boolean
4838+
type: object
4839+
type: array
4840+
name:
4841+
description: The name of the task group
4842+
type: string
4843+
taskNames:
4844+
description: TaskNames from lm-eval's task list and/or from
4845+
custom tasks if CustomTasks is defined
4846+
items:
4847+
type: string
4848+
type: array
4849+
taskRecipes:
4850+
description: Task Recipes specifically for the Unitxt tasks
4851+
items:
4852+
description: |-
4853+
Use a task recipe to form a custom task. It maps to the Unitxt Recipe
4854+
Find details of the Unitxt Recipe here:
4855+
https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
4856+
properties:
4857+
card:
4858+
description: The Unitxt dataset card
4859+
properties:
4860+
custom:
4861+
description: |-
4862+
A JSON string for a custom unitxt card which contains the custom dataset.
4863+
Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_dataset.html#adding-to-the-catalog
4864+
to compose a custom card, store it as a JSON file, and use the JSON content as the value here.
4865+
type: string
4866+
name:
4867+
description: Unitxt card's ID
4868+
type: string
4869+
type: object
4870+
demosPoolSize:
4871+
description: The pool size for the fewshot
4872+
type: integer
4873+
format:
4874+
description: The Unitxt format
4875+
type: string
4876+
loaderLimit:
4877+
description: A limit number of records to load
4878+
type: integer
4879+
metrics:
4880+
description: Metrics
4881+
items:
4882+
properties:
4883+
name:
4884+
description: Unitxt metric id
4885+
type: string
4886+
ref:
4887+
description: |-
4888+
The name of the custom metric in the custom field. Its value is a JSON string
4889+
for a custom Unitxt metric. Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_metric.html#adding-a-new-instance-metric
4890+
to compose a custom metric, store it as a JSON file by calling the
4891+
add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
4892+
and use the JSON content as the value here.
4893+
type: string
4894+
type: object
4895+
type: array
4896+
name:
4897+
description: The name of the TaskRecipe
4898+
type: string
4899+
numDemos:
4900+
description: Number of fewshot
4901+
type: integer
4902+
systemPrompt:
4903+
description: The Unitxt System Prompt
4904+
properties:
4905+
name:
4906+
description: Unitxt System Prompt id
4907+
type: string
4908+
ref:
4909+
description: The name of the custom systemPrompt
4910+
in the custom field. Its value is a custom system
4911+
prompt string
4912+
type: string
4913+
type: object
4914+
task:
4915+
description: The Unitxt Task
4916+
properties:
4917+
name:
4918+
description: Unitxt task id
4919+
type: string
4920+
ref:
4921+
description: |-
4922+
The name of the custom task in the custom field. Its value is a JSON string
4923+
for a custom Unitxt task. Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_task.html
4924+
to compose a custom task, store it as a JSON file by calling the
4925+
add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
4926+
and use the JSON content as the value here.
4927+
type: string
4928+
type: object
4929+
template:
4930+
description: The Unitxt template
4931+
properties:
4932+
name:
4933+
description: Unitxt template ID
4934+
type: string
4935+
ref:
4936+
description: |-
4937+
The name of the custom template in the custom field. Its value is a JSON string
4938+
for a custom Unitxt template. Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_template.html
4939+
to compose a custom template, store it as a JSON file by calling the
4940+
add_to_catalog API: https://www.unitxt.ai/en/latest/docs/saving_and_loading_from_catalog.html#adding-assets-to-the-catalog,
4941+
and use the JSON content as the value here.
4942+
type: string
4943+
type: object
4944+
required:
4945+
- card
4946+
type: object
4947+
type: array
4948+
required:
4949+
- name
4950+
type: object
4951+
type: array
48224952
taskNames:
48234953
description: TaskNames from lm-eval's task list and/or from custom
48244954
tasks if CustomTasks is defined
@@ -4872,6 +5002,9 @@ spec:
48725002
type: string
48735003
type: object
48745004
type: array
5005+
name:
5006+
description: The name of the TaskRecipe
5007+
type: string
48755008
numDemos:
48765009
description: Number of fewshot
48775010
type: integer

controllers/lmes/driver/driver.go

+37-3
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ type DriverOption struct {
6060
DetectDevice bool
6161
TaskRecipesPath string
6262
TaskRecipes []string
63+
TaskGroups []string
6364
CatalogPath string
6465
CustomArtifacts []CustomArtifact
6566
Logger logr.Logger
@@ -344,6 +345,9 @@ func (d *driverImpl) exec() error {
344345
if err := d.createTaskRecipes(); err != nil {
345346
return fmt.Errorf("failed to create task recipes: %v", err)
346347
}
348+
if err := d.createTaskGroups(); err != nil {
349+
return fmt.Errorf("failed to create task groups: %v", err)
350+
}
347351

348352
if err := d.prepDir4CustomArtifacts(); err != nil {
349353
return fmt.Errorf("failed to create the directories for custom artifacts: %v", err)
@@ -507,12 +511,21 @@ func (d *driverImpl) updateProgress(msg string) {
507511
}
508512

509513
func (d *driverImpl) createTaskRecipes() error {
510-
for i, taskRecipe := range d.Option.TaskRecipes {
514+
id := 0
515+
for _, rString := range d.Option.TaskRecipes {
516+
tokens := strings.SplitN(rString, "|", 2)
517+
taskRecipe := tokens[0]
518+
taskName := fmt.Sprintf("%s_%d", TaskRecipePrefix, id)
519+
if len(tokens) == 2 {
520+
taskName = tokens[1]
521+
} else {
522+
id++
523+
}
511524
err := os.WriteFile(
512-
filepath.Join(d.Option.TaskRecipesPath, fmt.Sprintf("%s_%d.yaml", TaskRecipePrefix, i)),
525+
filepath.Join(d.Option.TaskRecipesPath, fmt.Sprintf("%s.yaml", taskName)),
513526
[]byte(fmt.Sprintf(
514527
"task: %s\ninclude: unitxt\nrecipe: %s",
515-
fmt.Sprintf("%s_%d", TaskRecipePrefix, i),
528+
taskName,
516529
taskRecipe,
517530
)),
518531
0666,
@@ -524,6 +537,27 @@ func (d *driverImpl) createTaskRecipes() error {
524537
return nil
525538
}
526539

540+
func (d *driverImpl) createTaskGroups() error {
541+
for _, rString := range d.Option.TaskGroups {
542+
tokens := strings.SplitN(rString, "|", 2)
543+
taskGroupName := tokens[0]
544+
definition := tokens[1]
545+
err := os.WriteFile(
546+
filepath.Join(d.Option.TaskRecipesPath, fmt.Sprintf("%s.yaml", taskGroupName)),
547+
[]byte(fmt.Sprintf(
548+
"group: %s\n%s",
549+
taskGroupName,
550+
definition,
551+
)),
552+
0666,
553+
)
554+
if err != nil {
555+
return err
556+
}
557+
}
558+
return nil
559+
}
560+
527561
func (d *driverImpl) prepDir4CustomArtifacts() error {
528562
subDirs := []string{"cards", "templates", "system_prompts"}
529563
var errs []error

0 commit comments

Comments
 (0)