-
Notifications
You must be signed in to change notification settings - Fork 397
Local scoring (aka Sparkless) using Aardpfark #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 30 commits
b9028c4
ced9ee7
9e986a7
470f841
eda18ad
2320bbf
f7d690e
3ddbfa7
214c181
e0617bc
43ccba3
ceab612
0bb6172
dd45409
5d1f8b7
2053037
4d6b711
855217b
ebacf5e
f1f3ce0
f621947
87c364e
645a071
fcdaae4
0021d3c
0808f78
d73baba
74d8c26
7ab91f1
68135c5
b796b1e
7b2cdcf
67e42e5
aebef60
377d52f
c84c85a
8ac9daa
05ea9c4
4d72d05
91e9248
903ad95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,6 +116,27 @@ object RichRow { | |
def getFeatureType[T <: FeatureType](f: TransientFeature)(implicit conv: FeatureTypeSparkConverter[T]): T = | ||
conv.fromSpark(getAny(f.name)) | ||
|
||
/** | ||
* Converts row to a [[collection.mutable.Map]] | ||
* | ||
* @return a [[collection.mutable.Map]] with row contents | ||
*/ | ||
def toMutableMap: collection.mutable.Map[String, Any] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so you are saying that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oook, so my function is faster, because |
||
val res = collection.mutable.Map.empty[String, Any] | ||
val fields = row.schema.fields | ||
for {i <- 0 until row.size} { | ||
res += fields(i).name -> row(i) | ||
} | ||
res | ||
} | ||
|
||
/** | ||
* Converts row to a [[collection.immutable.Map]] | ||
* | ||
* @return a [[collection.immutable.Map]] with row contents | ||
*/ | ||
def toMap: Map[String, Any] = toMutableMap.toMap | ||
|
||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
repositories { | ||
// TODO: remove once Aardpfark release if official | ||
maven { url 'https://jitpack.io' } | ||
} | ||
|
||
dependencies { | ||
compile project(':core') | ||
testCompile project(':testkit') | ||
|
||
// PFA serialization for Spark models | ||
// TODO: replace with official Aardpfark release when ready | ||
compile "com.github.relateiq:aardpfark:$aardpfarkVersion" | ||
|
||
// Hadrian PFA runtime for JVM | ||
compileOnly "com.opendatagroup:hadrian:$hadrianVersion" | ||
testRuntime "com.opendatagroup:hadrian:$hadrianVersion" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
/* | ||
* Copyright (c) 2017, Salesforce.com, Inc. | ||
* All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* | ||
* * Redistributions of source code must retain the above copyright notice, this | ||
* list of conditions and the following disclaimer. | ||
* | ||
* * Redistributions in binary form must reproduce the above copyright notice, | ||
* this list of conditions and the following disclaimer in the documentation | ||
* and/or other materials provided with the distribution. | ||
* | ||
* * Neither the name of the copyright holder nor the names of its | ||
* contributors may be used to endorse or promote products derived from | ||
* this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
*/ | ||
|
||
package com.salesforce.op.local | ||
|
||
import com.ibm.aardpfark.spark.ml.SparkSupport | ||
import com.opendatagroup.hadrian.jvmcompiler.PFAEngine | ||
import com.salesforce.op.OpWorkflowModel | ||
import com.salesforce.op.stages.sparkwrappers.generic.SparkWrapperParams | ||
import com.salesforce.op.stages.{OPStage, OpTransformer} | ||
import org.apache.spark.ml.SparkMLSharedParamConstants._ | ||
import org.apache.spark.ml.Transformer | ||
import org.apache.spark.ml.linalg.Vector | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.json4s._ | ||
import org.json4s.native.JsonMethods._ | ||
import org.json4s.native.Serialization | ||
|
||
import scala.collection.mutable | ||
|
||
/** | ||
* Enrichment for [[OpWorkflowModel]] to allow local scoring functionality | ||
*/ | ||
trait OpWorkflowModelLocal { | ||
|
||
/** | ||
* Enrichment for [[OpWorkflowModel]] to allow local scoring functionality | ||
* | ||
* @param model [[OpWorkflowModel]] | ||
*/ | ||
implicit class RichOpWorkflowModel(model: OpWorkflowModel) { | ||
|
||
private implicit val formats = DefaultFormats | ||
|
||
/** | ||
* Internal PFA model representation | ||
* | ||
* @param inputs mode inputs mappings | ||
* @param output output mapping | ||
* @param engine PFA engine | ||
*/ | ||
private case class PFAModel | ||
( | ||
inputs: Map[String, String], | ||
output: (String, String), | ||
engine: PFAEngine[AnyRef, AnyRef] | ||
) | ||
|
||
/** | ||
* Internal OP model representation | ||
* | ||
* @param output output name | ||
* @param model model instance | ||
*/ | ||
private case class OPModel(output: String, model: OPStage with OpTransformer) | ||
|
||
/** | ||
* Prepares a score function for local scoring | ||
* | ||
* @return score function for local scoring | ||
*/ | ||
def scoreFunction: ScoreFunction = { | ||
// Prepare the stages for scoring | ||
val stagesWithIndex = model.stages.zipWithIndex | ||
// Collect all OP stages | ||
val opStages = stagesWithIndex.collect { case (s: OpTransformer, i) => OPModel(s.getOutputFeatureName, s) -> i } | ||
// Collect all Spark wrapped stages | ||
val sparkStages = stagesWithIndex.filterNot(_._1.isInstanceOf[OpTransformer]).collect { | ||
case (s: OPStage with SparkWrapperParams[_], i) if s.getSparkMlStage().isDefined => | ||
(s, s.getSparkMlStage().get.asInstanceOf[Transformer].copy(ParamMap.empty), i) | ||
} | ||
// Convert Spark wrapped stages into PFA models | ||
val pfaStages = sparkStages.map { case (opStage, sparkStage, i) => toPFAModel(opStage, sparkStage) -> i } | ||
// Combine all stages and apply the original order | ||
val allStages = (opStages ++ pfaStages).sortBy(_._2).map(_._1) | ||
val resultFeatures = model.getResultFeatures().map(_.name).toSet | ||
|
||
// Score Function | ||
input: Map[String, Any] => { | ||
val inputMap = mutable.Map.empty ++= input | ||
val transformedRow = allStages.foldLeft(inputMap) { | ||
// For OP Models we simply call transform | ||
case (row, OPModel(output, stage)) => | ||
row += output -> stage.transformKeyValue(row.apply) | ||
|
||
// For PFA Models we execute PFA engine action with json in/out | ||
case (row, PFAModel(inputs, (out, outCol), engine)) => | ||
val inJson = rowToJson(row, inputs) | ||
val engineIn = engine.jsonInput(inJson) | ||
val engineOut = engine.action(engineIn) | ||
val resMap = parse(engineOut.toString).extract[Map[String, Any]] | ||
row += out -> resMap(outCol) | ||
} | ||
transformedRow.filterKeys(resultFeatures.contains).toMap | ||
} | ||
} | ||
|
||
/** | ||
* Convert Spark wrapped staged into PFA Models | ||
*/ | ||
private def toPFAModel(opStage: OPStage with SparkWrapperParams[_], sparkStage: Transformer): PFAModel = { | ||
// Update input/output params for Spark stages to default ones | ||
val inParam = sparkStage.getParam(inputCol.name) | ||
val outParam = sparkStage.getParam(outputCol.name) | ||
val inputs = opStage.getInputFeatures().map(_.name).map { | ||
case n if sparkStage.get(inParam).contains(n) => n -> inputCol.name | ||
case n if sparkStage.get(outParam).contains(n) => n -> outputCol.name | ||
case n => n -> n | ||
}.toMap | ||
val output = opStage.getOutputFeatureName | ||
sparkStage.set(inParam, inputCol.name).set(outParam, outputCol.name) | ||
val pfaJson = SparkSupport.toPFA(sparkStage, pretty = true) | ||
val pfaEngine = PFAEngine.fromJson(pfaJson).head | ||
PFAModel(inputs, (output, outputCol.name), pfaEngine) | ||
} | ||
|
||
/** | ||
* Convert row of Spark values into a json convertible Map | ||
* See [[FeatureTypeSparkConverter.toSpark]] for all possible values - we invert them here | ||
*/ | ||
private def rowToJson(row: mutable.Map[String, Any], inputs: Map[String, String]): String = { | ||
val in = inputs.map { case (k, v) => (v, row.get(k)) }.mapValues { | ||
case Some(v: Vector) => v.toArray | ||
case Some(v: mutable.WrappedArray[_]) => v.toArray(v.elemTag) | ||
case Some(v: Map[_, _]) => v.mapValues { | ||
case v: mutable.WrappedArray[_] => v.toArray(v.elemTag) | ||
case x => x | ||
} | ||
case None | Some(null) => null | ||
case Some(v) => v | ||
} | ||
Serialization.write(in) | ||
} | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright (c) 2017, Salesforce.com, Inc. | ||
* All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* | ||
* * Redistributions of source code must retain the above copyright notice, this | ||
* list of conditions and the following disclaimer. | ||
* | ||
* * Redistributions in binary form must reproduce the above copyright notice, | ||
* this list of conditions and the following disclaimer in the documentation | ||
* and/or other materials provided with the distribution. | ||
* | ||
* * Neither the name of the copyright holder nor the names of its | ||
* contributors may be used to endorse or promote products derived from | ||
* this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
*/ | ||
|
||
package com.salesforce.op.local | ||
|
||
import com.salesforce.op.{OpParams, OpWorkflow} | ||
|
||
|
||
/** | ||
* A class for running TransmogrifAI Workflow without Spark. | ||
* | ||
* @param workflow the workflow that you want to run (Note: the workflow should have the resultFeatures set) | ||
*/ | ||
class OpWorkflowRunnerLocal(val workflow: OpWorkflow) { | ||
|
||
/** | ||
* Load the model & prepare a score function for local scoring | ||
* | ||
* @param params params to use during scoring | ||
* @return score function for local scoring | ||
*/ | ||
def score(params: OpParams): ScoreFunction = { | ||
require(params.modelLocation.isDefined, "Model location must be set in params") | ||
val model = workflow.loadModel(params.modelLocation.get) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will the standard load method work on spark models that use parquet storage without a spark context? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None of the spark ml readers require the context explicitly, but I will need to verify, cause they might get/create spark context inside. Do you have a model in mind that I can check against? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe try PCA There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh snap, they simply create a spark context internally when loading models 🤦♂️ https://github.com/apache/spark/blob/5264164a67df498b73facae207eda12ee133be7d/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala#L212 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, we also use spark context when reading the model & stages - https://github.com/salesforce/TransmogrifAI/blob/master/core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala#L61 |
||
model.scoreFunction | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* Copyright (c) 2017, Salesforce.com, Inc. | ||
* All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* | ||
* * Redistributions of source code must retain the above copyright notice, this | ||
* list of conditions and the following disclaimer. | ||
* | ||
* * Redistributions in binary form must reproduce the above copyright notice, | ||
* this list of conditions and the following disclaimer in the documentation | ||
* and/or other materials provided with the distribution. | ||
* | ||
* * Neither the name of the copyright holder nor the names of its | ||
* contributors may be used to endorse or promote products derived from | ||
* this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
*/ | ||
|
||
package com.salesforce.op | ||
|
||
|
||
package object local extends OpWorkflowModelLocal { | ||
|
||
/** | ||
* Score function for local scoring: raw record => transformed record | ||
*/ | ||
type ScoreFunction = Map[String, Any] => Map[String, Any] | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we pulling in a shapshot?