@@ -122,6 +122,9 @@ type VisitFailuresOptions struct {
122
122
// Context is the same for every call of a visit, callers should not store it.
123
123
// Visitor is free to mutate the passed failure struct.
124
124
Visitor func(*VisitFailuresContext, *failure.Failure) (error)
125
+ // Will be called for each Any encountered. If not set, the default is to recurse into the Any
126
+ // object, unmarshal it, visit, and re-marshal it always (even if there are no changes).
127
+ WellKnownAnyVisitor func(*VisitFailuresContext, *anypb.Any) error
125
128
}
126
129
127
130
// VisitFailures calls the options.Visitor function for every Failure proto within msg.
@@ -162,6 +165,25 @@ func NewFailureVisitorInterceptor(options FailureVisitorInterceptorOptions) (grp
162
165
}, nil
163
166
}
164
167
168
+ func (o *VisitFailuresOptions) defaultWellKnownAnyVisitor(ctx *VisitFailuresContext, p *anypb.Any) error {
169
+ child, err := p.UnmarshalNew()
170
+ if err != nil {
171
+ return fmt.Errorf("failed to unmarshal any: %w", err)
172
+ }
173
+ // We choose to visit and re-marshal always instead of cloning, visiting,
174
+ // and checking if anything changed before re-marshaling. It is assumed the
175
+ // clone + equality check is not much cheaper than re-marshal.
176
+ if err := visitFailures(ctx, o, child); err != nil {
177
+ return err
178
+ }
179
+ // Confirmed this replaces both Any fields on non-error, there is nothing
180
+ // left over
181
+ if err := p.MarshalFrom(child); err != nil {
182
+ return fmt.Errorf("failed to marshal any: %w", err)
183
+ }
184
+ return nil
185
+ }
186
+
165
187
func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, p *anypb.Any) error {
166
188
child, err := p.UnmarshalNew()
167
189
if err != nil {
@@ -299,6 +321,20 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj
299
321
if o == nil { continue }
300
322
if err := options.Visitor(ctx, o); err != nil { return err }
301
323
if err := visitFailures(ctx, options, o.GetCause()); err != nil { return err }
324
+ case *anypb.Any:
325
+ if o == nil {
326
+ continue
327
+ }
328
+ visitor := options.WellKnownAnyVisitor
329
+ if visitor == nil {
330
+ visitor = options.defaultWellKnownAnyVisitor
331
+ }
332
+ ctx.Parent = o
333
+ err := visitor(ctx, o)
334
+ ctx.Parent = nil
335
+ if err != nil {
336
+ return err
337
+ }
302
338
{{range $type, $record := .FailureTypes}}
303
339
{{if $record.Slice}}
304
340
case []{{$type}}:
@@ -508,17 +544,19 @@ func generateInterceptor(cfg config) error {
508
544
if err != nil {
509
545
return err
510
546
}
511
- // For the purposes of payloads, we also consider the Any well known type as
547
+
548
+ failureTypes , err := lookupTypes ("go.temporal.io/api/failure/v1" , []string {"Failure" })
549
+ if err != nil {
550
+ return err
551
+ }
552
+
553
+ // For the purposes of payloads and failures, we also consider the Any well known type as
512
554
// possible
513
555
if anyTypes , err := lookupTypes ("google.golang.org/protobuf/types/known/anypb" , []string {"Any" }); err != nil {
514
556
return err
515
557
} else {
516
558
payloadTypes = append (payloadTypes , anyTypes ... )
517
- }
518
-
519
- failureTypes , err := lookupTypes ("go.temporal.io/api/failure/v1" , []string {"Failure" })
520
- if err != nil {
521
- return err
559
+ failureTypes = append (failureTypes , anyTypes ... )
522
560
}
523
561
524
562
// UnimplementedWorkflowServiceServer is auto-generated via our API package
@@ -542,6 +580,11 @@ func generateInterceptor(cfg config) error {
542
580
}
543
581
workflowExecutions := types .NewPointer (exportTypes [0 ])
544
582
583
+ updateTypes , err := lookupTypes ("go.temporal.io/api/update/v1" , []string {"Acceptance" , "Rejection" , "Response" })
584
+ if err != nil {
585
+ return err
586
+ }
587
+
545
588
payloadRecords := map [string ]* TypeRecord {}
546
589
failureRecords := map [string ]* TypeRecord {}
547
590
@@ -572,6 +615,11 @@ func generateInterceptor(cfg config) error {
572
615
walk (payloadTypes , workflowExecutions , & payloadRecords , true )
573
616
walk (failureTypes , workflowExecutions , & failureRecords , false )
574
617
618
+ for _ , ut := range updateTypes {
619
+ walk (payloadTypes , types .NewPointer (ut ), & payloadRecords , true )
620
+ walk (failureTypes , types .NewPointer (ut ), & failureRecords , false )
621
+ }
622
+
575
623
payloadRecords = pruneRecords (payloadRecords )
576
624
failureRecords = pruneRecords (failureRecords )
577
625
0 commit comments