Skip to content

Commit dad8b16

Browse files
Fix update.Response getting skipped by the proxy (#209)
Fix update.Response getting skipped by the proxy
1 parent 44277ce commit dad8b16

File tree

3 files changed

+243
-8
lines changed

3 files changed

+243
-8
lines changed

cmd/proxygenerator/interceptor.go

+54-6
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ type VisitFailuresOptions struct {
122122
// Context is the same for every call of a visit, callers should not store it.
123123
// Visitor is free to mutate the passed failure struct.
124124
Visitor func(*VisitFailuresContext, *failure.Failure) (error)
125+
// Will be called for each Any encountered. If not set, the default is to recurse into the Any
126+
// object, unmarshal it, visit, and re-marshal it always (even if there are no changes).
127+
WellKnownAnyVisitor func(*VisitFailuresContext, *anypb.Any) error
125128
}
126129
127130
// VisitFailures calls the options.Visitor function for every Failure proto within msg.
@@ -162,6 +165,25 @@ func NewFailureVisitorInterceptor(options FailureVisitorInterceptorOptions) (grp
162165
}, nil
163166
}
164167
168+
func (o *VisitFailuresOptions) defaultWellKnownAnyVisitor(ctx *VisitFailuresContext, p *anypb.Any) error {
169+
child, err := p.UnmarshalNew()
170+
if err != nil {
171+
return fmt.Errorf("failed to unmarshal any: %w", err)
172+
}
173+
// We choose to visit and re-marshal always instead of cloning, visiting,
174+
// and checking if anything changed before re-marshaling. It is assumed the
175+
// clone + equality check is not much cheaper than re-marshal.
176+
if err := visitFailures(ctx, o, child); err != nil {
177+
return err
178+
}
179+
// Confirmed this replaces both Any fields on non-error, there is nothing
180+
// left over
181+
if err := p.MarshalFrom(child); err != nil {
182+
return fmt.Errorf("failed to marshal any: %w", err)
183+
}
184+
return nil
185+
}
186+
165187
func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, p *anypb.Any) error {
166188
child, err := p.UnmarshalNew()
167189
if err != nil {
@@ -299,6 +321,20 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj
299321
if o == nil { continue }
300322
if err := options.Visitor(ctx, o); err != nil { return err }
301323
if err := visitFailures(ctx, options, o.GetCause()); err != nil { return err }
324+
case *anypb.Any:
325+
if o == nil {
326+
continue
327+
}
328+
visitor := options.WellKnownAnyVisitor
329+
if visitor == nil {
330+
visitor = options.defaultWellKnownAnyVisitor
331+
}
332+
ctx.Parent = o
333+
err := visitor(ctx, o)
334+
ctx.Parent = nil
335+
if err != nil {
336+
return err
337+
}
302338
{{range $type, $record := .FailureTypes}}
303339
{{if $record.Slice}}
304340
case []{{$type}}:
@@ -508,17 +544,19 @@ func generateInterceptor(cfg config) error {
508544
if err != nil {
509545
return err
510546
}
511-
// For the purposes of payloads, we also consider the Any well known type as
547+
548+
failureTypes, err := lookupTypes("go.temporal.io/api/failure/v1", []string{"Failure"})
549+
if err != nil {
550+
return err
551+
}
552+
553+
// For the purposes of payloads and failures, we also consider the Any well known type as
512554
// possible
513555
if anyTypes, err := lookupTypes("google.golang.org/protobuf/types/known/anypb", []string{"Any"}); err != nil {
514556
return err
515557
} else {
516558
payloadTypes = append(payloadTypes, anyTypes...)
517-
}
518-
519-
failureTypes, err := lookupTypes("go.temporal.io/api/failure/v1", []string{"Failure"})
520-
if err != nil {
521-
return err
559+
failureTypes = append(failureTypes, anyTypes...)
522560
}
523561

524562
// UnimplementedWorkflowServiceServer is auto-generated via our API package
@@ -542,6 +580,11 @@ func generateInterceptor(cfg config) error {
542580
}
543581
workflowExecutions := types.NewPointer(exportTypes[0])
544582

583+
updateTypes, err := lookupTypes("go.temporal.io/api/update/v1", []string{"Acceptance", "Rejection", "Response"})
584+
if err != nil {
585+
return err
586+
}
587+
545588
payloadRecords := map[string]*TypeRecord{}
546589
failureRecords := map[string]*TypeRecord{}
547590

@@ -572,6 +615,11 @@ func generateInterceptor(cfg config) error {
572615
walk(payloadTypes, workflowExecutions, &payloadRecords, true)
573616
walk(failureTypes, workflowExecutions, &failureRecords, false)
574617

618+
for _, ut := range updateTypes {
619+
walk(payloadTypes, types.NewPointer(ut), &payloadRecords, true)
620+
walk(failureTypes, types.NewPointer(ut), &failureRecords, false)
621+
}
622+
575623
payloadRecords = pruneRecords(payloadRecords)
576624
failureRecords = pruneRecords(failureRecords)
577625

proxy/interceptor.go

+131
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)