Skip to content

Commit 4e1757c

Browse files
committed
feat: allow use of providers that don't return errors
1 parent 9c08a58 commit 4e1757c

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

callbacks.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,17 @@ func (b bindings) addTo(impl, iface any) {
3434
func (b bindings) addProvider(provider any) error {
3535
pv := reflect.ValueOf(provider)
3636
t := pv.Type()
37-
if t.Kind() != reflect.Func || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
38-
return fmt.Errorf("%T must be a function with the signature func(...)(T, error)", provider)
37+
if t.Kind() != reflect.Func {
38+
return fmt.Errorf("%T must be a function", provider)
39+
}
40+
41+
if t.NumOut() == 0 {
42+
return fmt.Errorf("%T must be a function with the signature func(...)(T, error) or func(...) T", provider)
43+
}
44+
if t.NumOut() == 2 {
45+
if t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() {
46+
return fmt.Errorf("missing error; %T must be a function with the signature func(...)(T, error) or func(...) T", provider)
47+
}
3948
}
4049
rt := pv.Type().Out(0)
4150
b[rt] = provider
@@ -143,7 +152,7 @@ func callAnyFunction(f reflect.Value, bindings bindings) (out []any, err error)
143152
if err != nil {
144153
return nil, fmt.Errorf("%s: %w", pt, err)
145154
}
146-
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && !ferrv.IsNil() {
155+
if ferrv := reflect.ValueOf(argv[len(argv)-1]); ferrv.IsValid() && ferrv.Type().Implements(callbackReturnSignature) && !ferrv.IsNil() {
147156
return nil, ferrv.Interface().(error) //nolint:forcetypeassert
148157
}
149158
in = append(in, reflect.ValueOf(argv[0]))

context.go

+3
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ func (c *Context) BindTo(impl, iface any) {
119119
//
120120
// This is useful when the Run() function of different commands require different values that may
121121
// not all be initialisable from the main() function.
122+
//
123+
// "provider" must be a function with the signature func(...) (T, error) or func(...) T, where
124+
// ... will be recursively injected with bound values.
122125
func (c *Context) BindToProvider(provider any) error {
123126
return c.bindings.addProvider(provider)
124127
}

kong_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -2521,3 +2521,23 @@ func TestIssue483EmptyRootNodeNoRun(t *testing.T) {
25212521
assert.Error(t, err)
25222522
assert.Contains(t, err.Error(), "no command selected")
25232523
}
2524+
2525+
type providerWithoutErrorCLI struct {
2526+
}
2527+
2528+
func (p *providerWithoutErrorCLI) Run(name string) error {
2529+
if name == "Bob" {
2530+
return nil
2531+
}
2532+
return fmt.Errorf("name %s is not Bob", name)
2533+
}
2534+
2535+
func TestProviderWithoutError(t *testing.T) {
2536+
k := mustNew(t, &providerWithoutErrorCLI{})
2537+
kctx, err := k.Parse(nil)
2538+
assert.NoError(t, err)
2539+
err = kctx.BindToProvider(func() string { return "Bob" })
2540+
assert.NoError(t, err)
2541+
err = kctx.Run()
2542+
assert.NoError(t, err)
2543+
}

0 commit comments

Comments
 (0)