Skip to content

Commit fa56a16

Browse files
committed
feat: new pulling mechanism for job statuses
Update the driver to keep running even the user program finishes. The driver provides two APIs: - GetStatus(): retrieve job status - Shutdown(): properly tear down the driver In the controller side, it uses `pod/exec` resource to run the driver command to invoke the driver APIs to retrieve the job status and shutdown the driver when job is done. Signed-off-by: Yihong Wang <[email protected]>
1 parent b2bec12 commit fa56a16

20 files changed

+574
-1637
lines changed

cmd/lmes_driver/main.go

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ package main
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"flag"
2223
"fmt"
2324
"io"
2425
"os"
2526
"strings"
26-
"time"
2727

2828
ctrl "sigs.k8s.io/controller-runtime"
2929
"sigs.k8s.io/controller-runtime/pkg/log"
@@ -50,17 +50,14 @@ func (t *strArrayArg) String() string {
5050
}
5151

5252
var (
53-
taskRecipes strArrayArg
54-
customCards strArrayArg
55-
copy = flag.String("copy", "", "copy this binary to specified destination path")
56-
jobNameSpace = flag.String("job-namespace", "", "Job's namespace ")
57-
jobName = flag.String("job-name", "", "Job's name")
58-
grpcService = flag.String("grpc-service", "", "grpc service name")
59-
grpcPort = flag.Int("grpc-port", 8082, "grpc port")
60-
outputPath = flag.String("output-path", OutputPath, "output path")
61-
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
62-
reportInterval = flag.Duration("report-interval", time.Second*10, "specify the druation interval to report the progress")
63-
driverLog = ctrl.Log.WithName("driver")
53+
taskRecipes strArrayArg
54+
customCards strArrayArg
55+
copy = flag.String("copy", "", "copy this binary to specified destination path")
56+
getStatus = flag.Bool("get-status", false, "Get current status")
57+
shutdown = flag.Bool("shutdown", false, "Shutdown the driver")
58+
outputPath = flag.String("output-path", OutputPath, "output path")
59+
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
60+
driverLog = ctrl.Log.WithName("driver")
6461
)
6562

6663
func init() {
@@ -83,7 +80,7 @@ func main() {
8380

8481
if *copy != "" {
8582
// copy exec to destination
86-
if err := CopyExec(*copy); err != nil {
83+
if err := copyExec(*copy); err != nil {
8784
driverLog.Error(err, "failed to copy binary")
8885
os.Exit(1)
8986
return
@@ -92,29 +89,34 @@ func main() {
9289
return
9390
}
9491

92+
if *getStatus {
93+
getStatusOrDie(ctx)
94+
return
95+
}
96+
97+
if *shutdown {
98+
shutdownOrDie(ctx)
99+
return
100+
}
101+
95102
if len(args) == 0 {
96103
driverLog.Error(fmt.Errorf("no user program"), "empty args")
97104
os.Exit(1)
98105
}
99106

100107
driverOpt := driver.DriverOption{
101-
Context: ctx,
102-
JobNamespace: *jobNameSpace,
103-
JobName: *jobName,
104-
OutputPath: *outputPath,
105-
GrpcService: *grpcService,
106-
GrpcPort: *grpcPort,
107-
DetectDevice: *detectDevice,
108-
Logger: driverLog,
109-
TaskRecipes: taskRecipes,
110-
CustomCards: customCards,
111-
Args: args,
112-
ReportInterval: *reportInterval,
108+
Context: ctx,
109+
OutputPath: *outputPath,
110+
DetectDevice: *detectDevice,
111+
Logger: driverLog,
112+
TaskRecipes: taskRecipes,
113+
CustomCards: customCards,
114+
Args: args,
113115
}
114116

115117
driver, err := driver.NewDriver(&driverOpt)
116118
if err != nil {
117-
driverLog.Error(err, "Driver.Run failed")
119+
driverLog.Error(err, "Driver.NewDriver failed")
118120
os.Exit(1)
119121
}
120122

@@ -123,11 +125,10 @@ func main() {
123125
driverLog.Error(err, "Driver.Run failed")
124126
exitCode = 1
125127
}
126-
driver.Cleanup()
127128
os.Exit(exitCode)
128129
}
129130

130-
func CopyExec(destination string) (err error) {
131+
func copyExec(destination string) (err error) {
131132
defer func() {
132133
if err != nil {
133134
err = fmt.Errorf("copy this binary to %s: %w", destination, err)
@@ -161,3 +162,53 @@ func findThisBinary() (string, error) {
161162
}
162163
return bin, nil
163164
}
165+
166+
func getStatusOrDie(ctx context.Context) {
167+
driver, err := driver.NewDriver(&driver.DriverOption{
168+
Context: ctx,
169+
OutputPath: *outputPath,
170+
DetectDevice: *detectDevice,
171+
Logger: driverLog,
172+
})
173+
174+
if err != nil {
175+
driverLog.Error(err, "failed to initialize the driver")
176+
os.Exit(1)
177+
}
178+
179+
status, err := driver.GetStatus()
180+
if err != nil {
181+
driverLog.Error(err, "failed to get status", "error", err.Error())
182+
os.Exit(1)
183+
}
184+
185+
b, err := json.Marshal(status)
186+
if err != nil {
187+
driverLog.Error(err, "json serialization failed", "error", err.Error())
188+
os.Exit(1)
189+
}
190+
191+
fmt.Print(string(b))
192+
os.Exit(0)
193+
}
194+
195+
func shutdownOrDie(ctx context.Context) {
196+
driver, err := driver.NewDriver(&driver.DriverOption{
197+
Context: ctx,
198+
OutputPath: *outputPath,
199+
DetectDevice: *detectDevice,
200+
Logger: driverLog,
201+
})
202+
203+
if err != nil {
204+
driverLog.Error(err, "failed to initialize the driver")
205+
os.Exit(1)
206+
}
207+
208+
err = driver.Shutdown()
209+
if err != nil {
210+
driverLog.Error(err, "failed to shutdown", "error", err.Error())
211+
os.Exit(1)
212+
}
213+
os.Exit(0)
214+
}

cmd/lmes_driver/main_test.go

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"flag"
2222
"os"
2323
"testing"
24-
"time"
2524

2625
"github.com/stretchr/testify/assert"
2726
"github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/driver"
@@ -31,13 +30,8 @@ import (
3130
func Test_ArgParsing(t *testing.T) {
3231
os.Args = []string{
3332
"/opt/app-root/src/bin/driver",
34-
"--job-namespace", "default",
35-
"--job-name", "test",
36-
"--grpc-service", "grpc-service.test.svc",
37-
"--grpc-port", "8088",
3833
"--output-path", "/opt/app-root/src/output",
3934
"--detect-device",
40-
"--report-interval", "10s",
4135
"--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
4236
"--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
4337
"--",
@@ -53,30 +47,19 @@ func Test_ArgParsing(t *testing.T) {
5347

5448
args := flag.Args()
5549

56-
assert.Equal(t, "default", *jobNameSpace)
57-
assert.Equal(t, "test", *jobName)
58-
assert.Equal(t, "grpc-service.test.svc", *grpcService)
59-
assert.Equal(t, 8088, *grpcPort)
60-
assert.Equal(t, "/opt/app-root/src/output", *outputPath)
6150
assert.Equal(t, true, *detectDevice)
62-
assert.Equal(t, time.Second*10, *reportInterval)
6351
assert.Equal(t, strArrayArg{
6452
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
6553
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
6654
}, taskRecipes)
6755

6856
dOption := driver.DriverOption{
69-
Context: context.Background(),
70-
JobNamespace: *jobNameSpace,
71-
JobName: *jobName,
72-
OutputPath: *outputPath,
73-
GrpcService: *grpcService,
74-
GrpcPort: *grpcPort,
75-
DetectDevice: *detectDevice,
76-
Logger: driverLog,
77-
TaskRecipes: taskRecipes,
78-
Args: args,
79-
ReportInterval: *reportInterval,
57+
Context: context.Background(),
58+
OutputPath: *outputPath,
59+
DetectDevice: *detectDevice,
60+
Logger: driverLog,
61+
TaskRecipes: taskRecipes,
62+
Args: args,
8063
}
8164

8265
assert.Equal(t, []string{

config/base/params.env

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ lmes-driver-image=quay.io/trustyai/ta-lmes-driver:latest
66
lmes-pod-image=quay.io/trustyai/ta-lmes-job:latest
77
lmes-pod-checking-interval=10s
88
lmes-image-pull-policy=Always
9-
lmes-grpc-service=lmes-grpc
10-
lmes-grpc-port=8082
119
lmes-max-batch-size=24
1210
lmes-default-batch-size=8
11+
lmes-detect-device=true

config/manager/kustomization.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
resources:
22
- manager.yaml
3-
- service.yaml
43
apiVersion: kustomize.config.k8s.io/v1beta1
54
kind: Kustomization

config/manager/service.yaml

Lines changed: 0 additions & 14 deletions
This file was deleted.

config/overlays/lmes/kustomization.yaml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,5 @@ kind: Kustomization
44
resources:
55
- ../../base
66

7-
replacements:
8-
- source:
9-
kind: Service
10-
version: v1
11-
name: lmes-grpc
12-
fieldPath: .metadata.name
13-
targets:
14-
- select:
15-
kind: ConfigMap
16-
version: v1
17-
name: config
18-
fieldPaths:
19-
- .data.lmes-grpc-service
20-
217
patchesStrategicMerge:
228
- lmes-only-patch.yaml

config/overlays/odh/kustomization.yaml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,3 @@ apiVersion: kustomize.config.k8s.io/v1beta1
33
kind: Kustomization
44
resources:
55
- ../../base
6-
7-
replacements:
8-
- source:
9-
kind: Service
10-
version: v1
11-
name: lmes-grpc
12-
fieldPath: .metadata.name
13-
targets:
14-
- select:
15-
kind: ConfigMap
16-
version: v1
17-
name: config
18-
fieldPaths:
19-
- .data.lmes-grpc-service

config/rbac/role.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ rules:
3737
- patch
3838
- update
3939
- watch
40+
- apiGroups:
41+
- ""
42+
resources:
43+
- pods/exec
44+
verbs:
45+
- create
46+
- delete
47+
- get
48+
- list
49+
- watch
4050
- apiGroups:
4151
- ""
4252
resources:

0 commit comments

Comments
 (0)