diff --git a/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala b/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala index 88f4f11b4f5..ba05c17964c 100644 --- a/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala +++ b/common/scala/src/main/scala/org/apache/openwhisk/core/connector/Message.scala @@ -60,7 +60,7 @@ case class ActivationMessage(override val transid: TransactionId, lockedArgs: Map[String, String] = Map.empty, cause: Option[ActivationId] = None, traceContext: Option[Map[String, String]] = None) - extends Message { + extends Message { override def serialize = ActivationMessage.serdes.write(this).compactPrint @@ -78,6 +78,7 @@ case class ActivationMessage(override val transid: TransactionId, */ abstract class AcknowledegmentMessage(private val tid: TransactionId) extends Message { override val transid: TransactionId = tid + override def serialize: String = AcknowledegmentMessage.serdes.write(this).compactPrint /** Pithy descriptor for logging. */ @@ -115,17 +116,23 @@ abstract class AcknowledegmentMessage(private val tid: TransactionId) extends Me * The constructor is private so that callers must use the more restrictive constructors which ensure the respose is always * Right when this message is created. */ -case class CombinedCompletionAndResultMessage private (override val transid: TransactionId, - response: Either[ActivationId, WhiskActivation], - override val isSystemError: Option[Boolean], - instance: InstanceId) - extends AcknowledegmentMessage(transid) { +case class CombinedCompletionAndResultMessage private(override val transid: TransactionId, + response: Either[ActivationId, WhiskActivation], + override val isSystemError: Option[Boolean], + instance: InstanceId) + extends AcknowledegmentMessage(transid) { override def messageType = "combined" + override def result = Some(response) + override def isSlotFree = Some(instance) + override def activationId = response.fold(identity, _.activationId) + override def toJson = CombinedCompletionAndResultMessage.serdes.write(this) + override def shrink = copy(response = response.flatMap(a => Left(a.activationId))) + override def toString = activationId.asString } @@ -135,16 +142,21 @@ case class CombinedCompletionAndResultMessage private (override val transid: Tra * phase notification to the load balancer where an invoker first sends a `ResultMessage` and later sends the * `CompletionMessage`. */ -case class CompletionMessage private (override val transid: TransactionId, - override val activationId: ActivationId, - override val isSystemError: Option[Boolean], - instance: InstanceId) - extends AcknowledegmentMessage(transid) { +case class CompletionMessage private(override val transid: TransactionId, + override val activationId: ActivationId, + override val isSystemError: Option[Boolean], + instance: InstanceId) + extends AcknowledegmentMessage(transid) { override def messageType = "completion" + override def result = None + override def isSlotFree = Some(instance) + override def toJson = CompletionMessage.serdes.write(this) + override def shrink = this + override def toString = activationId.asString } @@ -156,15 +168,22 @@ case class CompletionMessage private (override val transid: TransactionId, * The constructor is private so that callers must use the more restrictive constructors which ensure the respose is always * Right when this message is created. */ -case class ResultMessage private (override val transid: TransactionId, response: Either[ActivationId, WhiskActivation]) - extends AcknowledegmentMessage(transid) { +case class ResultMessage private(override val transid: TransactionId, response: Either[ActivationId, WhiskActivation]) + extends AcknowledegmentMessage(transid) { override def messageType = "result" + override def result = Some(response) + override def isSlotFree = None + override def isSystemError = response.fold(_ => None, a => Some(a.response.isWhiskError)) + override def activationId = response.fold(identity, _.activationId) + override def toJson = ResultMessage.serdes.write(this) + override def shrink = copy(response = response.flatMap(a => Left(a.activationId))) + override def toString = activationId.asString } @@ -234,7 +253,7 @@ object AcknowledegmentMessage extends DefaultJsonProtocol { Left(value.convertTo[ActivationId]) case _: JsObject => Right(value.convertTo[WhiskActivation]) - case _ => deserializationError("could not read ResultMessage") + case _ => deserializationError("could not read ResultMessage") } } @@ -265,6 +284,7 @@ case class PingMessage(instance: InvokerInstanceId) extends Message { object PingMessage extends DefaultJsonProtocol { def parse(msg: String) = Try(serdes.read(msg.parseJson)) + implicit val serdes = jsonFormat(PingMessage.apply _, "name") } @@ -276,7 +296,7 @@ object EventMessageBody extends DefaultJsonProtocol { implicit val format = new JsonFormat[EventMessageBody] { def write(eventMessageBody: EventMessageBody) = eventMessageBody match { - case m: Metric => m.toJson + case m: Metric => m.toJson case a: Activation => a.toJson } @@ -301,9 +321,11 @@ case class Activation(name: String, causedBy: Option[String], size: Option[Int] = None, userDefinedStatusCode: Option[Int] = None) - extends EventMessageBody { + extends EventMessageBody { val typeName = Activation.typeName + override def serialize = toJson.compactPrint + def entityPath: FullyQualifiedEntityName = EntityPath(name).toFullyQualifiedEntityName def toJson = Activation.activationFormat.write(this) @@ -327,12 +349,12 @@ object Activation extends DefaultJsonProtocol { private implicit val durationFormat = new RootJsonFormat[Duration] { override def write(obj: Duration): JsValue = obj match { case o if o.isFinite => JsNumber(o.toMillis) - case _ => JsNumber.zero + case _ => JsNumber.zero } override def read(json: JsValue): Duration = json match { case JsNumber(n) if n <= 0 => Duration.Zero - case JsNumber(n) => toDuration(n.longValue) + case JsNumber(n) => toDuration(n.longValue) } } @@ -352,7 +374,7 @@ object Activation extends DefaultJsonProtocol { "size", "userDefinedStatusCode") - /** Get "StatusCode" from result response set by action developer **/ + /** Get "StatusCode" from result response set by action developer * */ def userDefinedStatusCode(result: Option[JsValue]): Option[Int] = { val statusCode = JsHelpers .getFieldPath(result.get.asJsObject, ERROR_FIELD, "statusCode") @@ -394,13 +416,17 @@ object Activation extends DefaultJsonProtocol { case class Metric(metricName: String, metricValue: Long) extends EventMessageBody { val typeName = "Metric" + override def serialize = toJson.compactPrint + def toJson = Metric.metricFormat.write(this).asJsObject } object Metric extends DefaultJsonProtocol { val typeName = "Metric" + def parse(msg: String) = Try(metricFormat.read(msg.parseJson)) + implicit val metricFormat = jsonFormat(Metric.apply _, "metricName", "metricValue") } @@ -411,7 +437,7 @@ case class EventMessage(source: String, userId: UUID, eventType: String, timestamp: Long = System.currentTimeMillis()) - extends Message { + extends Message { override def serialize = EventMessage.format.write(this).compactPrint } @@ -434,7 +460,7 @@ case class InvokerResourceMessage(status: String, inProgressMemory: Long, tags: Seq[String], dedicatedNamespaces: Seq[String]) - extends Message { + extends Message { /** * Serializes message to string. Must be idempotent. @@ -444,6 +470,7 @@ case class InvokerResourceMessage(status: String, object InvokerResourceMessage extends DefaultJsonProtocol { def parse(msg: String): Try[InvokerResourceMessage] = Try(serdes.read(msg.parseJson)) + implicit val serdes = jsonFormat( InvokerResourceMessage.apply _, @@ -462,23 +489,25 @@ object InvokerResourceMessage extends DefaultJsonProtocol { * * [ * ... - * { - * "data": "RunningData", - * "fqn": "whisk.system/elasticsearch/status-alarm@0.0.2", - * "invocationNamespace": "style95", - * "status": "Running", - * "waitingActivation": 1 - * }, + * { + * "data": "RunningData", + * "fqn": "whisk.system/elasticsearch/status-alarm@0.0.2", + * "invocationNamespace": "style95", + * "status": "Running", + * "waitingActivation": 1 + * }, * ... * ] */ object StatusQuery + case class StatusData(invocationNamespace: String, fqn: String, waitingActivation: Int, status: String, data: String) - extends Message { + extends Message { override def serialize: String = StatusData.serdes.write(this).compactPrint } + object StatusData extends DefaultJsonProtocol { implicit val serdes = @@ -495,9 +524,10 @@ case class ContainerCreationMessage(override val transid: TransactionId, rpcPort: Int, retryCount: Int = 0, creationId: CreationId = CreationId.generate()) - extends ContainerMessage(transid) { + extends ContainerMessage(transid) { override def toJson: JsValue = ContainerCreationMessage.serdes.write(this) + override def serialize: String = toJson.compactPrint } @@ -526,8 +556,9 @@ case class ContainerDeletionMessage(override val transid: TransactionId, action: FullyQualifiedEntityName, revision: DocRevision, whiskActionMetaData: WhiskActionMetaData) - extends ContainerMessage(transid) { + extends ContainerMessage(transid) { override def toJson: JsValue = ContainerDeletionMessage.serdes.write(this) + override def serialize: String = toJson.compactPrint } @@ -544,6 +575,7 @@ object ContainerDeletionMessage extends DefaultJsonProtocol { abstract class ContainerMessage(private val tid: TransactionId) extends Message { override val transid: TransactionId = tid + override def serialize: String = ContainerMessage.serdes.write(this).compactPrint /** Serializes the message to JSON. */ @@ -569,18 +601,31 @@ object ContainerMessage extends DefaultJsonProtocol { } sealed trait ContainerCreationError + object ContainerCreationError extends Enumeration { + case object NoAvailableInvokersError extends ContainerCreationError + case object NoAvailableResourceInvokersError extends ContainerCreationError + case object ResourceNotEnoughError extends ContainerCreationError + case object WhiskError extends ContainerCreationError + case object UnknownError extends ContainerCreationError + case object TimeoutError extends ContainerCreationError + case object ShuttingDownError extends ContainerCreationError + case object NonExecutableActionError extends ContainerCreationError + case object DBFetchError extends ContainerCreationError + case object BlackBoxError extends ContainerCreationError + case object ZeroNamespaceLimit extends ContainerCreationError + case object TooManyConcurrentRequests extends ContainerCreationError val whiskErrors: Set[ContainerCreationError] = @@ -594,26 +639,27 @@ object ContainerCreationError extends Enumeration { TimeoutError, ZeroNamespaceLimit) - def fromName(name: String) = name.toUpperCase match { - case "NOAVAILABLEINVOKERSERROR" => NoAvailableInvokersError + private def parse(name: String) = name.toUpperCase match { + case "NOAVAILABLEINVOKERSERROR" => NoAvailableInvokersError case "NOAVAILABLERESOURCEINVOKERSERROR" => NoAvailableResourceInvokersError - case "RESOURCENOTENOUGHERROR" => ResourceNotEnoughError - case "NONEXECUTBLEACTIONERROR" => NonExecutableActionError - case "DBFETCHERROR" => DBFetchError - case "WHISKERROR" => WhiskError - case "BLACKBOXERROR" => BlackBoxError - case "TIMEOUTERROR" => TimeoutError - case "ZERONAMESPACELIMIT" => ZeroNamespaceLimit - case "TOOMANYCONCURRENTREQUESTS" => TooManyConcurrentRequests - case "UNKNOWNERROR" => UnknownError + case "RESOURCENOTENOUGHERROR" => ResourceNotEnoughError + case "NONEXECUTBLEACTIONERROR" => NonExecutableActionError + case "DBFETCHERROR" => DBFetchError + case "WHISKERROR" => WhiskError + case "BLACKBOXERROR" => BlackBoxError + case "TIMEOUTERROR" => TimeoutError + case "ZERONAMESPACELIMIT" => ZeroNamespaceLimit + case "TOOMANYCONCURRENTREQUESTS" => TooManyConcurrentRequests + case "UNKNOWNERROR" => UnknownError } implicit val serds = new RootJsonFormat[ContainerCreationError] { override def write(error: ContainerCreationError): JsValue = JsString(error.toString) + override def read(json: JsValue): ContainerCreationError = Try { val JsString(str) = json - ContainerCreationError.fromName(str.trim.toUpperCase) + ContainerCreationError.parse(str.trim.toUpperCase) } getOrElse { throw deserializationError("ContainerCreationError must be a valid string") } @@ -632,7 +678,7 @@ case class ContainerCreationAckMessage(override val transid: TransactionId, retryCount: Int = 0, error: Option[ContainerCreationError] = None, reason: Option[String] = None) - extends Message { + extends Message { /** * Serializes message to string. Must be idempotent. @@ -642,6 +688,7 @@ case class ContainerCreationAckMessage(override val transid: TransactionId, object ContainerCreationAckMessage extends DefaultJsonProtocol { def parse(msg: String): Try[ContainerCreationAckMessage] = Try(serdes.read(msg.parseJson)) + private implicit val fqnSerdes = FullyQualifiedEntityName.serdes private implicit val byteSizeSerdes = size.serdes implicit val serdes = jsonFormat12(ContainerCreationAckMessage.apply) diff --git a/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala b/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala index ded9a6add1a..258fc89f319 100644 --- a/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala +++ b/common/scala/src/main/scala/org/apache/openwhisk/core/entity/Size.scala @@ -29,36 +29,55 @@ object SizeUnits extends Enumeration { sealed abstract class Unit() { def toBytes(n: Long): Long + def toKBytes(n: Long): Long + def toMBytes(n: Long): Long + def toGBytes(n: Long): Long } case object BYTE extends Unit { def toBytes(n: Long): Long = n + def toKBytes(n: Long): Long = n / 1024 + def toMBytes(n: Long): Long = n / 1024 / 1024 + def toGBytes(n: Long): Long = n / 1024 / 1024 / 1024 } + case object KB extends Unit { def toBytes(n: Long): Long = n * 1024 + def toKBytes(n: Long): Long = n + def toMBytes(n: Long): Long = n / 1024 + def toGBytes(n: Long): Long = n / 1024 / 1024 } + case object MB extends Unit { def toBytes(n: Long): Long = n * 1024 * 1024 + def toKBytes(n: Long): Long = n * 1024 + def toMBytes(n: Long): Long = n + def toGBytes(n: Long): Long = n / 1024 } + case object GB extends Unit { def toBytes(n: Long): Long = n * 1024 * 1024 * 1024 + def toKBytes(n: Long): Long = n * 1024 * 1024 + def toMBytes(n: Long): Long = n * 1024 + def toGBytes(n: Long): Long = n } + } case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize] { @@ -66,7 +85,9 @@ case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize] require(size >= 0, "a negative size of an object is not allowed.") def toBytes = unit.toBytes(size) + def toKB = unit.toKBytes(size) + def toMB = unit.toMBytes(size) def +(other: ByteSize): ByteSize = { @@ -102,15 +123,15 @@ case class ByteSize(size: Long, unit: SizeUnits.Unit) extends Ordered[ByteSize] override def equals(that: Any): Boolean = that match { case t: ByteSize => compareTo(t) == 0 - case _ => false + case _ => false } override def toString = { unit match { case SizeUnits.BYTE => s"$size B" - case SizeUnits.KB => s"$size KB" - case SizeUnits.MB => s"$size MB" - case SizeUnits.GB => s"$size GB" + case SizeUnits.KB => s"$size KB" + case SizeUnits.MB => s"$size MB" + case SizeUnits.GB => s"$size GB" } } } @@ -138,6 +159,7 @@ object ByteSize { } object size { + implicit class SizeInt(n: Int) extends SizeConversion { def sizeIn(unit: SizeUnits.Unit): ByteSize = ByteSize(n, unit) } @@ -163,24 +185,31 @@ object size { implicit val pureconfigReader = ConfigReader[ConfigValue].map(v => ByteSize(v.atKey("key").getBytes("key"), SizeUnits.BYTE)) - implicit val serdes = new RootJsonFormat[ByteSize] { + protected[core] implicit val serdes = new RootJsonFormat[ByteSize] { def write(b: ByteSize) = JsString(b.toString) def read(value: JsValue): ByteSize = value match { case JsString(s) => ByteSize.fromString(s) - case _ => deserializationError(formatError) + case _ => deserializationError(formatError) } } } trait SizeConversion { def B = sizeIn(SizeUnits.BYTE) + def KB = sizeIn(SizeUnits.KB) + def MB = sizeIn(SizeUnits.MB) + def GB: ByteSize = sizeIn(SizeUnits.GB) + def bytes = B + def kilobytes = KB + def megabytes = MB + def gigabytes: ByteSize = GB def sizeInBytes = sizeIn(SizeUnits.BYTE) diff --git a/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala b/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala new file mode 100644 index 00000000000..05cddd83880 --- /dev/null +++ b/core/invoker/src/main/scala/org/apache/openwhisk/core/invoker/ContainerMessageConsumer.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.openwhisk.core.invoker + +import java.nio.charset.StandardCharsets + +import akka.actor.{ActorRef, ActorSystem, Props} +import org.apache.kafka.clients.producer.RecordMetadata +import org.apache.openwhisk.common.{GracefulShutdown, Logging, TransactionId} +import org.apache.openwhisk.core.WarmUp.isWarmUpAction +import org.apache.openwhisk.core.WhiskConfig +import org.apache.openwhisk.core.connector.ContainerCreationError.DBFetchError +import org.apache.openwhisk.core.connector._ +import org.apache.openwhisk.core.containerpool.v2.{CreationContainer, DeletionContainer} +import org.apache.openwhisk.core.database.{ + ArtifactStore, + DocumentTypeMismatchException, + DocumentUnreadable, + NoDocumentException +} +import org.apache.openwhisk.core.entity._ +import org.apache.openwhisk.http.Messages + +import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +class ContainerMessageConsumer( + invokerInstanceId: InvokerInstanceId, + containerPool: ActorRef, + entityStore: ArtifactStore[WhiskEntity], + config: WhiskConfig, + msgProvider: MessagingProvider, + longPollDuration: FiniteDuration, + maxPeek: Int, + sendAckToScheduler: (SchedulerInstanceId, ContainerCreationAckMessage) => Future[RecordMetadata])( + implicit actorSystem: ActorSystem, + executionContext: ExecutionContext, + logging: Logging) { + + private val topic = s"${Invoker.topicPrefix}invoker${invokerInstanceId.toInt}" + private val consumer = + msgProvider.getConsumer(config, topic, topic, maxPeek, maxPollInterval = TimeLimit.MAX_DURATION + 1.minute) + + private def handler(bytes: Array[Byte]): Future[Unit] = Future { + val raw = new String(bytes, StandardCharsets.UTF_8) + ContainerMessage.parse(raw) match { + case Success(creation: ContainerCreationMessage) if isWarmUpAction(creation.action) => + logging.info( + this, + s"container creation message for ${creation.invocationNamespace}/${creation.action} is received (creationId: ${creation.creationId})") + feed ! MessageFeed.Processed + + case Success(creation: ContainerCreationMessage) => + implicit val transid: TransactionId = creation.transid + logging + .info(this, s"container creation message for ${creation.invocationNamespace}/${creation.action} is received") + WhiskAction + .get(entityStore, creation.action.toDocId, creation.revision, fromCache = true) + .map { action => + containerPool ! CreationContainer(creation, action) + feed ! MessageFeed.Processed + } + .recover { + case t => + val message = t match { + case _: NoDocumentException => + Messages.actionRemovedWhileInvoking + case _: DocumentTypeMismatchException | _: DocumentUnreadable => + Messages.actionMismatchWhileInvoking + case e: Throwable => + logging.error(this, s"An unknown DB connection error occurred while fetching an action: $e.") + Messages.actionFetchErrorWhileInvoking + } + logging.error( + this, + s"failed to fetch action ${creation.invocationNamespace}/${creation.action}, error: $message (creationId: ${creation.creationId})") + + val ack = ContainerCreationAckMessage( + creation.transid, + creation.creationId, + creation.invocationNamespace, + creation.action, + creation.revision, + creation.whiskActionMetaData, + invokerInstanceId, + creation.schedulerHost, + creation.rpcPort, + creation.retryCount, + Some(DBFetchError), + Some(message)) + sendAckToScheduler(creation.rootSchedulerIndex, ack) + feed ! MessageFeed.Processed + } + case Success(deletion: ContainerDeletionMessage) => + implicit val transid: TransactionId = deletion.transid + logging.info(this, s"deletion message for ${deletion.invocationNamespace}/${deletion.action} is received") + containerPool ! DeletionContainer(deletion) + feed ! MessageFeed.Processed + case Failure(t) => + logging.error(this, s"Failed to parse $bytes, error: ${t.getMessage}") + feed ! MessageFeed.Processed + + case _ => + logging.error(this, s"Unexpected message received $raw") + feed ! MessageFeed.Processed + } + } + + private val feed = actorSystem.actorOf(Props { + new MessageFeed("containerCreation", logging, consumer, maxPeek, longPollDuration, handler) + }) + + def close(): Unit = { + feed ! GracefulShutdown + } +} diff --git a/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala b/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala new file mode 100644 index 00000000000..5ceddfb9384 --- /dev/null +++ b/tests/src/test/scala/org/apache/openwhisk/core/invoker/test/ContainerMessageConsumerTests.scala @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.openwhisk.core.invoker.test + +import java.nio.charset.StandardCharsets + +import akka.actor.ActorSystem +import akka.stream.ActorMaterializer +import akka.testkit.{TestKit, TestProbe} +import common.StreamLogging +import org.apache.kafka.clients.producer.RecordMetadata +import org.apache.openwhisk.common.{Logging, TransactionId} +import org.apache.openwhisk.core.{WarmUp, WhiskConfig} +import org.apache.openwhisk.core.connector.ContainerCreationError._ +import org.apache.openwhisk.core.connector._ +import org.apache.openwhisk.core.connector.test.TestConnector +import org.apache.openwhisk.core.containerpool.v2.CreationContainer +import org.apache.openwhisk.core.database.test.DbUtils +import org.apache.openwhisk.core.entity.ExecManifest.{ImageName, RuntimeManifest} +import org.apache.openwhisk.core.entity._ +import org.apache.openwhisk.core.entity.size._ +import org.apache.openwhisk.core.entity.test.ExecHelpers +import org.apache.openwhisk.core.invoker.ContainerMessageConsumer +import org.apache.openwhisk.http.Messages +import org.apache.openwhisk.utils.{retry => utilRetry} +import org.junit.runner.RunWith +import org.scalamock.scalatest.MockFactory +import org.scalatest.junit.JUnitRunner +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FlatSpecLike, Matchers} + +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.util.Try + +@RunWith(classOf[JUnitRunner]) +class ContainerMessageConsumerTests + extends TestKit(ActorSystem("ContainerMessageConsumer")) + with FlatSpecLike + with Matchers + with BeforeAndAfterEach + with BeforeAndAfterAll + with StreamLogging + with MockFactory + with DbUtils + with ExecHelpers { + + implicit val actualActorSystem = system // Use system for duplicate system and actorSystem. + implicit val ec = actualActorSystem.dispatcher + implicit val materializer = ActorMaterializer() + implicit val transId = TransactionId.testing + implicit val creationId = CreationId.generate() + + override def afterAll(): Unit = { + TestKit.shutdownActorSystem(system) + super.afterAll() + } + + private val whiskConfig = new WhiskConfig( + Map( + WhiskConfig.actionInvokePerMinuteLimit -> null, + WhiskConfig.triggerFirePerMinuteLimit -> null, + WhiskConfig.actionInvokeConcurrentLimit -> null, + WhiskConfig.runtimesManifest -> null, + WhiskConfig.actionSequenceMaxLimit -> null)) + + private val entityStore = WhiskEntityStore.datastore() + private val producer = stub[MessageProducer] + + private val defaultUserMemory: ByteSize = 1024.MB + private val invokerInstance = InvokerInstanceId(0, userMemory = defaultUserMemory) + private val schedulerInstanceId = SchedulerInstanceId("0") + + private val invocationNamespace = EntityName("invocationSpace") + + private val schedulerHost = "127.17.0.1" + + private val rpcPort = 13001 + + override def afterEach(): Unit = { + cleanup() + } + + private def fakeMessageProvider(consumer: TestConnector): MessagingProvider = { + new MessagingProvider { + override def getConsumer( + whiskConfig: WhiskConfig, + groupId: String, + topic: String, + maxPeek: Int, + maxPollInterval: FiniteDuration)(implicit logging: Logging, actorSystem: ActorSystem): MessageConsumer = + consumer + + override def getProducer(config: WhiskConfig, maxRequestSize: Option[ByteSize])( + implicit logging: Logging, + actorSystem: ActorSystem): MessageProducer = consumer.getProducer() + + override def ensureTopic(config: WhiskConfig, + topic: String, + topicConfig: String, + maxMessageBytes: Option[ByteSize])(implicit logging: Logging): Try[Unit] = Try {} + } + } + + def sendAckToScheduler(producer: MessageProducer)(schedulerInstanceId: SchedulerInstanceId, + ackMessage: ContainerCreationAckMessage): Future[RecordMetadata] = { + val topic = s"creationAck${schedulerInstanceId.asString}" + producer.send(topic, ackMessage) + } + + private def createAckMsg(creationMessage: ContainerCreationMessage, + error: Option[ContainerCreationError], + reason: Option[String]) = { + ContainerCreationAckMessage( + creationMessage.transid, + creationMessage.creationId, + creationMessage.invocationNamespace, + creationMessage.action, + creationMessage.revision, + creationMessage.whiskActionMetaData, + invokerInstance, + creationMessage.schedulerHost, + creationMessage.rpcPort, + creationMessage.retryCount, + error, + reason) + } + + it should "forward ContainerCreationMessage to containerPool" in { + val pool = TestProbe() + val mockConsumer = new TestConnector("fakeTopic", 4, true) + val msgProvider = fakeMessageProvider(mockConsumer) + + val consumer = + new ContainerMessageConsumer( + invokerInstance, + pool.ref, + entityStore, + whiskConfig, + msgProvider, + 200.milliseconds, + 500, + sendAckToScheduler(producer)) + + val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None) + val action = + WhiskAction(EntityPath("testns"), EntityName("testAction"), exec, limits = ActionLimits(TimeLimit(1.minute))) + put(entityStore, action) + val execMetadata = + CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint) + val actionMetadata = + WhiskActionMetaData( + action.namespace, + action.name, + execMetadata, + action.parameters, + action.limits, + action.version, + action.publish, + action.annotations) + + val msg = + ContainerCreationMessage( + transId, + invocationNamespace.asString, + action.fullyQualifiedName(true), + DocRevision.empty, + actionMetadata, + schedulerInstanceId, + schedulerHost, + rpcPort, + creationId = creationId) + + mockConsumer.send(msg) + + pool.expectMsgPF() { + case CreationContainer(_, _) => true + } + } + + it should "send ack(failed) to scheduler when failed to get action from DB " in { + val pool = TestProbe() + val creationConsumer = new TestConnector("creation", 4, true) + val msgProvider = fakeMessageProvider(creationConsumer) + + val ackTopic = "ack" + val ackConsumer = new TestConnector(ackTopic, 4, true) + + val consumer = + new ContainerMessageConsumer( + invokerInstance, + pool.ref, + entityStore, + whiskConfig, + msgProvider, + 200.milliseconds, + 500, + sendAckToScheduler(ackConsumer.getProducer())) + + val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None) + val whiskAction = + WhiskAction(EntityPath("testns"), EntityName("testAction2"), exec, limits = ActionLimits(TimeLimit(1.minute))) + val execMetadata = + CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint) + val actionMetadata = + WhiskActionMetaData( + whiskAction.namespace, + whiskAction.name, + execMetadata, + whiskAction.parameters, + whiskAction.limits, + whiskAction.version, + whiskAction.publish, + whiskAction.annotations) + + val creationMessage = + ContainerCreationMessage( + transId, + invocationNamespace.asString, + whiskAction.fullyQualifiedName(true), + DocRevision.empty, + actionMetadata, + schedulerInstanceId, + schedulerHost, + rpcPort, + creationId = creationId) + + // action doesn't exist + val ackMessage = createAckMsg(creationMessage, Some(DBFetchError), Some(Messages.actionRemovedWhileInvoking)) + creationConsumer.send(creationMessage) + + within(5.seconds) { + utilRetry({ + val buffer = ackConsumer.peek(50.millisecond) + buffer.size shouldBe 1 + buffer.head._1 shouldBe ackTopic + new String(buffer.head._4, StandardCharsets.UTF_8) shouldBe ackMessage.serialize + }, 10, Some(500.millisecond)) + pool.expectNoMessage(2.seconds) + } + + // action exist but version mismatch + put(entityStore, whiskAction) + val actualCreationMessage = creationMessage.copy(revision = DocRevision("1-fake")) + val fetchErrorAckMessage = + createAckMsg(actualCreationMessage, Some(DBFetchError), Some(Messages.actionFetchErrorWhileInvoking)) + creationConsumer.send(actualCreationMessage) + + within(5.seconds) { + utilRetry({ + val buffer2 = ackConsumer.peek(50.millisecond) + buffer2.size shouldBe 1 + buffer2.head._1 shouldBe ackTopic + new String(buffer2.head._4, StandardCharsets.UTF_8) shouldBe fetchErrorAckMessage.serialize + }, 10, Some(500.millisecond)) + pool.expectNoMessage(2.seconds) + } + } + + it should "drop messages of warm-up action" in { + val pool = TestProbe() + val mockConsumer = new TestConnector("fakeTopic", 4, true) + val msgProvider = fakeMessageProvider(mockConsumer) + + val consumer = + new ContainerMessageConsumer( + invokerInstance, + pool.ref, + entityStore, + whiskConfig, + msgProvider, + 200.milliseconds, + 500, + sendAckToScheduler(producer)) + + val exec = CodeExecAsString(RuntimeManifest("nodejs:10", ImageName("testImage")), "testCode", None) + val action = + WhiskAction( + WarmUp.warmUpAction.namespace.toPath, + WarmUp.warmUpAction.name, + exec, + limits = ActionLimits(TimeLimit(1.minute))) + val doc = put(entityStore, action) + val execMetadata = + CodeExecMetaDataAsString(exec.manifest, entryPoint = exec.entryPoint) + + val actionMetadata = + WhiskActionMetaData( + action.namespace, + action.name, + execMetadata, + action.parameters, + action.limits, + action.version, + action.publish, + action.annotations) + + val msg = + ContainerCreationMessage( + transId, + invocationNamespace.asString, + action.fullyQualifiedName(false), + DocRevision.empty, + actionMetadata, + schedulerInstanceId, + schedulerHost, + rpcPort, + creationId = creationId) + + mockConsumer.send(msg) + + pool.expectNoMessage(1.seconds) + } +}