Skip to content

Commit edcdc8f

Browse files
committed
feat: support unitxt recipes
Add new fields in the CRD to support unitxt recipes and leverage the driver to create corresponding yaml files of the unitxt recipes. Signed-off-by: Yihong Wang <[email protected]>
1 parent a626cf8 commit edcdc8f

File tree

10 files changed

+562
-49
lines changed

10 files changed

+562
-49
lines changed

Dockerfile.lmes-job

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,19 @@ USER default
77
WORKDIR /opt/app-root/src
88
RUN mkdir /opt/app-root/src/hf_home && chmod g+rwx /opt/app-root/src/hf_home
99
RUN mkdir /opt/app-root/src/output && chmod g+rwx /opt/app-root/src/output
10+
RUN mkdir /opt/app-root/src/my_tasks && chmod g+rwx /opt/app-root/src/my_tasks
1011
RUN mkdir /opt/app-root/src/.cache
1112
ENV PATH="/opt/app-root/bin:/opt/app-root/src/.local/bin/:/opt/app-root/src/bin:/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
1213

1314
RUN pip install --no-cache-dir --user --upgrade ibm-generative-ai[lm-eval]
1415
COPY --chown=1001:0 patch /opt/app-root/src/patch
15-
# Clone the Git repository and install the Python package
16+
# Clone the Git repository, check out v0.4.4 and install the Python package
1617
RUN git clone https://github.com/opendatahub-io/lm-evaluation-harness.git && \
17-
cd lm-evaluation-harness && git checkout 568af943e315100af3f00937bfd6947844769ab8 && \
18+
cd lm-evaluation-harness && git checkout 543617fef9ba885e87f8db8930fbbff1d4e2ca49 && \
1819
curl --output lm_eval/models/bam.py https://raw.githubusercontent.com/IBM/ibm-generative-ai/main/src/genai/extensions/lm_eval/model.py && \
19-
git apply /opt/app-root/src/patch/lmes/models.patch && pip install --no-cache-dir --user -e .[unitxt] && \
20-
pip install --no-cache-dir --user -e .[openai]
20+
git apply /opt/app-root/src/patch/lmes/models.patch && pip install --no-cache-dir --user -e .[api]
21+
22+
RUN python -c 'from lm_eval.tasks.unitxt import task; import os.path; print("class: !function " + task.__file__.replace("task.py", "task.Unitxt"))' > ./my_tasks/unitxt
2123

2224
ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src/lm-evaluation-harness:/opt/app-root/src:/opt/app-root/src/server
2325
ENV HF_HOME=/opt/app-root/src/hf_home

api/lmes/v1alpha1/lmevaljob_types.go

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ limitations under the License.
1717
package v1alpha1
1818

1919
import (
20+
"fmt"
21+
"strings"
22+
2023
corev1 "k8s.io/api/core/v1"
2124
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2225
)
@@ -78,6 +81,65 @@ type FileSecret struct {
7881
MountPath string `json:"mountPath"`
7982
}
8083

84+
// Use a task recipe to form a custom task. It maps to the Unitxt Recipe
85+
// Find details of the Unitxt Recipe here:
86+
// https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
87+
type TaskRecipe struct {
88+
// The Unitxt dataset card
89+
Card string `json:"card"`
90+
// The Unitxt template
91+
Template string `json:"template"`
92+
// The Unitxt Task
93+
// +optional
94+
Task *string `json:"task,omitempty"`
95+
// Metrics
96+
// +optional
97+
Metrics []string `json:"metrics,omitempty"`
98+
// The Unitxt format
99+
// +optional
100+
Format *string `json:"format,omitempty"`
101+
// A limit number of records to load
102+
// +optional
103+
LoaderLimit *int `json:"loaderLimit,omitempty"`
104+
// Number of fewshot
105+
// +optional
106+
NumDemos *int `json:"numDemos,omitempty"`
107+
// The pool size for the fewshot
108+
// +optional
109+
DemosPoolSize *int `json:"demosPoolSize,omitempty"`
110+
}
111+
112+
type TaskList struct {
113+
// TaskNames from lm-eval's task list
114+
TaskNames []string `json:"taskNames,omitempty"`
115+
// Task Recipes specifically for Unitxt
116+
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
117+
}
118+
119+
func (t *TaskRecipe) String() string {
120+
var b strings.Builder
121+
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card, t.Template))
122+
if t.Task != nil {
123+
b.WriteString(fmt.Sprintf(",task=%s", *t.Task))
124+
}
125+
if len(t.Metrics) > 0 {
126+
b.WriteString(fmt.Sprintf(",metrics=[%s]", strings.Join(t.Metrics, ",")))
127+
}
128+
if t.Format != nil {
129+
b.WriteString(fmt.Sprintf(",format=%s", *t.Format))
130+
}
131+
if t.LoaderLimit != nil {
132+
b.WriteString(fmt.Sprintf(",loader_limit=%d", *t.LoaderLimit))
133+
}
134+
if t.NumDemos != nil {
135+
b.WriteString(fmt.Sprintf(",num_demos=%d", *t.NumDemos))
136+
}
137+
if t.DemosPoolSize != nil {
138+
b.WriteString(fmt.Sprintf(",demos_pool_size=%d", *t.DemosPoolSize))
139+
}
140+
return b.String()
141+
}
142+
81143
// LMEvalJobSpec defines the desired state of LMEvalJob
82144
type LMEvalJobSpec struct {
83145
// INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
@@ -88,8 +150,8 @@ type LMEvalJobSpec struct {
88150
// Args for the model
89151
// +optional
90152
ModelArgs []Arg `json:"modelArgs,omitempty"`
91-
// Evaluation tasks
92-
Tasks []string `json:"tasks"`
153+
// Evaluation task list
154+
TaskList TaskList `json:"taskList"`
93155
// Sets the number of few-shot examples to place in context
94156
// +optional
95157
NumFewShot *int `json:"numFewShot,omitempty"`

api/lmes/v1alpha1/zz_generated.deepcopy.go

Lines changed: 73 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cmd/lmes_driver/main.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323
"io"
2424
"os"
25+
"strings"
2526
"time"
2627

2728
ctrl "sigs.k8s.io/controller-runtime"
@@ -36,18 +37,35 @@ const (
3637
OutputPath = "/opt/app-root/src/output"
3738
)
3839

40+
type taskRecipeArg []string
41+
42+
func (t *taskRecipeArg) Set(value string) error {
43+
*t = append(*t, value)
44+
return nil
45+
}
46+
47+
func (t *taskRecipeArg) String() string {
48+
// supposedly, use ":" as the separator for task recipe should be safe
49+
return strings.Join(*t, ":")
50+
}
51+
3952
var (
53+
taskRecipes taskRecipeArg
4054
copy = flag.String("copy", "", "copy this binary to specified destination path")
4155
jobNameSpace = flag.String("job-namespace", "", "Job's namespace ")
4256
jobName = flag.String("job-name", "", "Job's name")
4357
grpcService = flag.String("grpc-service", "", "grpc service name")
4458
grpcPort = flag.Int("grpc-port", 8082, "grpc port")
4559
outputPath = flag.String("output-path", OutputPath, "output path")
46-
detectDevice = flag.Bool("detect-device", true, "detect available device(s), CUDA or CPU")
60+
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
4761
reportInterval = flag.Duration("report-interval", time.Second*10, "specify the druation interval to report the progress")
4862
driverLog = ctrl.Log.WithName("driver")
4963
)
5064

65+
func init() {
66+
flag.Var(&taskRecipes, "task-recipe", "task recipe")
67+
}
68+
5169
func main() {
5270
opts := zap.Options{
5371
Development: true,
@@ -86,6 +104,7 @@ func main() {
86104
GrpcPort: *grpcPort,
87105
DetectDevice: *detectDevice,
88106
Logger: driverLog,
107+
TaskRecipes: taskRecipes,
89108
Args: args,
90109
ReportInterval: *reportInterval,
91110
}

cmd/lmes_driver/main_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
Copyright 2024.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package main
18+
19+
import (
20+
"context"
21+
"flag"
22+
"os"
23+
"testing"
24+
"time"
25+
26+
"github.com/stretchr/testify/assert"
27+
"github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/driver"
28+
"sigs.k8s.io/controller-runtime/pkg/log/zap"
29+
)
30+
31+
func Test_ArgParsing(t *testing.T) {
32+
os.Args = []string{
33+
"/opt/app-root/src/bin/driver",
34+
"--job-namespace", "default",
35+
"--job-name", "test",
36+
"--grpc-service", "grpc-service.test.svc",
37+
"--grpc-port", "8088",
38+
"--output-path", "/opt/app-root/src/output",
39+
"--detect-device",
40+
"--report-interval", "10s",
41+
"--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
42+
"--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
43+
"--",
44+
"sh", "-c", "python",
45+
}
46+
47+
opts := zap.Options{
48+
Development: true,
49+
}
50+
opts.BindFlags(flag.CommandLine)
51+
52+
flag.Parse()
53+
54+
args := flag.Args()
55+
56+
assert.Equal(t, "default", *jobNameSpace)
57+
assert.Equal(t, "test", *jobName)
58+
assert.Equal(t, "grpc-service.test.svc", *grpcService)
59+
assert.Equal(t, 8088, *grpcPort)
60+
assert.Equal(t, "/opt/app-root/src/output", *outputPath)
61+
assert.Equal(t, true, *detectDevice)
62+
assert.Equal(t, time.Second*10, *reportInterval)
63+
assert.Equal(t, taskRecipeArg{
64+
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
65+
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
66+
}, taskRecipes)
67+
68+
dOption := driver.DriverOption{
69+
Context: context.Background(),
70+
JobNamespace: *jobNameSpace,
71+
JobName: *jobName,
72+
OutputPath: *outputPath,
73+
GrpcService: *grpcService,
74+
GrpcPort: *grpcPort,
75+
DetectDevice: *detectDevice,
76+
Logger: driverLog,
77+
TaskRecipes: taskRecipes,
78+
Args: args,
79+
ReportInterval: *reportInterval,
80+
}
81+
82+
assert.Equal(t, []string{
83+
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
84+
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
85+
}, dOption.TaskRecipes)
86+
87+
assert.Equal(t, []string{
88+
"sh", "-c", "python",
89+
}, dOption.Args)
90+
}

0 commit comments

Comments
 (0)