diff --git a/core/scheduler/src/main/scala/org/apache/openwhisk/core/scheduler/queue/ContainerCounter.scala b/core/scheduler/src/main/scala/org/apache/openwhisk/core/scheduler/queue/ContainerCounter.scala new file mode 100644 index 00000000000..7859a190d7e --- /dev/null +++ b/core/scheduler/src/main/scala/org/apache/openwhisk/core/scheduler/queue/ContainerCounter.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.openwhisk.core.scheduler.queue + +import java.util.concurrent.atomic.AtomicInteger + +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import org.apache.openwhisk.common.Logging +import org.apache.openwhisk.core.etcd.EtcdClient +import org.apache.openwhisk.core.etcd.EtcdKV.ContainerKeys +import org.apache.openwhisk.core.service.{DeleteEvent, PutEvent, UnwatchEndpoint, WatchEndpoint, WatchEndpointOperation} + +import scala.collection.concurrent.TrieMap +import scala.concurrent.{ExecutionContext, Future} + +class ContainerCounter(invocationNamespace: String, etcdClient: EtcdClient, watcherService: ActorRef)( + implicit val actorSystem: ActorSystem, + ec: ExecutionContext, + logging: Logging) { + private[queue] var existingContainerNumByNamespace: Int = 0 + private[queue] var inProgressContainerNumByNamespace: Int = 0 + private[queue] val references = new AtomicInteger(0) + private val watcherName = s"container-counter-$invocationNamespace" + + private val inProgressContainerPrefixKeyByNamespace = + ContainerKeys.inProgressContainerPrefixByNamespace(invocationNamespace) + private val existingContainerPrefixKeyByNamespace = + ContainerKeys.existingContainersPrefixByNamespace(invocationNamespace) + + private val watchedKeys = Seq(inProgressContainerPrefixKeyByNamespace, existingContainerPrefixKeyByNamespace) + + private val watcher = + actorSystem.actorOf(Props(new Actor { + private var countingKeys = Set.empty[String] + private var waitingForCountKeys = Set.empty[String] + + override def receive: Receive = { + case operation: WatchEndpointOperation if operation.isPrefix => + if (countingKeys + .contains(operation.watchKey)) + waitingForCountKeys += operation.watchKey + else { + countingKeys += operation.watchKey + refreshContainerCount(operation.watchKey) + } + + case ReadyToGetCount(key) => + if (waitingForCountKeys.contains(key)) { + waitingForCountKeys -= key + refreshContainerCount(key) + } else + countingKeys -= key + } + })) + + private def refreshContainerCount(key: String): Future[Unit] = { + etcdClient + .getCount(key) + .map { count => + key match { + case `inProgressContainerPrefixKeyByNamespace` => inProgressContainerNumByNamespace = count.toInt + case `existingContainerPrefixKeyByNamespace` => existingContainerNumByNamespace = count.toInt + } + watcher ! ReadyToGetCount(key) + } + .recover { + case t: Throwable => + logging.error( + this, + s"failed to get the number of existing containers for ${invocationNamespace} due to ${t}.") + watcher ! ReadyToGetCount(key) + } + } + + def increaseReference(): ContainerCounter = { + if (references.incrementAndGet() == 1) { + watchedKeys.foreach { key => + watcherService.tell(WatchEndpoint(key, "", true, watcherName, Set(PutEvent, DeleteEvent)), watcher) + } + + } + this + } + + def close(): Unit = { + if (references.decrementAndGet() == 0) { + watchedKeys.foreach { key => + watcherService ! UnwatchEndpoint(key, true, watcherName) + } + NamespaceContainerCount.instances.remove(invocationNamespace) + } + } +} + +object NamespaceContainerCount { + private[queue] val instances = TrieMap[String, ContainerCounter]() + def apply(namespace: String, etcdClient: EtcdClient, watcherService: ActorRef)(implicit actorSystem: ActorSystem, + ec: ExecutionContext, + logging: Logging): ContainerCounter = { + instances + .getOrElseUpdate(namespace, new ContainerCounter(namespace, etcdClient, watcherService)) + .increaseReference() + } +} + +case class ReadyToGetCount(key: String) diff --git a/tests/src/test/scala/org/apache/openwhisk/core/scheduler/queue/test/ContainerCounterTests.scala b/tests/src/test/scala/org/apache/openwhisk/core/scheduler/queue/test/ContainerCounterTests.scala new file mode 100644 index 00000000000..e9e6694e046 --- /dev/null +++ b/tests/src/test/scala/org/apache/openwhisk/core/scheduler/queue/test/ContainerCounterTests.scala @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.openwhisk.core.scheduler.queue.test + +import java.{lang, util} +import java.util.concurrent.Executor + +import akka.actor.ActorSystem +import akka.testkit.{TestKit, TestProbe} +import com.google.protobuf.ByteString +import com.ibm.etcd.api.Event.EventType +import com.ibm.etcd.api.{Event, KeyValue, LeaseKeepAliveResponse, ResponseHeader, TxnResponse} +import com.ibm.etcd.client.kv.KvClient.Watch +import com.ibm.etcd.client.kv.WatchUpdate +import com.ibm.etcd.client.{EtcdClient => Client} +import common.StreamLogging +import org.apache.openwhisk.core.entity.{ + CreationId, + DocRevision, + EntityName, + EntityPath, + FullyQualifiedEntityName, + SchedulerInstanceId +} +import org.apache.openwhisk.core.etcd.EtcdClient +import org.apache.openwhisk.core.etcd.EtcdKV.ContainerKeys +import org.apache.openwhisk.core.etcd.EtcdKV.ContainerKeys.inProgressContainer +import org.apache.openwhisk.core.scheduler.queue.NamespaceContainerCount +import org.apache.openwhisk.core.service.{DeleteEvent, PutEvent, UnwatchEndpoint, WatchEndpoint, WatcherService} +import org.junit.runner.RunWith +import org.scalamock.scalatest.MockFactory +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.{FlatSpecLike, Matchers} +import org.scalatest.junit.JUnitRunner + +import scala.concurrent.Future +import scala.concurrent.duration.TimeUnit + +@RunWith(classOf[JUnitRunner]) +class ContainerCounterTests + extends TestKit(ActorSystem("ContainerCounter")) + with FlatSpecLike + with Matchers + with MockFactory + with ScalaFutures + with StreamLogging { + + private implicit val ec = system.dispatcher + + private val namespace = "testNamespace" + private val namespace2 = "testNamespace2" + private val action = "testAction" + private val action2 = "testAction2" + private val schedulerId = SchedulerInstanceId("0") + private val fqn = FullyQualifiedEntityName(EntityPath(namespace), EntityName(action)) + private val revision = DocRevision("1-testRev1") + private val fqn2 = FullyQualifiedEntityName(EntityPath(namespace), EntityName(action2)) + private val revision2 = DocRevision("1-testRev2") + private val fqn3 = FullyQualifiedEntityName(EntityPath(namespace2), EntityName(action2)) + private val revision3 = DocRevision("1-testRev3") + private val watcherName = s"container-counter-$namespace" + private val inProgressContainerPrefixKeyByNamespace = + ContainerKeys.inProgressContainerPrefixByNamespace(namespace) + private val existingContainerPrefixKeyByNamespace = + ContainerKeys.existingContainersPrefixByNamespace(namespace) + + val client: Client = { + val hostAndPorts = "172.17.0.1:2379" + Client.forEndpoints(hostAndPorts).withPlainText().build() + } + + it should "be shared for a same namespace" in { + val etcd = mock[EtcdClient] + val watcher = TestProbe() + val res = Future.sequence { + (0 to 99).map { _ => + Future { + NamespaceContainerCount(namespace, etcd, watcher.ref) + } + } + }.futureValue + + // only create one instance + res.toSet.size shouldBe 1 + res.head.references.intValue shouldBe 100 + + // only register watch endpoint once + watcher.expectMsgAllOf( + WatchEndpoint(inProgressContainerPrefixKeyByNamespace, "", true, watcherName, Set(PutEvent, DeleteEvent)), + WatchEndpoint(existingContainerPrefixKeyByNamespace, "", true, watcherName, Set(PutEvent, DeleteEvent))) + watcher.expectNoMessage() + NamespaceContainerCount.instances.size shouldBe 1 + NamespaceContainerCount.instances.clear() + } + + it should "and only should be closed when all references are closed" in { + val etcd = mock[EtcdClient] + val watcher = TestProbe() + val res = Future.sequence { + (0 to 99).map { _ => + Future { + NamespaceContainerCount(namespace, etcd, watcher.ref) + } + } + }.futureValue + + // only create one instance + res.toSet.size shouldBe 1 + res.head.references.intValue shouldBe 100 + + // only register watch endpoint once + watcher.expectMsgAllOf( + WatchEndpoint(inProgressContainerPrefixKeyByNamespace, "", true, watcherName, Set(PutEvent, DeleteEvent)), + WatchEndpoint(existingContainerPrefixKeyByNamespace, "", true, watcherName, Set(PutEvent, DeleteEvent))) + watcher.expectNoMessage() + NamespaceContainerCount.instances.size shouldBe 1 + + // close 50 times + Future.sequence { + (0 to 49).map { _ => + Future(res.head.close()) + } + }.futureValue + res.head.references.intValue shouldBe 50 + + // should not unregister watch endpoint + watcher.expectNoMessage() + NamespaceContainerCount.instances.size shouldBe 1 + + // close left 50 times + Future.sequence { + (0 to 49).map { _ => + Future(res.head.close()) + } + }.futureValue + res.head.references.intValue shouldBe 0 + + // only unregister watch endpoint once + watcher.expectMsgAllOf( + UnwatchEndpoint(inProgressContainerPrefixKeyByNamespace, true, watcherName), + UnwatchEndpoint(existingContainerPrefixKeyByNamespace, true, watcherName)) + watcher.expectNoMessage() + NamespaceContainerCount.instances.size shouldBe 0 + } + + it should "update the number of containers based on Watch event" in { + val mockEtcdClient = new MockEtcdClient(client, true) + val watcher = system.actorOf(WatcherService.props(mockEtcdClient)) + + val ns = NamespaceContainerCount(namespace, mockEtcdClient, watcher) + Thread.sleep(1000) + + ns.inProgressContainerNumByNamespace shouldBe 0 + ns.existingContainerNumByNamespace shouldBe 0 + + val invoker = "invoker0" + + mockEtcdClient.publishEvents( + EventType.PUT, + inProgressContainer(namespace, fqn, revision, schedulerId, CreationId("testId")), + "test-value") + + mockEtcdClient.publishEvents( + EventType.PUT, + s"${ContainerKeys.existingContainers(namespace, fqn, DocRevision.empty)}/${invoker}/test-container", + "test-value") + + Thread.sleep(1000) + ns.inProgressContainerNumByNamespace shouldBe 1 + ns.existingContainerNumByNamespace shouldBe 1 + + // other action's containers under same namespace should have effect + mockEtcdClient.publishEvents( + EventType.PUT, + inProgressContainer(namespace, fqn2, revision2, schedulerId, CreationId("testId2")), + "test-value") + + mockEtcdClient.publishEvents( + EventType.PUT, + s"${ContainerKeys.existingContainers(namespace, fqn2, DocRevision.empty)}/${invoker}/test-container2", + "test-value") + + Thread.sleep(1000) + ns.inProgressContainerNumByNamespace shouldBe 2 + ns.existingContainerNumByNamespace shouldBe 2 + + // other namespace's containers should have no influence + mockEtcdClient.publishEvents( + EventType.PUT, + inProgressContainer(namespace2, fqn3, revision3, schedulerId, CreationId("testId3")), + "test-value") + + mockEtcdClient.publishEvents( + EventType.PUT, + s"${ContainerKeys.existingContainers(namespace2, fqn3, DocRevision.empty)}/${invoker}/test-container3", + "test-value") + + Thread.sleep(1000) + ns.inProgressContainerNumByNamespace shouldBe 2 + ns.existingContainerNumByNamespace shouldBe 2 + + // inProgress containers should have no effect on existing containers + mockEtcdClient.publishEvents( + EventType.DELETE, + inProgressContainer(namespace, fqn, revision, schedulerId, CreationId("testId")), + "test-value") + + mockEtcdClient.publishEvents( + EventType.DELETE, + inProgressContainer(namespace, fqn2, revision2, schedulerId, CreationId("testId2")), + "test-value") + + Thread.sleep(1000) + ns.inProgressContainerNumByNamespace shouldBe 0 + ns.existingContainerNumByNamespace shouldBe 2 + + // existing containers should have no effect on inProgress containers + mockEtcdClient.publishEvents( + EventType.DELETE, + s"${ContainerKeys.existingContainers(namespace, fqn, DocRevision.empty)}/${invoker}/test-container", + "test-value") + + mockEtcdClient.publishEvents( + EventType.DELETE, + s"${ContainerKeys.existingContainers(namespace, fqn2, DocRevision.empty)}/${invoker}/test-container2", + "test-value") + + Thread.sleep(1000) + ns.inProgressContainerNumByNamespace shouldBe 0 + ns.existingContainerNumByNamespace shouldBe 0 + + NamespaceContainerCount.instances.clear() + } + + class MockEtcdClient(client: Client, isLeader: Boolean, leaseNotFound: Boolean = false, failedCount: Int = 1) + extends EtcdClient(client)(ec) { + var count = 0 + var storedValues = List.empty[(String, String, Long, Long)] + var dataMap = Map[String, String]() + + override def putTxn[T](key: String, value: T, cmpVersion: Long, leaseId: Long): Future[TxnResponse] = { + if (isLeader) { + storedValues = (key, value.toString, cmpVersion, leaseId) :: storedValues + } + Future.successful(TxnResponse.newBuilder().setSucceeded(isLeader).build()) + } + + /* + * this method count the number of entries whose key starts with the given prefix + */ + override def getCount(prefixKey: String): Future[Long] = { + Future.successful { dataMap.count(data => data._1.startsWith(prefixKey)) } + } + + var watchCallbackMap = Map[String, WatchUpdate => Unit]() + + override def keepAliveOnce(leaseId: Long): Future[LeaseKeepAliveResponse] = + Future.successful(LeaseKeepAliveResponse.newBuilder().setID(leaseId).build()) + + /* + * this method adds one callback for the given key in watchCallbackMap. + * + * Note: Currently it only supports prefix-based watch. + */ + override def watchAllKeys(next: WatchUpdate => Unit, error: Throwable => Unit, completed: () => Unit): Watch = { + + watchCallbackMap += "" -> next + new Watch { + override def close(): Unit = {} + + override def addListener(listener: Runnable, executor: Executor): Unit = {} + + override def cancel(mayInterruptIfRunning: Boolean): Boolean = true + + override def isCancelled: Boolean = true + + override def isDone: Boolean = true + + override def get(): lang.Boolean = true + + override def get(timeout: Long, unit: TimeUnit): lang.Boolean = true + } + } + + /* + * This method stores the data in dataMap to simulate etcd.put() + * After then, it calls the registered watch callback for the given key + * So we don't need to call put() to simulate watch API. + * Expected order of calls is 1. watch(), 2.publishEvents(). Data will be stored in dataMap and + * callbacks in the callbackMap for the given prefix will be called by publishEvents() + * + * Note: watch callback is currently registered based on prefix only. + */ + def publishEvents(eventType: EventType, key: String, value: String): Unit = { + val eType = eventType match { + case EventType.PUT => + dataMap += key -> value + EventType.PUT + + case EventType.DELETE => + dataMap -= key + EventType.DELETE + + case EventType.UNRECOGNIZED => Event.EventType.UNRECOGNIZED + } + val event = Event + .newBuilder() + .setType(eType) + .setPrevKv( + KeyValue + .newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .setValue(ByteString.copyFromUtf8(value)) + .build()) + .setKv( + KeyValue + .newBuilder() + .setKey(ByteString.copyFromUtf8(key)) + .setValue(ByteString.copyFromUtf8(value)) + .build()) + .build() + + // find the callbacks which has the proper prefix for the given key + watchCallbackMap.filter(callback => key.startsWith(callback._1)).foreach { callback => + callback._2(new mockWatchUpdate().addEvents(event)) + } + } + } + + class mockWatchUpdate extends WatchUpdate { + private var eventLists: util.List[Event] = new util.ArrayList[Event]() + override def getHeader: ResponseHeader = ??? + + def addEvents(event: Event): WatchUpdate = { + eventLists.add(event) + this + } + + override def getEvents: util.List[Event] = eventLists + } +}