-
Notifications
You must be signed in to change notification settings - Fork 397
Local scoring (aka Sparkless) using Aardpfark #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
b9028c4
ced9ee7
9e986a7
470f841
eda18ad
2320bbf
f7d690e
3ddbfa7
214c181
e0617bc
43ccba3
ceab612
0bb6172
dd45409
5d1f8b7
2053037
4d6b711
855217b
ebacf5e
f1f3ce0
f621947
87c364e
645a071
fcdaae4
0021d3c
0808f78
d73baba
74d8c26
7ab91f1
68135c5
b796b1e
7b2cdcf
67e42e5
aebef60
377d52f
c84c85a
8ac9daa
05ea9c4
4d72d05
91e9248
903ad95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
repositories { | ||
// TODO: remove once Aardpfark release if official | ||
maven { url 'https://jitpack.io' } | ||
} | ||
|
||
dependencies { | ||
compile project(':core') | ||
testCompile project(':testkit') | ||
|
||
// PFA serialization for Spark models | ||
// TODO: replace with official Aardpfark release when ready | ||
compile "com.github.relateiq:aardpfark:0.1.0-SNAPSHOT" | ||
|
||
// Hadrian PFA runtime for JVM | ||
compileOnly "com.opendatagroup:hadrian:0.8.5" | ||
testRuntime "com.opendatagroup:hadrian:0.8.5" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* | ||
* Copyright (c) 2017, Salesforce.com, Inc. | ||
* All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* | ||
* * Redistributions of source code must retain the above copyright notice, this | ||
* list of conditions and the following disclaimer. | ||
* | ||
* * Redistributions in binary form must reproduce the above copyright notice, | ||
* this list of conditions and the following disclaimer in the documentation | ||
* and/or other materials provided with the distribution. | ||
* | ||
* * Neither the name of the copyright holder nor the names of its | ||
* contributors may be used to endorse or promote products derived from | ||
* this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
*/ | ||
|
||
package com.salesforce.op.local | ||
|
||
import com.ibm.aardpfark.spark.ml.SparkSupport.toPFA | ||
import com.opendatagroup.hadrian.jvmcompiler.PFAEngine | ||
import com.salesforce.op.features.types.OPVector | ||
import com.salesforce.op.stages.{OpPipelineStage, OpTransformer} | ||
import com.salesforce.op.stages.sparkwrappers.generic.SparkWrapperParams | ||
import com.salesforce.op.utils.json.JsonUtils | ||
import com.salesforce.op.{OpParams, OpWorkflow} | ||
import org.apache.spark.ml.Transformer | ||
import org.apache.spark.ml.linalg.Vector | ||
|
||
|
||
/** | ||
* A class for running TransmogrifAI Workflow without Spark. | ||
* | ||
* @param workflow the workflow that you want to run (Note: the workflow should have the resultFeatures set) | ||
*/ | ||
class OpWorkflowRunnerLocal(val workflow: OpWorkflow) { | ||
|
||
type ScoreFun = Map[String, Any] => Map[String, Any] | ||
|
||
/** | ||
* Load the model & prepare a score local function | ||
* | ||
* @param params params to use during scoring | ||
* @return score local function | ||
*/ | ||
def score(params: OpParams): ScoreFun = { | ||
require(params.modelLocation.isDefined, "Model location must be set in params") | ||
val model = workflow.loadModel(params.modelLocation.get) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will the standard load method work on spark models that use parquet storage without a spark context? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None of the spark ml readers require the context explicitly, but I will need to verify, cause they might get/create spark context inside. Do you have a model in mind that I can check against? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe try PCA There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh snap, they simply create a spark context internally when loading models 🤦♂️ https://github.com/apache/spark/blob/5264164a67df498b73facae207eda12ee133be7d/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala#L212 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, we also use spark context when reading the model & stages - https://github.com/salesforce/TransmogrifAI/blob/master/core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala#L61 |
||
|
||
val stagesWithIndex = model.stages.zipWithIndex | ||
val opStages = stagesWithIndex.collect { case (s: OpTransformer, i) => s -> i } | ||
val sparkStages = stagesWithIndex.filterNot(_._1.isInstanceOf[OpTransformer]).collect { | ||
case (s: SparkWrapperParams[_], i) => s.getSparkMlStage().map(_ -> i) | ||
case (s: Transformer, i) => Some(s -> i) | ||
}.flatten.map(v => v._1.asInstanceOf[Transformer] -> v._2) | ||
|
||
val pfaStages = sparkStages.map { case (s, i) => toPFA(s, pretty = true) -> i } | ||
val engines = pfaStages.map { case (s, i) => PFAEngine.fromJson(s, multiplicity = 1).head -> i } | ||
val loadedStages = (opStages ++ engines).sortBy(_._2) | ||
|
||
row: Map[String, Any] => { | ||
val rowMap = collection.mutable.Map.empty ++ row | ||
val transformedRow = loadedStages.foldLeft(rowMap) { (r, s) => | ||
s match { | ||
case (s: OpTransformer, _) => | ||
r += s.asInstanceOf[OpPipelineStage[_]].getOutputFeatureName -> s.transformKeyValue(r.apply) | ||
|
||
case (e: PFAEngine[AnyRef, AnyRef], i) => | ||
val stage = stagesWithIndex.find(_._2 == i).map(_._1.asInstanceOf[OpPipelineStage[_]]).get | ||
val outName = stage.getOutputFeatureName | ||
val inputName = stage.getInputFeatures().collect { | ||
case f if f.isSubtypeOf[OPVector] => f.name | ||
}.head | ||
val vector = r(inputName).asInstanceOf[Vector].toArray | ||
val input = s"""{"$inputName":${vector.mkString("[", ",", "]")}}""" | ||
val res = e.action(e.jsonInput(input)).toString | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick is using json is the most efficient way to call engine action? |
||
r += outName -> JsonUtils.fromString[Map[String, Any]](res).get | ||
} | ||
} | ||
val resultFeatures = model.getResultFeatures().map(_.name) | ||
transformedRow.collect { case r@(k, _) if resultFeatures.contains(k) => r }.toMap | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright (c) 2017, Salesforce.com, Inc. | ||
* All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* | ||
* * Redistributions of source code must retain the above copyright notice, this | ||
* list of conditions and the following disclaimer. | ||
* | ||
* * Redistributions in binary form must reproduce the above copyright notice, | ||
* this list of conditions and the following disclaimer in the documentation | ||
* and/or other materials provided with the distribution. | ||
* | ||
* * Neither the name of the copyright holder nor the names of its | ||
* contributors may be used to endorse or promote products derived from | ||
* this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
*/ | ||
|
||
package com.salesforce.op.local | ||
|
||
import java.io.File | ||
|
||
import com.salesforce.op.stages.impl.classification.BinaryClassificationModelSelector | ||
import com.salesforce.op.stages.impl.classification.BinaryClassificationModelsToTry._ | ||
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestCommon} | ||
import com.salesforce.op.utils.spark.RichRow._ | ||
import com.salesforce.op.utils.spark.RichDataset._ | ||
import com.salesforce.op.{OpParams, OpWorkflow} | ||
import org.junit.runner.RunWith | ||
import org.scalatest.FlatSpec | ||
import org.scalatest.junit.JUnitRunner | ||
|
||
|
||
@RunWith(classOf[JUnitRunner]) | ||
class OpWorkflowRunnerLocalTest extends FlatSpec with PassengerSparkFixtureTest with TestCommon { | ||
|
||
val features = Seq(height, weight, gender, description, age).transmogrify() | ||
val survivedNum = survived.occurs() | ||
|
||
val prediction = BinaryClassificationModelSelector.withTrainValidationSplit( | ||
splitter = None, modelTypesToUse = Seq(OpLogisticRegression) | ||
).setInput(survivedNum, features).getOutput() | ||
|
||
val workflow = new OpWorkflow().setResultFeatures(prediction, survivedNum).setReader(dataReader) | ||
|
||
lazy val model = workflow.train() | ||
|
||
lazy val modelLocation = { | ||
val path = new File(tempDir + "/op-runner-local-test-model").toString | ||
model.save(path) | ||
path | ||
} | ||
|
||
lazy val rawData = dataReader.generateDataFrame(model.rawFeatures).collect().map(_.toMap) | ||
|
||
lazy val expectedScores = model.score().collect(prediction, survivedNum) | ||
|
||
// TODO: actually test spark wrapped stage with PFA | ||
Spec(classOf[OpWorkflowRunnerLocal]) should "produce scores without Spark" in { | ||
val params = new OpParams().withValues(modelLocation = Some(modelLocation)) | ||
val scoreFn = new OpWorkflowRunnerLocal(workflow).score(params) | ||
val _ = rawData.map(row => scoreFn(row)) // warm up | ||
|
||
val numOfRuns = 1000 | ||
var elapsed = 0L | ||
for { _ <- 0 until numOfRuns } { | ||
val start = System.currentTimeMillis() | ||
val scores = rawData.map(row => scoreFn(row)) | ||
elapsed += System.currentTimeMillis() - start | ||
for { | ||
(score, (predV, survivedV)) <- scores.zip(expectedScores) | ||
expected = Map( | ||
prediction.name -> predV.value, | ||
survivedNum.name -> survivedV.value.get | ||
) | ||
} score shouldBe expected | ||
} | ||
println(s"Scored ${expectedScores.length * numOfRuns} records in ${elapsed}ms") | ||
println(s"Average time per record: ${elapsed.toDouble / (expectedScores.length * numOfRuns)}ms") | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala#L354
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you are saying that
row.getValuesMap[Any]
should work as well? let me try.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oook, so my function is faster, because
getValuesMap
callsdef getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName))
for each value, while my function operates on indices.