diff --git a/chasm/tree.go b/chasm/tree.go index ca94e5b56d9..98c4ca845ce 100644 --- a/chasm/tree.go +++ b/chasm/tree.go @@ -18,6 +18,7 @@ import ( "go.temporal.io/server/common/clock" "go.temporal.io/server/common/definition" "go.temporal.io/server/common/log" + "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/transitionhistory" "go.temporal.io/server/common/softassert" @@ -766,7 +767,13 @@ func (n *Node) closeTransactionUpdateComponentTasks() error { validateContext := NewContext(context.Background(), n) var validationErr error deleteFunc := func(existingTask *persistencespb.ChasmComponentAttributes_Task) bool { - valid, err := node.validateComponentTask(validateContext, existingTask) + existingTaskInstance, err := node.deserializeComponentTask(existingTask) + if err != nil { + validationErr = err + return false + } + + valid, err := node.validateTask(validateContext, existingTaskInstance) if err != nil { validationErr = err return false @@ -826,30 +833,42 @@ func (n *Node) closeTransactionUpdateComponentTasks() error { return nil } -func (n *Node) validateComponentTask( - validateContext Context, +func (n *Node) deserializeComponentTask( componentTask *persistencespb.ChasmComponentAttributes_Task, -) (bool, error) { +) (any, error) { registableTask, ok := n.registry.task(componentTask.Type) if !ok { - return false, serviceerror.NewInternal(fmt.Sprintf("task type %s is not registered", componentTask.Type)) + return nil, serviceerror.NewInternal(fmt.Sprintf("task type %s is not registered", componentTask.Type)) } - // TODO: cache validateMethod (reflect.Value) in the registry - validator := registableTask.validator - validateMethod := reflect.ValueOf(validator).MethodByName("Validate") - // TODO: cache deserialized task value (reflect.Value) in the node, // use task VT and offset as the key - deserizedTaskValue, err := deserializeTask(registableTask, componentTask.Data) + taskValue, err := deserializeTask(registableTask, componentTask.Data) if err != nil { - return false, err + return nil, err + } + + return taskValue.Interface(), nil +} + +func (n *Node) validateTask( + validateContext Context, + taskInstance any, +) (bool, error) { + registableTask, ok := n.registry.taskFor(taskInstance) + if !ok { + return false, serviceerror.NewInternal( + fmt.Sprintf("task type for goType %s is not registered", reflect.TypeOf(taskInstance).Name())) } + // TODO: cache validateMethod (reflect.Value) in the registry + validator := registableTask.validator + validateMethod := reflect.ValueOf(validator).MethodByName("Validate") + retValues := validateMethod.Call([]reflect.Value{ reflect.ValueOf(validateContext), reflect.ValueOf(n.value), - deserizedTaskValue, + reflect.ValueOf(taskInstance), }) if !retValues[1].IsNil() { //revive:disable-next-line:unchecked-type-assertion @@ -1241,6 +1260,66 @@ func (n *Node) isValueNeedSerialize() bool { return false } +// isComponentTaskExpired returns true when the task's scheduled time is equal +// or before the reference time. The caller should also make sure to account +// for skew between the physical task queue and the database by adjusting +// referenceTime in advance. +func isComponentTaskExpired( + referenceTime time.Time, + task *persistencespb.ChasmComponentAttributes_Task, +) bool { + if task.ScheduledTime == nil { + return false + } + + scheduledTime := task.ScheduledTime.AsTime().Truncate(persistence.ScheduledTaskMinPrecision) + referenceTime = referenceTime.Truncate(persistence.ScheduledTaskMinPrecision) + + return !scheduledTime.After(referenceTime) +} + +// EachPureTask runs the callback for all expired/runnable pure tasks within the +// CHASM tree (including invalid tasks). The CHASM tree is left untouched, even +// if invalid tasks are detected (these are cleaned up as part of transaction +// close). +func (n *Node) EachPureTask( + referenceTime time.Time, + callback func(node *Node, task any) error, +) error { + // Walk the tree to find all runnable tasks. + for _, node := range n.andAllChildren() { + // Skip nodes that aren't serialized yet. + if node.serializedNode == nil || node.serializedNode.Metadata == nil { + continue + } + + componentAttr := node.serializedNode.Metadata.GetComponentAttributes() + // Skip nodes that aren't components. + if componentAttr == nil { + continue + } + + for _, task := range componentAttr.GetPureTasks() { + if !isComponentTaskExpired(referenceTime, task) { + // Pure tasks are stored in-order, so we can skip scanning the rest once we hit + // an unexpired task deadline. + break + } + + taskValue, err := node.deserializeComponentTask(task) + if err != nil { + return err + } + + if err = callback(node, taskValue); err != nil { + return err + } + } + } + + return nil +} + func newNode( base *nodeBase, parent *Node, @@ -1433,3 +1512,60 @@ func serializeTask( return blob, nil } + +// ExecutePureTask validates and then executes the given taskInstance against the +// node's component. Executing an invalid task is a no-op (no error returned). +func (n *Node) ExecutePureTask(baseCtx context.Context, taskInstance any) error { + registrableTask, ok := n.registry.taskFor(taskInstance) + if !ok { + return fmt.Errorf("unknown task type for task instance goType '%s'", reflect.TypeOf(taskInstance).Name()) + } + + if !registrableTask.isPureTask { + return fmt.Errorf("ExecutePureTask called on a SideEffect task '%s'", registrableTask.fqType()) + } + + // TODO - instantiate CHASM engine and attach to context + ctx := NewContext(baseCtx, n) + + // Ensure this node's component value is hydrated before execution. Component + // will also check access rules. + component, err := n.Component(ctx, ComponentRef{}) + if err != nil { + return err + } + + // Run the task's registered value before execution. + valid, err := n.validateTask(ctx, taskInstance) + if err != nil { + return err + } + if !valid { + return nil + } + + executor := registrableTask.handler + if executor == nil { + return fmt.Errorf("no handler registered for task type '%s'", registrableTask.taskType) + } + + fn := reflect.ValueOf(executor).MethodByName("Execute") + result := fn.Call([]reflect.Value{ + reflect.ValueOf(ctx), + reflect.ValueOf(component), + reflect.ValueOf(taskInstance), + }) + if !result[0].IsNil() { + //nolint:revive // type cast result is unchecked + return result[0].Interface().(error) + } + + // TODO - a task validator must succeed validation after a task executes + // successfully (without error), otherwise it will generate an infinite loop. + // Check for this case by marking the in-memory task as having executed, which the + // CloseTransaction method will check against. + // + // See: https://github.com/temporalio/temporal/pull/7701#discussion_r2072026993 + + return nil +} diff --git a/chasm/tree_test.go b/chasm/tree_test.go index 6b5f40577b8..b46dffdc6a8 100644 --- a/chasm/tree_test.go +++ b/chasm/tree_test.go @@ -1438,3 +1438,173 @@ func (s *nodeSuite) testComponentTree() *Node { return node // maybe tc too } + +func (s *nodeSuite) TestEachPureTask() { + now := s.timeSource.Now() + + payload := &commonpb.Payload{ + Data: []byte("some-random-data"), + } + taskBlob, err := serialization.ProtoEncodeBlob(payload, enumspb.ENCODING_TYPE_PROTO3) + s.NoError(err) + + // Set up a tree with expired and unexpired pure tasks. + persistenceNodes := map[string]*persistencespb.ChasmNode{ + "": { + Metadata: &persistencespb.ChasmNodeMetadata{ + InitialVersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, + Attributes: &persistencespb.ChasmNodeMetadata_ComponentAttributes{ + ComponentAttributes: &persistencespb.ChasmComponentAttributes{ + Type: "TestLibrary.test_component", + PureTasks: []*persistencespb.ChasmComponentAttributes_Task{ + { + // Expired + Type: "TestLibrary.test_pure_task", + ScheduledTime: timestamppb.New(now), + VersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, + VersionedTransitionOffset: 1, + PhysicalTaskStatus: physicalTaskStatusCreated, + Data: taskBlob, + }, + }, + }, + }, + }, + }, + "child": { + Metadata: &persistencespb.ChasmNodeMetadata{ + InitialVersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, + Attributes: &persistencespb.ChasmNodeMetadata_ComponentAttributes{ + ComponentAttributes: &persistencespb.ChasmComponentAttributes{ + Type: "TestLibrary.test_component", + PureTasks: []*persistencespb.ChasmComponentAttributes_Task{ + { + Type: "TestLibrary.test_pure_task", + // Unexpired + ScheduledTime: timestamppb.New(now.Add(time.Hour)), + VersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, + VersionedTransitionOffset: 1, + PhysicalTaskStatus: physicalTaskStatusCreated, + Data: taskBlob, + }, + }, + }, + }, + }, + }, + "child/grandchild1": { + Metadata: &persistencespb.ChasmNodeMetadata{ + InitialVersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, + Attributes: &persistencespb.ChasmNodeMetadata_ComponentAttributes{ + ComponentAttributes: &persistencespb.ChasmComponentAttributes{ + Type: "TestLibrary.test_component", + PureTasks: []*persistencespb.ChasmComponentAttributes_Task{ + { + Type: "TestLibrary.test_pure_task", + // Expired, and physical task not created + ScheduledTime: timestamppb.New(now), + VersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, + VersionedTransitionOffset: 2, + PhysicalTaskStatus: physicalTaskStatusNone, + Data: taskBlob, + }, + { + Type: "TestLibrary.test_pure_task", + // Expired + ScheduledTime: timestamppb.New(now), + VersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, + VersionedTransitionOffset: 1, + PhysicalTaskStatus: physicalTaskStatusCreated, + Data: taskBlob, + }, + }, + }, + }, + }, + }, + } + + root, err := NewTree(persistenceNodes, s.registry, s.timeSource, s.nodeBackend, s.nodePathEncoder, s.logger) + s.NoError(err) + s.NotNil(root) + + actualTaskCount := 0 + err = root.EachPureTask(now.Add(time.Minute), func(node *Node, task any) error { + s.NotNil(node) + + _, ok := task.(*TestPureTask) + s.True(ok) + + actualTaskCount += 1 + return nil + }) + s.NoError(err) + s.Equal(3, actualTaskCount) +} + +func (s *nodeSuite) TestExecutePureTask() { + persistenceNodes := map[string]*persistencespb.ChasmNode{ + "": { + Metadata: &persistencespb.ChasmNodeMetadata{ + InitialVersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, + Attributes: &persistencespb.ChasmNodeMetadata_ComponentAttributes{ + ComponentAttributes: &persistencespb.ChasmComponentAttributes{ + Type: "TestLibrary.test_component", + }, + }, + }, + }, + } + + pureTask := &TestPureTask{ + Payload: &commonpb.Payload{ + Data: []byte("some-random-data"), + }, + } + + rt, ok := s.registry.Task("TestLibrary.test_pure_task") + s.True(ok) + + root, err := NewTree(persistenceNodes, s.registry, s.timeSource, s.nodeBackend, s.nodePathEncoder, s.logger) + s.NoError(err) + s.NotNil(root) + ctx := context.Background() + + expectExecute := func(result error) { + rt.handler.(*MockPureTaskExecutor[any, *TestPureTask]).EXPECT(). + Execute( + gomock.Any(), + gomock.AssignableToTypeOf(&TestComponent{}), + gomock.Eq(pureTask), + ).Return(result).Times(1) + } + + expectValidate := func(retValue bool, errValue error) { + rt.validator.(*MockTaskValidator[any, *TestPureTask]).EXPECT(). + Validate(gomock.Any(), gomock.Any(), gomock.Any()).Return(retValue, errValue).Times(1) + } + + // Succeed task execution and validation (happy case). + expectExecute(nil) + expectValidate(true, nil) + err = root.ExecutePureTask(ctx, pureTask) + s.NoError(err) + + expectedErr := errors.New("dummy") + + // Succeed validation, fail execution. + expectExecute(expectedErr) + expectValidate(true, nil) + err = root.ExecutePureTask(ctx, pureTask) + s.ErrorIs(expectedErr, err) + + // Fail task validation (no execution occurs). + expectValidate(false, nil) + err = root.ExecutePureTask(ctx, pureTask) + s.NoError(err) + + // Error during task validation (no execution occurs). + expectValidate(false, expectedErr) + err = root.ExecutePureTask(ctx, pureTask) + s.ErrorIs(expectedErr, err) +} diff --git a/service/history/interfaces/chasm_tree.go b/service/history/interfaces/chasm_tree.go index daf46dfc913..a3562f0ee2b 100644 --- a/service/history/interfaces/chasm_tree.go +++ b/service/history/interfaces/chasm_tree.go @@ -3,6 +3,8 @@ package interfaces import ( + "time" + persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/chasm" ) @@ -15,4 +17,9 @@ type ChasmTree interface { ApplyMutation(chasm.NodesMutation) error ApplySnapshot(chasm.NodesSnapshot) error IsDirty() bool + + EachPureTask( + deadline time.Time, + callback func(node *chasm.Node, task any) error, + ) error } diff --git a/service/history/interfaces/chasm_tree_mock.go b/service/history/interfaces/chasm_tree_mock.go index c157103e44f..6a45bfe32c4 100644 --- a/service/history/interfaces/chasm_tree_mock.go +++ b/service/history/interfaces/chasm_tree_mock.go @@ -11,6 +11,7 @@ package interfaces import ( reflect "reflect" + time "time" persistence "go.temporal.io/server/api/persistence/v1" chasm "go.temporal.io/server/chasm" @@ -84,6 +85,20 @@ func (mr *MockChasmTreeMockRecorder) CloseTransaction() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseTransaction", reflect.TypeOf((*MockChasmTree)(nil).CloseTransaction)) } +// EachPureTask mocks base method. +func (m *MockChasmTree) EachPureTask(deadline time.Time, callback func(*chasm.Node, any) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EachPureTask", deadline, callback) + ret0, _ := ret[0].(error) + return ret0 +} + +// EachPureTask indicates an expected call of EachPureTask. +func (mr *MockChasmTreeMockRecorder) EachPureTask(deadline, callback any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EachPureTask", reflect.TypeOf((*MockChasmTree)(nil).EachPureTask), deadline, callback) +} + // IsDirty mocks base method. func (m *MockChasmTree) IsDirty() bool { m.ctrl.T.Helper() diff --git a/service/history/workflow/noop_chasm_tree.go b/service/history/workflow/noop_chasm_tree.go index ddada4203b3..92e98ce4c62 100644 --- a/service/history/workflow/noop_chasm_tree.go +++ b/service/history/workflow/noop_chasm_tree.go @@ -1,6 +1,8 @@ package workflow import ( + "time" + persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/chasm" historyi "go.temporal.io/server/service/history/interfaces" @@ -29,3 +31,10 @@ func (*noopChasmTree) ApplySnapshot(chasm.NodesSnapshot) error { func (*noopChasmTree) IsDirty() bool { return false } + +func (*noopChasmTree) EachPureTask( + deadline time.Time, + callback func(node *chasm.Node, task any) error, +) error { + return nil +}