Skip to content

feat: support custom dataset #309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Dockerfile.lmes-job
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ WORKDIR /opt/app-root/src
RUN mkdir /opt/app-root/src/hf_home && chmod g+rwx /opt/app-root/src/hf_home
RUN mkdir /opt/app-root/src/output && chmod g+rwx /opt/app-root/src/output
RUN mkdir /opt/app-root/src/my_tasks && chmod g+rwx /opt/app-root/src/my_tasks
RUN mkdir -p /opt/app-root/src/my_catalogs/cards && chmod -R g+rwx /opt/app-root/src/my_catalogs
RUN mkdir /opt/app-root/src/.cache
ENV PATH="/opt/app-root/bin:/opt/app-root/src/.local/bin/:/opt/app-root/src/bin:/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"

Expand All @@ -23,6 +24,7 @@ RUN python -c 'from lm_eval.tasks.unitxt import task; import os.path; print("cla

ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src/lm-evaluation-harness:/opt/app-root/src:/opt/app-root/src/server
ENV HF_HOME=/opt/app-root/src/hf_home
ENV UNITXT_ARTIFACTORIES=/opt/app-root/src/my_catalogs

CMD ["/opt/app-root/bin/python"]

103 changes: 82 additions & 21 deletions api/lmes/v1alpha1/lmevaljob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,30 +63,23 @@ type Arg struct {
Value string `json:"value,omitempty"`
}

type EnvSecret struct {
// Environment's name
Env string `json:"env"`
// The secret is from a secret object
type Card struct {
// Unitxt card's ID
// +optional
SecretRef *corev1.SecretKeySelector `json:"secretRef,omitempty"`
// The secret is from a plain text
Name string `json:"name,omitempty"`
// A JSON string for a custom unitxt card which contains the custom dataset.
// Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_dataset.html#adding-to-the-catalog
// to compose a custom card, store it as a JSON file, and use the JSON content as the value here.
// +optional
Secret *string `json:"secret,omitempty"`
}

type FileSecret struct {
// The secret object
SecretRef corev1.SecretVolumeSource `json:"secretRef,omitempty"`
// The path to mount the secret
MountPath string `json:"mountPath"`
Custom string `json:"custom,omitempty"`
}

// Use a task recipe to form a custom task. It maps to the Unitxt Recipe
// Find details of the Unitxt Recipe here:
// https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
type TaskRecipe struct {
// The Unitxt dataset card
Card string `json:"card"`
Card Card `json:"card"`
// The Unitxt template
Template string `json:"template"`
// The Unitxt Task
Expand Down Expand Up @@ -118,7 +111,7 @@ type TaskList struct {

func (t *TaskRecipe) String() string {
var b strings.Builder
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card, t.Template))
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card.Name, t.Template))
if t.Task != nil {
b.WriteString(fmt.Sprintf(",task=%s", *t.Task))
}
Expand All @@ -140,6 +133,76 @@ func (t *TaskRecipe) String() string {
return b.String()
}

type LMEvalContainer struct {
// Define Env information for the main container
// +optional
Env []corev1.EnvVar `json:"env,omitempty"`
// Define the volume mount information
// +optional
VolumeMounts []corev1.VolumeMount `json:"volumeMounts,omitempty"`
// Compute Resources required by this container.
// More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/
// +optional
Resources *corev1.ResourceRequirements `json:"resources,omitempty"`
}

// The following Getter-ish functions avoid nil pointer panic
func (c *LMEvalContainer) GetEnv() []corev1.EnvVar {
if c == nil {
return nil
}
return c.Env
}

func (c *LMEvalContainer) GetVolumMounts() []corev1.VolumeMount {
if c == nil {
return nil
}
return c.VolumeMounts
}

func (c *LMEvalContainer) GetResources() *corev1.ResourceRequirements {
if c == nil {
return nil
}
return c.Resources
}

type LMEvalPodSpec struct {
// Extra container data for the lm-eval container
// +optional
Container *LMEvalContainer `json:"container,omitempty"`
// Specify the volumes information for the lm-eval and sidecar containers
// +optional
Volumes []corev1.Volume `json:"volumes,omitempty"`
// Specify extra containers for the lm-eval job
// FIXME: aggregate the sidecar containers into the pod
// +optional
SideCars []corev1.Container `json:"sideCars,omitempty"`
}

// The following Getter-ish functions avoid nil pointer panic
func (p *LMEvalPodSpec) GetContainer() *LMEvalContainer {
if p == nil {
return nil
}
return p.Container
}

func (p *LMEvalPodSpec) GetVolumes() []corev1.Volume {
if p == nil {
return nil
}
return p.Volumes
}

func (p *LMEvalPodSpec) GetSideCards() []corev1.Container {
if p == nil {
return nil
}
return p.SideCars
}

// LMEvalJobSpec defines the desired state of LMEvalJob
type LMEvalJobSpec struct {
// INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
Expand Down Expand Up @@ -167,14 +230,12 @@ type LMEvalJobSpec struct {
// model, will be saved at per-document granularity
// +optional
LogSamples *bool `json:"logSamples,omitempty"`
// Assign secrets to the environment variables
// +optional
EnvSecrets []EnvSecret `json:"envSecrets,omitempty"`
// Use secrets as files
FileSecrets []FileSecret `json:"fileSecrets,omitempty"`
// Batch size for the evaluation. This is used by the models that run and are loaded
// locally and not apply for the commercial APIs.
BatchSize *int `json:"batchSize,omitempty"`
// Specify extra information for the lm-eval job's pod
// +optional
Pod *LMEvalPodSpec `json:"pod,omitempty"`
}

// LMEvalJobStatus defines the observed state of LMEvalJob
Expand Down
100 changes: 67 additions & 33 deletions api/lmes/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions cmd/lmes_driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,21 @@ const (
OutputPath = "/opt/app-root/src/output"
)

type taskRecipeArg []string
type strArrayArg []string

func (t *taskRecipeArg) Set(value string) error {
func (t *strArrayArg) Set(value string) error {
*t = append(*t, value)
return nil
}

func (t *taskRecipeArg) String() string {
func (t *strArrayArg) String() string {
// supposedly, use ":" as the separator for task recipe should be safe
return strings.Join(*t, ":")
}

var (
taskRecipes taskRecipeArg
taskRecipes strArrayArg
customCards strArrayArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
jobNameSpace = flag.String("job-namespace", "", "Job's namespace ")
jobName = flag.String("job-name", "", "Job's name")
Expand All @@ -64,6 +65,7 @@ var (

func init() {
flag.Var(&taskRecipes, "task-recipe", "task recipe")
flag.Var(&customCards, "custom-card", "A JSON string represents a custom card")
}

func main() {
Expand Down Expand Up @@ -105,6 +107,7 @@ func main() {
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
CustomCards: customCards,
Args: args,
ReportInterval: *reportInterval,
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/lmes_driver/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func Test_ArgParsing(t *testing.T) {
assert.Equal(t, "/opt/app-root/src/output", *outputPath)
assert.Equal(t, true, *detectDevice)
assert.Equal(t, time.Second*10, *reportInterval)
assert.Equal(t, taskRecipeArg{
assert.Equal(t, strArrayArg{
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
}, taskRecipes)
Expand Down
Loading
Loading